pax_global_header00006660000000000000000000000064127307464600014523gustar00rootroot0000000000000052 comment=d42167fd04f636e20b005e9934159e95454233c7 golang-github-valyala-fasthttp-20160617/000077500000000000000000000000001273074646000200045ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/.gitignore000066400000000000000000000000331273074646000217700ustar00rootroot00000000000000tags *.pprof *.fasthttp.gz golang-github-valyala-fasthttp-20160617/.travis.yml000066400000000000000000000004021273074646000221110ustar00rootroot00000000000000language: go go: - 1.6 script: # build test for supported platforms - GOOS=linux go build - GOOS=darwin go build - GOOS=freebsd go build - GOOS=windows go build - GOARCH=386 go build # run tests on a standard platform - go test -v ./... golang-github-valyala-fasthttp-20160617/LICENSE000066400000000000000000000021211273074646000210050ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2015-2016 Aliaksandr Valialkin, VertaMedia Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-valyala-fasthttp-20160617/README.md000066400000000000000000000675361273074646000213040ustar00rootroot00000000000000[![Build Status](https://travis-ci.org/valyala/fasthttp.svg)](https://travis-ci.org/valyala/fasthttp) [![GoDoc](https://godoc.org/github.com/valyala/fasthttp?status.svg)](http://godoc.org/github.com/valyala/fasthttp) [![Go Report](http://goreportcard.com/badge/valyala/fasthttp)](http://goreportcard.com/report/valyala/fasthttp) # fasthttp Fast HTTP implementation for Go. Currently fasthttp is successfully used by [VertaMedia](https://vertamedia.com/) in a production serving up to 200K rps from more than 1.5M concurrent keep-alive connections per physical server. [TechEmpower Benchmark round 12 results](https://www.techempower.com/benchmarks/#section=data-r12&hw=peak&test=plaintext) [Server Benchmarks](#http-server-performance-comparison-with-nethttp) [Client Benchmarks](#http-client-comparison-with-nethttp) [Install](#install) [Documentation](https://godoc.org/github.com/valyala/fasthttp) [Examples from docs](https://godoc.org/github.com/valyala/fasthttp#pkg-examples) [Code examples](examples) [Switching from net/http to fasthttp](#switching-from-nethttp-to-fasthttp) [Fasthttp best practices](#fasthttp-best-practices) [Tricks with byte buffers](#tricks-with-byte-buffers) [Related projects](#related-projects) [FAQ](#faq) # HTTP server performance comparison with [net/http](https://golang.org/pkg/net/http/) In short, fasthttp server is up to 10 times faster than net/http. Below are benchmark results. *GOMAXPROCS=1* net/http server: ``` $ GOMAXPROCS=1 go test -bench=NetHTTPServerGet -benchmem -benchtime=10s BenchmarkNetHTTPServerGet1ReqPerConn 1000000 12052 ns/op 2297 B/op 29 allocs/op BenchmarkNetHTTPServerGet2ReqPerConn 1000000 12278 ns/op 2327 B/op 24 allocs/op BenchmarkNetHTTPServerGet10ReqPerConn 2000000 8903 ns/op 2112 B/op 19 allocs/op BenchmarkNetHTTPServerGet10KReqPerConn 2000000 8451 ns/op 2058 B/op 18 allocs/op BenchmarkNetHTTPServerGet1ReqPerConn10KClients 500000 26733 ns/op 3229 B/op 29 allocs/op BenchmarkNetHTTPServerGet2ReqPerConn10KClients 1000000 23351 ns/op 3211 B/op 24 allocs/op BenchmarkNetHTTPServerGet10ReqPerConn10KClients 1000000 13390 ns/op 2483 B/op 19 allocs/op BenchmarkNetHTTPServerGet100ReqPerConn10KClients 1000000 13484 ns/op 2171 B/op 18 allocs/op ``` fasthttp server: ``` $ GOMAXPROCS=1 go test -bench=kServerGet -benchmem -benchtime=10s BenchmarkServerGet1ReqPerConn 10000000 1559 ns/op 0 B/op 0 allocs/op BenchmarkServerGet2ReqPerConn 10000000 1248 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10ReqPerConn 20000000 797 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10KReqPerConn 20000000 716 ns/op 0 B/op 0 allocs/op BenchmarkServerGet1ReqPerConn10KClients 10000000 1974 ns/op 0 B/op 0 allocs/op BenchmarkServerGet2ReqPerConn10KClients 10000000 1352 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10ReqPerConn10KClients 20000000 789 ns/op 2 B/op 0 allocs/op BenchmarkServerGet100ReqPerConn10KClients 20000000 604 ns/op 0 B/op 0 allocs/op ``` *GOMAXPROCS=4* net/http server: ``` $ GOMAXPROCS=4 go test -bench=NetHTTPServerGet -benchmem -benchtime=10s BenchmarkNetHTTPServerGet1ReqPerConn-4 3000000 4529 ns/op 2389 B/op 29 allocs/op BenchmarkNetHTTPServerGet2ReqPerConn-4 5000000 3896 ns/op 2418 B/op 24 allocs/op BenchmarkNetHTTPServerGet10ReqPerConn-4 5000000 3145 ns/op 2160 B/op 19 allocs/op BenchmarkNetHTTPServerGet10KReqPerConn-4 5000000 3054 ns/op 2065 B/op 18 allocs/op BenchmarkNetHTTPServerGet1ReqPerConn10KClients-4 1000000 10321 ns/op 3710 B/op 30 allocs/op BenchmarkNetHTTPServerGet2ReqPerConn10KClients-4 2000000 7556 ns/op 3296 B/op 24 allocs/op BenchmarkNetHTTPServerGet10ReqPerConn10KClients-4 5000000 3905 ns/op 2349 B/op 19 allocs/op BenchmarkNetHTTPServerGet100ReqPerConn10KClients-4 5000000 3435 ns/op 2130 B/op 18 allocs/op ``` fasthttp server: ``` $ GOMAXPROCS=4 go test -bench=kServerGet -benchmem -benchtime=10s BenchmarkServerGet1ReqPerConn-4 10000000 1141 ns/op 0 B/op 0 allocs/op BenchmarkServerGet2ReqPerConn-4 20000000 707 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10ReqPerConn-4 30000000 341 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10KReqPerConn-4 50000000 310 ns/op 0 B/op 0 allocs/op BenchmarkServerGet1ReqPerConn10KClients-4 10000000 1119 ns/op 0 B/op 0 allocs/op BenchmarkServerGet2ReqPerConn10KClients-4 20000000 644 ns/op 0 B/op 0 allocs/op BenchmarkServerGet10ReqPerConn10KClients-4 30000000 346 ns/op 0 B/op 0 allocs/op BenchmarkServerGet100ReqPerConn10KClients-4 50000000 282 ns/op 0 B/op 0 allocs/op ``` # HTTP client comparison with net/http In short, fasthttp client is up to 10 times faster than net/http. Below are benchmark results. *GOMAXPROCS=1* net/http client: ``` $ GOMAXPROCS=1 go test -bench='HTTPClient(Do|GetEndToEnd)' -benchmem -benchtime=10s BenchmarkNetHTTPClientDoFastServer 1000000 12567 ns/op 2616 B/op 35 allocs/op BenchmarkNetHTTPClientGetEndToEnd1TCP 200000 67030 ns/op 5028 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd10TCP 300000 51098 ns/op 5031 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd100TCP 300000 45096 ns/op 5026 B/op 55 allocs/op BenchmarkNetHTTPClientGetEndToEnd1Inmemory 500000 24779 ns/op 5035 B/op 57 allocs/op BenchmarkNetHTTPClientGetEndToEnd10Inmemory 1000000 26425 ns/op 5035 B/op 57 allocs/op BenchmarkNetHTTPClientGetEndToEnd100Inmemory 500000 28515 ns/op 5045 B/op 57 allocs/op BenchmarkNetHTTPClientGetEndToEnd1000Inmemory 500000 39511 ns/op 5096 B/op 56 allocs/op ``` fasthttp client: ``` $ GOMAXPROCS=1 go test -bench='kClient(Do|GetEndToEnd)' -benchmem -benchtime=10s BenchmarkClientDoFastServer 20000000 865 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd1TCP 1000000 18711 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd10TCP 1000000 14664 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd100TCP 1000000 14043 ns/op 1 B/op 0 allocs/op BenchmarkClientGetEndToEnd1Inmemory 5000000 3965 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd10Inmemory 3000000 4060 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd100Inmemory 5000000 3396 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd1000Inmemory 5000000 3306 ns/op 2 B/op 0 allocs/op ``` *GOMAXPROCS=4* net/http client: ``` $ GOMAXPROCS=4 go test -bench='HTTPClient(Do|GetEndToEnd)' -benchmem -benchtime=10s BenchmarkNetHTTPClientDoFastServer-4 2000000 8774 ns/op 2619 B/op 35 allocs/op BenchmarkNetHTTPClientGetEndToEnd1TCP-4 500000 22951 ns/op 5047 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd10TCP-4 1000000 19182 ns/op 5037 B/op 55 allocs/op BenchmarkNetHTTPClientGetEndToEnd100TCP-4 1000000 16535 ns/op 5031 B/op 55 allocs/op BenchmarkNetHTTPClientGetEndToEnd1Inmemory-4 1000000 14495 ns/op 5038 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd10Inmemory-4 1000000 10237 ns/op 5034 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd100Inmemory-4 1000000 10125 ns/op 5045 B/op 56 allocs/op BenchmarkNetHTTPClientGetEndToEnd1000Inmemory-4 1000000 11132 ns/op 5136 B/op 56 allocs/op ``` fasthttp client: ``` $ GOMAXPROCS=4 go test -bench='kClient(Do|GetEndToEnd)' -benchmem -benchtime=10s BenchmarkClientDoFastServer-4 50000000 397 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd1TCP-4 2000000 7388 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd10TCP-4 2000000 6689 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd100TCP-4 3000000 4927 ns/op 1 B/op 0 allocs/op BenchmarkClientGetEndToEnd1Inmemory-4 10000000 1604 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd10Inmemory-4 10000000 1458 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd100Inmemory-4 10000000 1329 ns/op 0 B/op 0 allocs/op BenchmarkClientGetEndToEnd1000Inmemory-4 10000000 1316 ns/op 5 B/op 0 allocs/op ``` # Install ``` go get -u github.com/valyala/fasthttp ``` # Switching from net/http to fasthttp Unfortunately, fasthttp doesn't provide API identical to net/http. See the [FAQ](#faq) for details. There is [net/http -> fasthttp handler converter](https://godoc.org/github.com/valyala/fasthttp/fasthttpadaptor), but it is advisable writing fasthttp request handlers by hands for gaining all the fasthttp advantages (especially high performance :) ). Important points: * Fasthttp works with [RequestHandler functions](https://godoc.org/github.com/valyala/fasthttp#RequestHandler) instead of objects implementing [Handler interface](https://golang.org/pkg/net/http/#Handler). Fortunately, it is easy to pass bound struct methods to fasthttp: ```go type MyHandler struct { foobar string } // request handler in net/http style, i.e. method bound to MyHandler struct. func (h *MyHandler) HandleFastHTTP(ctx *fasthttp.RequestCtx) { // notice that we may access MyHandler properties here - see h.foobar. fmt.Fprintf(ctx, "Hello, world! Requested path is %q. Foobar is %q", ctx.Path(), h.foobar) } // request handler in fasthttp style, i.e. just plain function. func fastHTTPHandler(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hi there! RequestURI is %q", ctx.RequestURI()) } // pass bound struct method to fasthttp myHandler := &MyHandler{ foobar: "foobar", } fasthttp.ListenAndServe(":8080", myHandler.HandleFastHTTP) // pass plain function to fasthttp fasthttp.ListenAndServe(":8081", fastHTTPHandler) ``` * The [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler) accepts only one argument - [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx). It contains all the functionality required for http request processing and response writing. Below is an example of a simple request handler conversion from net/http to fasthttp. ```go // net/http request handler requestHandler := func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/foo": fooHandler(w, r) case "/bar": barHandler(w, r) default: http.Error(w, "Unsupported path", http.StatusNotFound) } } ``` ```go // the corresponding fasthttp request handler requestHandler := func(ctx *fasthttp.RequestCtx) { switch string(ctx.Path()) { case "/foo": fooHandler(ctx) case "/bar": barHandler(ctx) default: ctx.Error("Unsupported path", fasthttp.StatusNotFound) } } ``` * Fasthttp allows setting response headers and writing response body in arbitrary order. There is no 'headers first, then body' restriction like in net/http. The following code is valid for fasthttp: ```go requestHandler := func(ctx *fasthttp.RequestCtx) { // set some headers and status code first ctx.SetContentType("foo/bar") ctx.SetStatusCode(fasthttp.StatusOK) // then write the first part of body fmt.Fprintf(ctx, "this is the first part of body\n") // then set more headers ctx.Response.Header.Set("Foo-Bar", "baz") // then write more body fmt.Fprintf(ctx, "this is the second part of body\n") // then override already written body ctx.SetBody([]byte("this is completely new body contents")) // then update status code ctx.SetStatusCode(fasthttp.StatusNotFound) // basically, anything may be updated many times before // returning from RequestHandler. // // Unlike net/http fasthttp doesn't put response to the wire until // returning from RequestHandler. } ``` * Fasthttp doesn't provide [ServeMux](https://golang.org/pkg/net/http/#ServeMux), but there are more powerful third-party routers and web frameworks with fasthttp support exist: * [Iris](https://github.com/kataras/iris) * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) * [fasthttprouter](https://github.com/buaazp/fasthttprouter) * [echo v2](https://github.com/labstack/echo) Net/http code with simple ServeMux is trivially converted to fasthttp code: ```go // net/http code m := &http.ServeMux{} m.HandleFunc("/foo", fooHandlerFunc) m.HandleFunc("/bar", barHandlerFunc) m.Handle("/baz", bazHandler) http.ListenAndServe(":80", m) ``` ```go // the corresponding fasthttp code m := func(ctx *fasthttp.RequestCtx) { switch string(ctx.Path()) { case "/foo": fooHandlerFunc(ctx) case "/bar": barHandlerFunc(ctx) case "/baz": bazHandler.HandlerFunc(ctx) default: ctx.Error("not found", fasthttp.StatusNotFound) } } fastttp.ListenAndServe(":80", m) ``` * net/http -> fasthttp conversion table: * All the pseudocode below assumes w, r and ctx have these types: ```go var ( w http.ResponseWriter r *http.Request ctx *fasthttp.RequestCtx ) ``` * r.Body -> [ctx.PostBody()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostBody) * r.URL.Path -> [ctx.Path()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Path) * r.URL -> [ctx.URI()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.URI) * r.Method -> [ctx.Method()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Method) * r.Header -> [ctx.Request.Header](https://godoc.org/github.com/valyala/fasthttp#RequestHeader) * r.Header.Get() -> [ctx.Request.Header.Peek()](https://godoc.org/github.com/valyala/fasthttp#RequestHeader.Peek) * r.Host -> [ctx.Host()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Host) * r.Form -> [ctx.QueryArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.QueryArgs) + [ctx.PostArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostArgs) * r.PostForm -> [ctx.PostArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostArgs) * r.FormValue() -> [ctx.FormValue()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.FormValue) * r.FormFile() -> [ctx.FormFile()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.FormFile) * r.MultipartForm -> [ctx.MultipartForm()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.MultipartForm) * r.RemoteAddr -> [ctx.RemoteAddr()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.RemoteAddr) * r.RequestURI -> [ctx.RequestURI()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.RequestURI) * r.TLS -> [ctx.IsTLS()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.IsTLS) * r.Cookie() -> [ctx.Request.Header.Cookie()](https://godoc.org/github.com/valyala/fasthttp#RequestHeader.Cookie) * r.Referer() -> [ctx.Referer()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Referer) * r.UserAgent() -> [ctx.UserAgent()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.UserAgent) * w.Header() -> [ctx.Response.Header](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader) * w.Header().Set() -> [ctx.Response.Header.Set()](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader.Set) * w.Header().Set("Content-Type") -> [ctx.SetContentType()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetContentType) * w.Header().Set("Set-Cookie") -> [ctx.Response.Header.SetCookie()](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader.SetCookie) * w.Write() -> [ctx.Write()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Write), [ctx.SetBody()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBody), [ctx.SetBodyStream()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBodyStream), [ctx.SetBodyStreamWriter()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBodyStreamWriter) * w.WriteHeader() -> [ctx.SetStatusCode()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetStatusCode) * w.(http.Hijacker).Hijack() -> [ctx.Hijack()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack) * http.Error() -> [ctx.Error()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Error) * http.FileServer() -> [fasthttp.FSHandler()](https://godoc.org/github.com/valyala/fasthttp#FSHandler), [fasthttp.FS](https://godoc.org/github.com/valyala/fasthttp#FS) * http.ServeFile() -> [fasthttp.ServeFile()](https://godoc.org/github.com/valyala/fasthttp#ServeFile) * http.Redirect() -> [ctx.Redirect()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Redirect) * http.NotFound() -> [ctx.NotFound()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.NotFound) * http.StripPrefix() -> [fasthttp.PathRewriteFunc](https://godoc.org/github.com/valyala/fasthttp#PathRewriteFunc) * *VERY IMPORTANT!* Fasthttp disallows holding references to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx) or to its' members after returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler). Otherwise [data races](http://blog.golang.org/race-detector) are inevitable. Carefully inspect all the net/http request handlers converted to fasthttp whether they retain references to RequestCtx or to its' members after returning. RequestCtx provides the following _band aids_ for this case: * Wrap RequestHandler into [TimeoutHandler](https://godoc.org/github.com/valyala/fasthttp#TimeoutHandler). * Call [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError) before returning from RequestHandler if there are references to RequestCtx or to its' members. See [the example](https://godoc.org/github.com/valyala/fasthttp#example-RequestCtx-TimeoutError) for more details. Use brilliant tool - [race detector](http://blog.golang.org/race-detector) - for detecting and eliminating data races in your program. If you detected data race related to fasthttp in your program, then there is high probability you forgot calling [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError) before returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler). * Blind switching from net/http to fasthttp won't give you performance boost. While fasthttp is optimized for speed, its' performance may be easily saturated by slow [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler). So [profile](http://blog.golang.org/profiling-go-programs) and optimize your code after switching to fasthttp. For instance, use [quicktemplate](https://github.com/valyala/quicktemplate) instead of [html/template](https://golang.org/pkg/html/template/). * See also [fasthttputil](https://godoc.org/github.com/valyala/fasthttp/fasthttputil), [fasthttpadaptor](https://godoc.org/github.com/valyala/fasthttp/fasthttpadaptor) and [expvarhandler](https://godoc.org/github.com/valyala/fasthttp/expvarhandler). # Performance optimization tips for multi-core systems * Use [reuseport](https://godoc.org/github.com/valyala/fasthttp/reuseport) listener. * Run a separate server instance per CPU core with GOMAXPROCS=1. * Pin each server instance to a separate CPU core using [taskset](http://linux.die.net/man/1/taskset). * Ensure the interrupts of multiqueue network card are evenly distributed between CPU cores. See [this article](https://blog.cloudflare.com/how-to-achieve-low-latency/) for details. * Use Go 1.6 as it provides some considerable performance improvements. # Fasthttp best practices * Do not allocate objects and `[]byte` buffers - just reuse them as much as possible. Fasthttp API design encourages this. * [sync.Pool](https://golang.org/pkg/sync/#Pool) is your best friend. * [Profile your program](http://blog.golang.org/profiling-go-programs) in production. `go tool pprof --alloc_objects your-program mem.pprof` usually gives better insights for optimization opportunities than `go tool pprof your-program cpu.pprof`. * Write [tests and benchmarks](https://golang.org/pkg/testing/) for hot paths. * Avoid conversion between `[]byte` and `string`, since this may result in memory allocation+copy. Fasthttp API provides functions for both `[]byte` and `string` - use these functions instead of converting manually between `[]byte` and `string`. There are some exceptions - see [this wiki page](https://github.com/golang/go/wiki/CompilerOptimizations#string-and-byte) for more details. * Verify your tests and production code under [race detector](https://golang.org/doc/articles/race_detector.html) on a regular basis. * Prefer [quicktemplate](https://github.com/valyala/quicktemplate) instead of [html/template](https://golang.org/pkg/html/template/) in your webserver. # Tricks with `[]byte` buffers The following tricks are used by fasthttp. Use them in your code too. * Standard Go functions accept nil buffers ```go var ( // both buffers are uninitialized dst []byte src []byte ) dst = append(dst, src...) // is legal if dst is nil and/or src is nil copy(dst, src) // is legal if dst is nil and/or src is nil (string(src) == "") // is true if src is nil (len(src) == 0) // is true if src is nil src = src[:0] // works like a charm with nil src // this for loop doesn't panic if src is nil for i, ch := range src { doSomething(i, ch) } ``` So throw away nil checks for `[]byte` buffers from you code. For example, ```go srcLen := 0 if src != nil { srcLen = len(src) } ``` becomes ```go srcLen := len(src) ``` * String may be appended to `[]byte` buffer with `append` ```go dst = append(dst, "foobar"...) ``` * `[]byte` buffer may be extended to its' capacity. ```go buf := make([]byte, 100) a := buf[:10] // len(a) == 10, cap(a) == 100. b := a[:100] // is valid, since cap(a) == 100. ``` * All fasthttp functions accept nil `[]byte` buffer ```go statusCode, body, err := fasthttp.Get(nil, "http://google.com/") uintBuf := fasthttp.AppendUint(nil, 1234) ``` # Related projects * [fasthttp-contrib](https://github.com/fasthttp-contrib) - various useful helpers for projects based on fasthttp. * [iris](https://github.com/kataras/iris) - web application framework built on top of fasthttp. Features speed and functionality. * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) - fast and powerful routing package for fasthttp servers. * [fasthttprouter](https://github.com/buaazp/fasthttprouter) - a high performance fasthttp request router that scales well. * [echo](https://github.com/labstack/echo) - fast and unfancy HTTP server framework with fasthttp support. * [websocket](https://github.com/leavengood/websocket) - Gorilla-based websocket implementation for fasthttp. # FAQ * *Why creating yet another http package instead of optimizing net/http?* Because net/http API limits many optimization opportunities. For example: * net/http Request object lifetime isn't limited by request handler execution time. So the server must create new request object per each request instead of reusing existing objects like fasthttp do. * net/http headers are stored in a `map[string][]string`. So the server must parse all the headers, convert them from `[]byte` to `string` and put them into the map before calling user-provided request handler. This all requires unnecessary memory allocations avoided by fasthttp. * net/http client API requires creating new response object per each request. * *Why fasthttp API is incompatible with net/http?* Because net/http API limits many optimization opportunities. See the answer above for more details. Also certain net/http API parts are suboptimal for use: * Compare [net/http connection hijacking](https://golang.org/pkg/net/http/#Hijacker) to [fasthttp connection hijacking](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack). * Compare [net/http Request.Body reading](https://golang.org/pkg/net/http/#Request) to [fasthttp request body reading](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostBody). * *Why fasthttp doesn't support HTTP/2.0 and WebSockets?* There are [plans](TODO) for adding HTTP/2.0 and WebSockets support in the future. In the mean time, third parties may use [RequestCtx.Hijack](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack) for implementing these goodies. See [the first third-party websocket implementation on the top of fasthttp](https://github.com/leavengood/websocket). * *Are there known net/http advantages comparing to fasthttp?* Yes: * net/http supports [HTTP/2.0 starting from go1.6](https://http2.golang.org/). * net/http API is stable, while fasthttp API constantly evolves. * net/http handles more HTTP corner cases. * net/http should contain less bugs, since it is used and tested by much wider audience. * net/http works on Go older than 1.5. * *Why fasthttp API prefers returning `[]byte` instead of `string`?* Because `[]byte` to `string` conversion isn't free - it requires memory allocation and copy. Feel free wrapping returned `[]byte` result into `string()` if you prefer working with strings instead of byte slices. But be aware that this has non-zero overhead. * *Which GO versions are supported by fasthttp?* Go1.5+. Older versions won't be supported, since their standard package [miss useful functions](https://github.com/valyala/fasthttp/issues/5). * *Please provide real benchmark data and sever information* See [this issue](https://github.com/valyala/fasthttp/issues/4). * *Are there plans to add request routing to fasthttp?* There are no plans to add request routing into fasthttp. Use third-party routers and web frameworks with fasthttp support: * [Iris](https://github.com/kataras/iris) * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) * [fasthttprouter](https://github.com/buaazp/fasthttprouter) * [echo v2](https://github.com/labstack/echo) See also [this issue](https://github.com/valyala/fasthttp/issues/9) for more info. * *I detected data race in fasthttp!* Cool! [File a bug](https://github.com/valyala/fasthttp/issues/new). But before doing this check the following in your code: * Make sure there are no references to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx) or to its' members after returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler). * Make sure you call [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError) before returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler) if there are references to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx) or to its' members, which may be accessed by other goroutines. * *I didn't find an answer for my question here* Try exploring [these questions](https://github.com/valyala/fasthttp/issues?q=label%3Aquestion). golang-github-valyala-fasthttp-20160617/TODO000066400000000000000000000003051273074646000204720ustar00rootroot00000000000000- SessionClient with referer and cookies support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . golang-github-valyala-fasthttp-20160617/args.go000066400000000000000000000237561273074646000213040ustar00rootroot00000000000000package fasthttp import ( "bytes" "errors" "io" "sync" ) // AcquireArgs returns an empty Args object from the pool. // // The returned Args may be returned to the pool with ReleaseArgs // when no longer needed. This allows reducing GC load. func AcquireArgs() *Args { return argsPool.Get().(*Args) } // ReleaseArgs returns the object acquired via AquireArgs to the pool. // // Do not access the released Args object, otherwise data races may occur. func ReleaseArgs(a *Args) { a.Reset() argsPool.Put(a) } var argsPool = &sync.Pool{ New: func() interface{} { return &Args{} }, } // Args represents query arguments. // // It is forbidden copying Args instances. Create new instances instead // and use CopyTo(). // // Args instance MUST NOT be used from concurrently running goroutines. type Args struct { noCopy noCopy args []argsKV buf []byte } type argsKV struct { key []byte value []byte } // Reset clears query args. func (a *Args) Reset() { a.args = a.args[:0] } // CopyTo copies all args to dst. func (a *Args) CopyTo(dst *Args) { dst.Reset() dst.args = copyArgs(dst.args, a.args) } // VisitAll calls f for each existing arg. // // f must not retain references to key and value after returning. // Make key and/or value copies if you need storing them after returning. func (a *Args) VisitAll(f func(key, value []byte)) { visitArgs(a.args, f) } // Len returns the number of query args. func (a *Args) Len() int { return len(a.args) } // Parse parses the given string containing query args. func (a *Args) Parse(s string) { a.buf = append(a.buf[:0], s...) a.ParseBytes(a.buf) } // ParseBytes parses the given b containing query args. func (a *Args) ParseBytes(b []byte) { a.Reset() var s argsScanner s.b = b var kv *argsKV a.args, kv = allocArg(a.args) for s.next(kv) { if len(kv.key) > 0 || len(kv.value) > 0 { a.args, kv = allocArg(a.args) } } a.args = releaseArg(a.args) } // String returns string representation of query args. func (a *Args) String() string { return string(a.QueryString()) } // QueryString returns query string for the args. // // The returned value is valid until the next call to Args methods. func (a *Args) QueryString() []byte { a.buf = a.AppendBytes(a.buf[:0]) return a.buf } // AppendBytes appends query string to dst and returns the extended dst. func (a *Args) AppendBytes(dst []byte) []byte { for i, n := 0, len(a.args); i < n; i++ { kv := &a.args[i] dst = AppendQuotedArg(dst, kv.key) if len(kv.value) > 0 { dst = append(dst, '=') dst = AppendQuotedArg(dst, kv.value) } if i+1 < n { dst = append(dst, '&') } } return dst } // WriteTo writes query string to w. // // WriteTo implements io.WriterTo interface. func (a *Args) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(a.QueryString()) return int64(n), err } // Del deletes argument with the given key from query args. func (a *Args) Del(key string) { a.args = delAllArgs(a.args, key) } // DelBytes deletes argument with the given key from query args. func (a *Args) DelBytes(key []byte) { a.args = delAllArgs(a.args, b2s(key)) } // Add adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) Add(key, value string) { a.args = appendArg(a.args, key, value) } // AddBytesK adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesK(key []byte, value string) { a.args = appendArg(a.args, b2s(key), value) } // AddBytesV adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesV(key string, value []byte) { a.args = appendArg(a.args, key, b2s(value)) } // AddBytesKV adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesKV(key, value []byte) { a.args = appendArg(a.args, b2s(key), b2s(value)) } // Set sets 'key=value' argument. func (a *Args) Set(key, value string) { a.args = setArg(a.args, key, value) } // SetBytesK sets 'key=value' argument. func (a *Args) SetBytesK(key []byte, value string) { a.args = setArg(a.args, b2s(key), value) } // SetBytesV sets 'key=value' argument. func (a *Args) SetBytesV(key string, value []byte) { a.args = setArg(a.args, key, b2s(value)) } // SetBytesKV sets 'key=value' argument. func (a *Args) SetBytesKV(key, value []byte) { a.args = setArgBytes(a.args, key, value) } // Peek returns query arg value for the given key. // // Returned value is valid until the next Args call. func (a *Args) Peek(key string) []byte { return peekArgStr(a.args, key) } // PeekBytes returns query arg value for the given key. // // Returned value is valid until the next Args call. func (a *Args) PeekBytes(key []byte) []byte { return peekArgBytes(a.args, key) } // PeekMulti returns all the arg values for the given key. func (a *Args) PeekMulti(key string) [][]byte { var values [][]byte a.VisitAll(func(k, v []byte) { if string(k) == key { values = append(values, v) } }) return values } // PeekMultiBytes returns all the arg values for the given key. func (a *Args) PeekMultiBytes(key []byte) [][]byte { return a.PeekMulti(b2s(key)) } // Has returns true if the given key exists in Args. func (a *Args) Has(key string) bool { return hasArg(a.args, key) } // HasBytes returns true if the given key exists in Args. func (a *Args) HasBytes(key []byte) bool { return hasArg(a.args, b2s(key)) } // ErrNoArgValue is returned when Args value with the given key is missing. var ErrNoArgValue = errors.New("no Args value for the given key") // GetUint returns uint value for the given key. func (a *Args) GetUint(key string) (int, error) { value := a.Peek(key) if len(value) == 0 { return -1, ErrNoArgValue } return ParseUint(value) } // SetUint sets uint value for the given key. func (a *Args) SetUint(key string, value int) { bb := AcquireByteBuffer() bb.B = AppendUint(bb.B[:0], value) a.SetBytesV(key, bb.B) ReleaseByteBuffer(bb) } // SetUintBytes sets uint value for the given key. func (a *Args) SetUintBytes(key []byte, value int) { a.SetUint(b2s(key), value) } // GetUintOrZero returns uint value for the given key. // // Zero (0) is returned on error. func (a *Args) GetUintOrZero(key string) int { n, err := a.GetUint(key) if err != nil { n = 0 } return n } // GetUfloat returns ufloat value for the given key. func (a *Args) GetUfloat(key string) (float64, error) { value := a.Peek(key) if len(value) == 0 { return -1, ErrNoArgValue } return ParseUfloat(value) } // GetUfloatOrZero returns ufloat value for the given key. // // Zero (0) is returned on error. func (a *Args) GetUfloatOrZero(key string) float64 { f, err := a.GetUfloat(key) if err != nil { f = 0 } return f } func visitArgs(args []argsKV, f func(k, v []byte)) { for i, n := 0, len(args); i < n; i++ { kv := &args[i] f(kv.key, kv.value) } } func copyArgs(dst, src []argsKV) []argsKV { if cap(dst) < len(src) { tmp := make([]argsKV, len(src)) copy(tmp, dst) dst = tmp } n := len(src) dst = dst[:n] for i := 0; i < n; i++ { dstKV := &dst[i] srcKV := &src[i] dstKV.key = append(dstKV.key[:0], srcKV.key...) dstKV.value = append(dstKV.value[:0], srcKV.value...) } return dst } func delAllArgsBytes(args []argsKV, key []byte) []argsKV { return delAllArgs(args, b2s(key)) } func delAllArgs(args []argsKV, key string) []argsKV { for i, n := 0, len(args); i < n; i++ { kv := &args[i] if key == string(kv.key) { tmp := *kv copy(args[i:], args[i+1:]) n-- args[n] = tmp args = args[:n] } } return args } func setArgBytes(h []argsKV, key, value []byte) []argsKV { return setArg(h, b2s(key), b2s(value)) } func setArg(h []argsKV, key, value string) []argsKV { n := len(h) for i := 0; i < n; i++ { kv := &h[i] if key == string(kv.key) { kv.value = append(kv.value[:0], value...) return h } } return appendArg(h, key, value) } func appendArgBytes(h []argsKV, key, value []byte) []argsKV { return appendArg(h, b2s(key), b2s(value)) } func appendArg(args []argsKV, key, value string) []argsKV { var kv *argsKV args, kv = allocArg(args) kv.key = append(kv.key[:0], key...) kv.value = append(kv.value[:0], value...) return args } func allocArg(h []argsKV) ([]argsKV, *argsKV) { n := len(h) if cap(h) > n { h = h[:n+1] } else { h = append(h, argsKV{}) } return h, &h[n] } func releaseArg(h []argsKV) []argsKV { return h[:len(h)-1] } func hasArg(h []argsKV, key string) bool { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if key == string(kv.key) { return true } } return false } func peekArgBytes(h []argsKV, k []byte) []byte { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if bytes.Equal(kv.key, k) { return kv.value } } return nil } func peekArgStr(h []argsKV, k string) []byte { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if string(kv.key) == k { return kv.value } } return nil } type argsScanner struct { b []byte } func (s *argsScanner) next(kv *argsKV) bool { if len(s.b) == 0 { return false } isKey := true k := 0 for i, c := range s.b { switch c { case '=': if isKey { isKey = false kv.key = decodeArg(kv.key, s.b[:i], true) k = i + 1 } case '&': if isKey { kv.key = decodeArg(kv.key, s.b[:i], true) kv.value = kv.value[:0] } else { kv.value = decodeArg(kv.value, s.b[k:i], true) } s.b = s.b[i+1:] return true } } if isKey { kv.key = decodeArg(kv.key, s.b, true) kv.value = kv.value[:0] } else { kv.value = decodeArg(kv.value, s.b[k:], true) } s.b = s.b[len(s.b):] return true } func decodeArg(dst, src []byte, decodePlus bool) []byte { return decodeArgAppend(dst[:0], src, decodePlus) } func decodeArgAppend(dst, src []byte, decodePlus bool) []byte { for i, n := 0, len(src); i < n; i++ { c := src[i] if c == '%' { if i+2 >= n { return append(dst, src[i:]...) } x1 := hexbyte2int(src[i+1]) x2 := hexbyte2int(src[i+2]) if x1 < 0 || x2 < 0 { dst = append(dst, c) } else { dst = append(dst, byte(x1<<4|x2)) i += 2 } } else if decodePlus && c == '+' { dst = append(dst, ' ') } else { dst = append(dst, c) } } return dst } golang-github-valyala-fasthttp-20160617/args_test.go000066400000000000000000000247261273074646000223410ustar00rootroot00000000000000package fasthttp import ( "fmt" "reflect" "strings" "testing" "time" ) func TestArgsAdd(t *testing.T) { var a Args a.Add("foo", "bar") a.Add("foo", "baz") a.Add("foo", "1") a.Add("ba", "23") if a.Len() != 4 { t.Fatalf("unexpected number of elements: %d. Expecting 4", a.Len()) } s := a.String() expectedS := "foo=bar&foo=baz&foo=1&ba=23" if s != expectedS { t.Fatalf("unexpected result: %q. Expecting %q", s, expectedS) } var a1 Args a1.Parse(s) if a1.Len() != 4 { t.Fatalf("unexpected number of elements: %d. Expecting 4", a.Len()) } var barFound, bazFound, oneFound, baFound bool a1.VisitAll(func(k, v []byte) { switch string(k) { case "foo": switch string(v) { case "bar": barFound = true case "baz": bazFound = true case "1": oneFound = true default: t.Fatalf("unexpected value %q", v) } case "ba": if string(v) != "23" { t.Fatalf("unexpected value: %q. Expecting %q", v, "23") } baFound = true default: t.Fatalf("unexpected key found %q", k) } }) if !barFound || !bazFound || !oneFound || !baFound { t.Fatalf("something is missing: %v, %v, %v, %v", barFound, bazFound, oneFound, baFound) } } func TestArgsAcquireReleaseSequential(t *testing.T) { testArgsAcquireRelease(t) } func TestArgsAcquireReleaseConcurrent(t *testing.T) { ch := make(chan struct{}, 10) for i := 0; i < 10; i++ { go func() { testArgsAcquireRelease(t) ch <- struct{}{} }() } for i := 0; i < 10; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testArgsAcquireRelease(t *testing.T) { a := AcquireArgs() for i := 0; i < 10; i++ { k := fmt.Sprintf("key_%d", i) v := fmt.Sprintf("value_%d", i*3+123) a.Set(k, v) } s := a.String() a.Reset() a.Parse(s) for i := 0; i < 10; i++ { k := fmt.Sprintf("key_%d", i) expectedV := fmt.Sprintf("value_%d", i*3+123) v := a.Peek(k) if string(v) != expectedV { t.Fatalf("unexpected value %q for key %q. Expecting %q", v, k, expectedV) } } ReleaseArgs(a) } func TestArgsPeekMulti(t *testing.T) { var a Args a.Parse("foo=123&bar=121&foo=321&foo=&barz=sdf") vv := a.PeekMulti("foo") expectedVV := [][]byte{ []byte("123"), []byte("321"), []byte(nil), } if !reflect.DeepEqual(vv, expectedVV) { t.Fatalf("unexpected vv\n%#v\nExpecting\n%#v\n", vv, expectedVV) } vv = a.PeekMulti("aaaa") if len(vv) > 0 { t.Fatalf("expecting empty result for non-existing key. Got %#v", vv) } vv = a.PeekMulti("bar") expectedVV = [][]byte{[]byte("121")} if !reflect.DeepEqual(vv, expectedVV) { t.Fatalf("unexpected vv\n%#v\nExpecting\n%#v\n", vv, expectedVV) } } func TestArgsEscape(t *testing.T) { testArgsEscape(t, "foo", "bar", "foo=bar") testArgsEscape(t, "f.o,1:2/4", "~`!@#$%^&*()_-=+\\|/[]{};:'\"<>,./?", "f.o%2C1%3A2%2F4=%7E%60%21%40%23%24%25%5E%26*%28%29_-%3D%2B%5C%7C%2F%5B%5D%7B%7D%3B%3A%27%22%3C%3E%2C.%2F%3F") } func testArgsEscape(t *testing.T, k, v, expectedS string) { var a Args a.Set(k, v) s := a.String() if s != expectedS { t.Fatalf("unexpected args %q. Expecting %q. k=%q, v=%q", s, expectedS, k, v) } } func TestArgsWriteTo(t *testing.T) { s := "foo=bar&baz=123&aaa=bbb" var a Args a.Parse(s) var w ByteBuffer n, err := a.WriteTo(&w) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != int64(len(s)) { t.Fatalf("unexpected n: %d. Expecting %d", n, len(s)) } result := string(w.B) if result != s { t.Fatalf("unexpected result %q. Expecting %q", result, s) } } func TestArgsUint(t *testing.T) { var a Args a.SetUint("foo", 123) a.SetUint("bar", 0) a.SetUint("aaaa", 34566) expectedS := "foo=123&bar=0&aaaa=34566" s := string(a.QueryString()) if s != expectedS { t.Fatalf("unexpected args %q. Expecting %q", s, expectedS) } if a.GetUintOrZero("foo") != 123 { t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("foo"), 123) } if a.GetUintOrZero("bar") != 0 { t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("bar"), 0) } if a.GetUintOrZero("aaaa") != 34566 { t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("aaaa"), 34566) } if string(a.Peek("foo")) != "123" { t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("foo"), "123") } if string(a.Peek("bar")) != "0" { t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("bar"), "0") } if string(a.Peek("aaaa")) != "34566" { t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("aaaa"), "34566") } } func TestArgsCopyTo(t *testing.T) { var a Args // empty args testCopyTo(t, &a) a.Set("foo", "bar") testCopyTo(t, &a) a.Set("xxx", "yyy") testCopyTo(t, &a) a.Del("foo") testCopyTo(t, &a) } func testCopyTo(t *testing.T, a *Args) { keys := make(map[string]struct{}) a.VisitAll(func(k, v []byte) { keys[string(k)] = struct{}{} }) var b Args a.CopyTo(&b) b.VisitAll(func(k, v []byte) { if _, ok := keys[string(k)]; !ok { t.Fatalf("unexpected key %q after copying from %q", k, a.String()) } delete(keys, string(k)) }) if len(keys) > 0 { t.Fatalf("missing keys %#v after copying from %q", keys, a.String()) } } func TestArgsVisitAll(t *testing.T) { var a Args a.Set("foo", "bar") i := 0 a.VisitAll(func(k, v []byte) { if string(k) != "foo" { t.Fatalf("unexpected key %q. Expected %q", k, "foo") } if string(v) != "bar" { t.Fatalf("unexpected value %q. Expected %q", v, "bar") } i++ }) if i != 1 { t.Fatalf("unexpected number of VisitAll calls: %d. Expected %d", i, 1) } } func TestArgsStringCompose(t *testing.T) { var a Args a.Set("foo", "bar") a.Set("aa", "bbb") a.Set("привет", "мир") a.Set("", "xxxx") a.Set("cvx", "") expectedS := "foo=bar&aa=bbb&%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=%D0%BC%D0%B8%D1%80&=xxxx&cvx" s := a.String() if s != expectedS { t.Fatalf("Unexpected string %q. Exected %q", s, expectedS) } } func TestArgsString(t *testing.T) { var a Args testArgsString(t, &a, "") testArgsString(t, &a, "foobar") testArgsString(t, &a, "foo=bar") testArgsString(t, &a, "foo=bar&baz=sss") testArgsString(t, &a, "") testArgsString(t, &a, "f%20o=x.x*-_8x%D0%BF%D1%80%D0%B8%D0%B2%D0%B5aaa&sdf=ss") testArgsString(t, &a, "=asdfsdf") } func testArgsString(t *testing.T, a *Args, s string) { a.Parse(s) s1 := a.String() if s != s1 { t.Fatalf("Unexpected args %q. Expected %q", s1, s) } } func TestArgsSetGetDel(t *testing.T) { var a Args if len(a.Peek("foo")) > 0 { t.Fatalf("Unexpected value: %q", a.Peek("foo")) } if len(a.Peek("")) > 0 { t.Fatalf("Unexpected value: %q", a.Peek("")) } a.Del("xxx") for j := 0; j < 3; j++ { for i := 0; i < 10; i++ { k := fmt.Sprintf("foo%d", i) v := fmt.Sprintf("bar_%d", i) a.Set(k, v) if string(a.Peek(k)) != v { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } } } for i := 0; i < 10; i++ { k := fmt.Sprintf("foo%d", i) v := fmt.Sprintf("bar_%d", i) if string(a.Peek(k)) != v { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } a.Del(k) if string(a.Peek(k)) != "" { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "") } } a.Parse("aaa=xxx&bb=aa") if string(a.Peek("foo0")) != "" { t.Fatalf("Unepxected value %q", a.Peek("foo0")) } if string(a.Peek("aaa")) != "xxx" { t.Fatalf("Unexpected value %q. Expected %q", a.Peek("aaa"), "xxx") } if string(a.Peek("bb")) != "aa" { t.Fatalf("Unexpected value %q. Expected %q", a.Peek("bb"), "aa") } for i := 0; i < 10; i++ { k := fmt.Sprintf("xx%d", i) v := fmt.Sprintf("yy%d", i) a.Set(k, v) if string(a.Peek(k)) != v { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } } for i := 5; i < 10; i++ { k := fmt.Sprintf("xx%d", i) v := fmt.Sprintf("yy%d", i) if string(a.Peek(k)) != v { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } a.Del(k) if string(a.Peek(k)) != "" { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "") } } } func TestArgsParse(t *testing.T) { var a Args // empty args testArgsParse(t, &a, "", 0, "foo=", "bar=", "=") // arg without value testArgsParse(t, &a, "foo1", 1, "foo=", "bar=", "=") // arg without value, but with equal sign testArgsParse(t, &a, "foo2=", 1, "foo=", "bar=", "=") // arg with value testArgsParse(t, &a, "foo3=bar1", 1, "foo3=bar1", "bar=", "=") // empty key testArgsParse(t, &a, "=bar2", 1, "foo=", "=bar2", "bar2=") // missing kv testArgsParse(t, &a, "&&&&", 0, "foo=", "bar=", "=") // multiple values with the same key testArgsParse(t, &a, "x=1&x=2&x=3", 3, "x=1") // multiple args testArgsParse(t, &a, "&&&qw=er&tyx=124&&&zxc_ss=2234&&", 3, "qw=er", "tyx=124", "zxc_ss=2234") // multiple args without values testArgsParse(t, &a, "&&a&&b&&bar&baz", 4, "a=", "b=", "bar=", "baz=") // values with '=' testArgsParse(t, &a, "zz=1&k=v=v=a=a=s", 2, "k=v=v=a=a=s", "zz=1") // mixed '=' and '&' testArgsParse(t, &a, "sss&z=dsf=&df", 3, "sss=", "z=dsf=", "df=") // encoded args testArgsParse(t, &a, "f+o%20o=%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82+test", 1, "f o o=привет test") // invalid percent encoding testArgsParse(t, &a, "f%=x&qw%z=d%0k%20p&%%20=%%%20x", 3, "f%=x", "qw%z=d%0k p", "% =%% x") // special chars testArgsParse(t, &a, "a.b,c:d/e=f.g,h:i/q", 1, "a.b,c:d/e=f.g,h:i/q") } func TestArgsHas(t *testing.T) { var a Args // single arg testArgsHas(t, &a, "foo", "foo") testArgsHasNot(t, &a, "foo", "bar", "baz", "") // multi args without values testArgsHas(t, &a, "foo&bar", "foo", "bar") testArgsHasNot(t, &a, "foo&bar", "", "aaaa") // multi args testArgsHas(t, &a, "b=xx&=aaa&c=", "b", "", "c") testArgsHasNot(t, &a, "b=xx&=aaa&c=", "xx", "aaa", "foo") // encoded args testArgsHas(t, &a, "a+b=c+d%20%20e", "a b") testArgsHasNot(t, &a, "a+b=c+d", "a+b", "c+d") } func testArgsHas(t *testing.T, a *Args, s string, expectedKeys ...string) { a.Parse(s) for _, key := range expectedKeys { if !a.Has(key) { t.Fatalf("Missing key %q in %q", key, s) } } } func testArgsHasNot(t *testing.T, a *Args, s string, unexpectedKeys ...string) { a.Parse(s) for _, key := range unexpectedKeys { if a.Has(key) { t.Fatalf("Unexpected key %q in %q", key, s) } } } func testArgsParse(t *testing.T, a *Args, s string, expectedLen int, expectedArgs ...string) { a.Parse(s) if a.Len() != expectedLen { t.Fatalf("Unexpected args len %d. Expected %d. s=%q", a.Len(), expectedLen, s) } for _, xx := range expectedArgs { tmp := strings.SplitN(xx, "=", 2) k := tmp[0] v := tmp[1] buf := a.Peek(k) if string(buf) != v { t.Fatalf("Unexpected value for key=%q: %q. Expected %q. s=%q", k, buf, v, s) } } } golang-github-valyala-fasthttp-20160617/args_timing_test.go000066400000000000000000000010461273074646000236760ustar00rootroot00000000000000package fasthttp import ( "bytes" "testing" ) func BenchmarkArgsParse(b *testing.B) { s := []byte("foo=bar&baz=qqq&aaaaa=bbbb") b.RunParallel(func(pb *testing.PB) { var a Args for pb.Next() { a.ParseBytes(s) } }) } func BenchmarkArgsPeek(b *testing.B) { value := []byte("foobarbaz1234") key := "foobarbaz" b.RunParallel(func(pb *testing.PB) { var a Args a.SetBytesV(key, value) for pb.Next() { if !bytes.Equal(a.Peek(key), value) { b.Fatalf("unexpected arg value %q. Expecting %q", a.Peek(key), value) } } }) } golang-github-valyala-fasthttp-20160617/bytebuffer.go000066400000000000000000000037401273074646000224740ustar00rootroot00000000000000package fasthttp import ( "sync" ) const ( defaultByteBufferSize = 128 ) // ByteBuffer provides byte buffer, which can be used with fasthttp API // in order to minimize memory allocations. // // ByteBuffer may be used with functions appending data to the given []byte // slice. See example code for details. // // Use AcquireByteBuffer for obtaining an empty byte buffer. type ByteBuffer struct { // B is a byte buffer to use in append-like workloads. // See example code for details. B []byte } // Write implements io.Writer - it appends p to ByteBuffer.B func (b *ByteBuffer) Write(p []byte) (int, error) { b.B = append(b.B, p...) return len(p), nil } // WriteString appends s to ByteBuffer.B func (b *ByteBuffer) WriteString(s string) (int, error) { b.B = append(b.B, s...) return len(s), nil } // Set sets ByteBuffer.B to p func (b *ByteBuffer) Set(p []byte) { b.B = append(b.B[:0], p...) } // SetString sets ByteBuffer.B to s func (b *ByteBuffer) SetString(s string) { b.B = append(b.B[:0], s...) } // Reset makes ByteBuffer.B empty. func (b *ByteBuffer) Reset() { b.B = b.B[:0] } // AcquireByteBuffer returns an empty byte buffer from the pool. // // Acquired byte buffer may be returned to the pool via ReleaseByteBuffer call. // This reduces the number of memory allocations required for byte buffer // management. func AcquireByteBuffer() *ByteBuffer { return defaultByteBufferPool.Acquire() } // ReleaseByteBuffer returns byte buffer to the pool. // // ByteBuffer.B mustn't be touched after returning it to the pool. // Otherwise data races occur. func ReleaseByteBuffer(b *ByteBuffer) { defaultByteBufferPool.Release(b) } type byteBufferPool struct { pool sync.Pool } var defaultByteBufferPool byteBufferPool func (p *byteBufferPool) Acquire() *ByteBuffer { v := p.pool.Get() if v == nil { return &ByteBuffer{ B: make([]byte, 0, defaultByteBufferSize), } } return v.(*ByteBuffer) } func (p *byteBufferPool) Release(b *ByteBuffer) { b.B = b.B[:0] p.pool.Put(b) } golang-github-valyala-fasthttp-20160617/bytebuffer_example_test.go000066400000000000000000000015271273074646000252470ustar00rootroot00000000000000package fasthttp_test import ( "fmt" "github.com/valyala/fasthttp" ) func ExampleByteBuffer() { // This request handler sets 'Your-IP' response header // to 'Your IP is '. It uses ByteBuffer for constructing response // header value with zero memory allocations. yourIPRequestHandler := func(ctx *fasthttp.RequestCtx) { b := fasthttp.AcquireByteBuffer() b.B = append(b.B, "Your IP is <"...) b.B = fasthttp.AppendIPv4(b.B, ctx.RemoteIP()) b.B = append(b.B, ">"...) ctx.Response.Header.SetBytesV("Your-IP", b.B) fmt.Fprintf(ctx, "Check response headers - they must contain 'Your-IP: %s'", b.B) // It is safe to release byte buffer now, since it is // no longer used. fasthttp.ReleaseByteBuffer(b) } // Start fasthttp server returning your ip in response headers. fasthttp.ListenAndServe(":8080", yourIPRequestHandler) } golang-github-valyala-fasthttp-20160617/bytebuffer_test.go000066400000000000000000000015131273074646000235270ustar00rootroot00000000000000package fasthttp import ( "fmt" "testing" "time" ) func TestByteBufferAcquireReleaseSerial(t *testing.T) { testByteBufferAcquireRelease(t) } func TestByteBufferAcquireReleaseConcurrent(t *testing.T) { concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testByteBufferAcquireRelease(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout!") } } } func testByteBufferAcquireRelease(t *testing.T) { for i := 0; i < 10; i++ { b := AcquireByteBuffer() b.B = append(b.B, "num "...) b.B = AppendUint(b.B, i) expectedS := fmt.Sprintf("num %d", i) if string(b.B) != expectedS { t.Fatalf("unexpected result: %q. Expecting %q", b.B, expectedS) } ReleaseByteBuffer(b) } } golang-github-valyala-fasthttp-20160617/bytebuffer_timing_test.go000066400000000000000000000007721273074646000251040ustar00rootroot00000000000000package fasthttp import ( "bytes" "testing" ) func BenchmarkByteBufferWrite(b *testing.B) { s := []byte("foobarbaz") b.RunParallel(func(pb *testing.PB) { var buf ByteBuffer for pb.Next() { for i := 0; i < 100; i++ { buf.Write(s) } buf.Reset() } }) } func BenchmarkBytesBufferWrite(b *testing.B) { s := []byte("foobarbaz") b.RunParallel(func(pb *testing.PB) { var buf bytes.Buffer for pb.Next() { for i := 0; i < 100; i++ { buf.Write(s) } buf.Reset() } }) } golang-github-valyala-fasthttp-20160617/bytesconv.go000066400000000000000000000220161273074646000223500ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "errors" "fmt" "io" "math" "net" "reflect" "sync" "time" "unsafe" ) // AppendHTMLEscape appends html-escaped s to dst and returns the extended dst. func AppendHTMLEscape(dst []byte, s string) []byte { var prev int var sub string for i, n := 0, len(s); i < n; i++ { sub = "" switch s[i] { case '<': sub = "<" case '>': sub = ">" case '"': sub = """ case '\'': sub = "'" } if len(sub) > 0 { dst = append(dst, s[prev:i]...) dst = append(dst, sub...) prev = i + 1 } } return append(dst, s[prev:]...) } // AppendHTMLEscapeBytes appends html-escaped s to dst and returns // the extended dst. func AppendHTMLEscapeBytes(dst, s []byte) []byte { return AppendHTMLEscape(dst, b2s(s)) } // AppendIPv4 appends string representation of the given ip v4 to dst // and returns the extended dst. func AppendIPv4(dst []byte, ip net.IP) []byte { ip = ip.To4() if ip == nil { return append(dst, "non-v4 ip passed to AppendIPv4"...) } dst = AppendUint(dst, int(ip[0])) for i := 1; i < 4; i++ { dst = append(dst, '.') dst = AppendUint(dst, int(ip[i])) } return dst } var errEmptyIPStr = errors.New("empty ip address string") // ParseIPv4 parses ip address from ipStr into dst and returns the extended dst. func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) { if len(ipStr) == 0 { return dst, errEmptyIPStr } if len(dst) < net.IPv4len { dst = make([]byte, net.IPv4len) } copy(dst, net.IPv4zero) dst = dst.To4() if dst == nil { panic("BUG: dst must not be nil") } b := ipStr for i := 0; i < 3; i++ { n := bytes.IndexByte(b, '.') if n < 0 { return dst, fmt.Errorf("cannot find dot in ipStr %q", ipStr) } v, err := ParseUint(b[:n]) if err != nil { return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err) } if v > 255 { return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v) } dst[i] = byte(v) b = b[n+1:] } v, err := ParseUint(b) if err != nil { return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err) } if v > 255 { return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v) } dst[3] = byte(v) return dst, nil } // AppendHTTPDate appends HTTP-compliant (RFC1123) representation of date // to dst and returns the extended dst. func AppendHTTPDate(dst []byte, date time.Time) []byte { dst = date.In(time.UTC).AppendFormat(dst, time.RFC1123) copy(dst[len(dst)-3:], strGMT) return dst } // ParseHTTPDate parses HTTP-compliant (RFC1123) date. func ParseHTTPDate(date []byte) (time.Time, error) { return time.Parse(time.RFC1123, b2s(date)) } // AppendUint appends n to dst and returns the extended dst. func AppendUint(dst []byte, n int) []byte { if n < 0 { panic("BUG: int must be positive") } var b [20]byte buf := b[:] i := len(buf) var q int for n >= 10 { i-- q = n / 10 buf[i] = '0' + byte(n-q*10) n = q } i-- buf[i] = '0' + byte(n) dst = append(dst, buf[i:]...) return dst } // ParseUint parses uint from buf. func ParseUint(buf []byte) (int, error) { v, n, err := parseUintBuf(buf) if n != len(buf) { return -1, errUnexpectedTrailingChar } return v, err } var ( errEmptyInt = errors.New("empty integer") errUnexpectedFirstChar = errors.New("unexpected first char found. Expecting 0-9") errUnexpectedTrailingChar = errors.New("unexpected traling char found. Expecting 0-9") errTooLongInt = errors.New("too long int") ) func parseUintBuf(b []byte) (int, int, error) { n := len(b) if n == 0 { return -1, 0, errEmptyInt } v := 0 for i := 0; i < n; i++ { c := b[i] k := c - '0' if k > 9 { if i == 0 { return -1, i, errUnexpectedFirstChar } return v, i, nil } if i >= maxIntChars { return -1, i, errTooLongInt } v = 10*v + int(k) } return v, n, nil } var ( errEmptyFloat = errors.New("empty float number") errDuplicateFloatPoint = errors.New("duplicate point found in float number") errUnexpectedFloatEnd = errors.New("unexpected end of float number") errInvalidFloatExponent = errors.New("invalid float number exponent") errUnexpectedFloatChar = errors.New("unexpected char found in float number") ) // ParseUfloat parses unsigned float from buf. func ParseUfloat(buf []byte) (float64, error) { if len(buf) == 0 { return -1, errEmptyFloat } b := buf var v uint64 var offset = 1.0 var pointFound bool for i, c := range b { if c < '0' || c > '9' { if c == '.' { if pointFound { return -1, errDuplicateFloatPoint } pointFound = true continue } if c == 'e' || c == 'E' { if i+1 >= len(b) { return -1, errUnexpectedFloatEnd } b = b[i+1:] minus := -1 switch b[0] { case '+': b = b[1:] minus = 1 case '-': b = b[1:] default: minus = 1 } vv, err := ParseUint(b) if err != nil { return -1, errInvalidFloatExponent } return float64(v) * offset * math.Pow10(minus*int(vv)), nil } return -1, errUnexpectedFloatChar } v = 10*v + uint64(c-'0') if pointFound { offset /= 10 } } return float64(v) * offset, nil } var ( errEmptyHexNum = errors.New("empty hex number") errTooLargeHexNum = errors.New("too large hex number") ) func readHexInt(r *bufio.Reader) (int, error) { n := 0 i := 0 var k int for { c, err := r.ReadByte() if err != nil { if err == io.EOF && i > 0 { return n, nil } return -1, err } k = hexbyte2int(c) if k < 0 { if i == 0 { return -1, errEmptyHexNum } r.UnreadByte() return n, nil } if i >= maxHexIntChars { return -1, errTooLargeHexNum } n = (n << 4) | k i++ } } var hexIntBufPool sync.Pool func writeHexInt(w *bufio.Writer, n int) error { if n < 0 { panic("BUG: int must be positive") } v := hexIntBufPool.Get() if v == nil { v = make([]byte, maxHexIntChars+1) } buf := v.([]byte) i := len(buf) - 1 for { buf[i] = int2hexbyte(n & 0xf) n >>= 4 if n == 0 { break } i-- } _, err := w.Write(buf[i:]) hexIntBufPool.Put(v) return err } func int2hexbyte(n int) byte { if n < 10 { return '0' + byte(n) } return 'a' + byte(n) - 10 } func hexCharUpper(c byte) byte { if c < 10 { return '0' + c } return c - 10 + 'A' } var hex2intTable = func() []byte { b := make([]byte, 255) for i := byte(0); i < 255; i++ { c := byte(0) if i >= '0' && i <= '9' { c = 1 + i - '0' } else if i >= 'a' && i <= 'f' { c = 1 + i - 'a' + 10 } else if i >= 'A' && i <= 'F' { c = 1 + i - 'A' + 10 } b[i] = c } return b }() func hexbyte2int(c byte) int { return int(hex2intTable[c]) - 1 } const toLower = 'a' - 'A' func uppercaseByte(p *byte) { c := *p if c >= 'a' && c <= 'z' { *p = c - toLower } } func lowercaseByte(p *byte) { c := *p if c >= 'A' && c <= 'Z' { *p = c + toLower } } func lowercaseBytes(b []byte) { for i, n := 0, len(b); i < n; i++ { lowercaseByte(&b[i]) } } // b2s converts byte slice to a string without memory allocation. // See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . // // Note it may break if string and/or slice header will change // in the future go versions. func b2s(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } // s2b converts string to a byte slice without memory allocation. // // Note it may break if string and/or slice header will change // in the future go versions. func s2b(s string) []byte { sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) bh := reflect.SliceHeader{ Data: sh.Data, Len: sh.Len, Cap: sh.Len, } return *(*[]byte)(unsafe.Pointer(&bh)) } // AppendQuotedArg appends url-encoded src to dst and returns appended dst. func AppendQuotedArg(dst, src []byte) []byte { for _, c := range src { // See http://www.w3.org/TR/html5/forms.html#form-submission-algorithm if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '*' || c == '-' || c == '.' || c == '_' { dst = append(dst, c) } else { dst = append(dst, '%', hexCharUpper(c>>4), hexCharUpper(c&15)) } } return dst } func appendQuotedPath(dst, src []byte) []byte { for _, c := range src { if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '/' || c == '.' || c == ',' || c == '=' || c == ':' || c == '&' || c == '~' || c == '-' || c == '_' { dst = append(dst, c) } else { dst = append(dst, '%', hexCharUpper(c>>4), hexCharUpper(c&15)) } } return dst } // EqualBytesStr returns true if string(b) == s. // // This function has no performance benefits comparing to string(b) == s. // It is left here for backwards compatibility only. // // This function is deperecated and may be deleted soon. func EqualBytesStr(b []byte, s string) bool { return string(b) == s } // AppendBytesStr appends src to dst and returns the extended dst. // // This function has no performance benefits comparing to append(dst, src...). // It is left here for backwards compatibility only. // // This function is deprecated and may be deleted soon. func AppendBytesStr(dst []byte, src string) []byte { return append(dst, src...) } golang-github-valyala-fasthttp-20160617/bytesconv_32.go000066400000000000000000000001441273074646000226520ustar00rootroot00000000000000// +build !amd64,!arm64,!ppc64 package fasthttp const ( maxIntChars = 9 maxHexIntChars = 7 ) golang-github-valyala-fasthttp-20160617/bytesconv_32_test.go000066400000000000000000000015471273074646000237210ustar00rootroot00000000000000// +build !amd64,!arm64,!ppc64 package fasthttp import ( "testing" ) func TestWriteHexInt(t *testing.T) { testWriteHexInt(t, 0, "0") testWriteHexInt(t, 1, "1") testWriteHexInt(t, 0x123, "123") testWriteHexInt(t, 0x7fffffff, "7fffffff") } func TestAppendUint(t *testing.T) { testAppendUint(t, 0) testAppendUint(t, 123) testAppendUint(t, 0x7fffffff) for i := 0; i < 2345; i++ { testAppendUint(t, i) } } func TestReadHexIntSuccess(t *testing.T) { testReadHexIntSuccess(t, "0", 0) testReadHexIntSuccess(t, "fF", 0xff) testReadHexIntSuccess(t, "00abc", 0xabc) testReadHexIntSuccess(t, "7ffffff", 0x7ffffff) testReadHexIntSuccess(t, "000", 0) testReadHexIntSuccess(t, "1234ZZZ", 0x1234) } func TestParseUintSuccess(t *testing.T) { testParseUintSuccess(t, "0", 0) testParseUintSuccess(t, "123", 123) testParseUintSuccess(t, "123456789", 123456789) } golang-github-valyala-fasthttp-20160617/bytesconv_64.go000066400000000000000000000001431273074646000226560ustar00rootroot00000000000000// +build amd64 arm64 ppc64 package fasthttp const ( maxIntChars = 18 maxHexIntChars = 15 ) golang-github-valyala-fasthttp-20160617/bytesconv_64_test.go000066400000000000000000000020031273074646000237120ustar00rootroot00000000000000// +build amd64 arm64 ppc64 package fasthttp import ( "testing" ) func TestWriteHexInt(t *testing.T) { testWriteHexInt(t, 0, "0") testWriteHexInt(t, 1, "1") testWriteHexInt(t, 0x123, "123") testWriteHexInt(t, 0x7fffffffffffffff, "7fffffffffffffff") } func TestAppendUint(t *testing.T) { testAppendUint(t, 0) testAppendUint(t, 123) testAppendUint(t, 0x7fffffffffffffff) for i := 0; i < 2345; i++ { testAppendUint(t, i) } } func TestReadHexIntSuccess(t *testing.T) { testReadHexIntSuccess(t, "0", 0) testReadHexIntSuccess(t, "fF", 0xff) testReadHexIntSuccess(t, "00abc", 0xabc) testReadHexIntSuccess(t, "7fffffff", 0x7fffffff) testReadHexIntSuccess(t, "000", 0) testReadHexIntSuccess(t, "1234ZZZ", 0x1234) testReadHexIntSuccess(t, "7ffffffffffffff", 0x7ffffffffffffff) } func TestParseUintSuccess(t *testing.T) { testParseUintSuccess(t, "0", 0) testParseUintSuccess(t, "123", 123) testParseUintSuccess(t, "1234567890", 1234567890) testParseUintSuccess(t, "123456789012345678", 123456789012345678) } golang-github-valyala-fasthttp-20160617/bytesconv_test.go000066400000000000000000000146461273074646000234210ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "fmt" "net" "testing" "time" ) func TestAppendHTMLEscape(t *testing.T) { testAppendHTMLEscape(t, "", "") testAppendHTMLEscape(t, "<", "<") testAppendHTMLEscape(t, "a", "a") testAppendHTMLEscape(t, `><"''`, "><"''") testAppendHTMLEscape(t, "foaxxx", "fo<b x='ss'>a</b>xxx") } func testAppendHTMLEscape(t *testing.T, s, expectedS string) { buf := AppendHTMLEscapeBytes(nil, []byte(s)) if string(buf) != expectedS { t.Fatalf("unexpected html-escaped string %q. Expecting %q. Original string %q", buf, expectedS, s) } } func TestParseIPv4(t *testing.T) { testParseIPv4(t, "0.0.0.0", true) testParseIPv4(t, "255.255.255.255", true) testParseIPv4(t, "123.45.67.89", true) // ipv6 shouldn't work testParseIPv4(t, "2001:4860:0:2001::68", false) // invalid ip testParseIPv4(t, "foobar", false) testParseIPv4(t, "1.2.3", false) testParseIPv4(t, "123.456.789.11", false) } func testParseIPv4(t *testing.T, ipStr string, isValid bool) { ip, err := ParseIPv4(nil, []byte(ipStr)) if isValid { if err != nil { t.Fatalf("unexpected error when parsing ip %q: %s", ipStr, err) } s := string(AppendIPv4(nil, ip)) if s != ipStr { t.Fatalf("unexpected ip parsed %q. Expecting %q", s, ipStr) } } else { if err == nil { t.Fatalf("expecting error when parsing ip %q", ipStr) } } } func TestAppendIPv4(t *testing.T) { testAppendIPv4(t, "0.0.0.0", true) testAppendIPv4(t, "127.0.0.1", true) testAppendIPv4(t, "8.8.8.8", true) testAppendIPv4(t, "123.45.67.89", true) // ipv6 shouldn't work testAppendIPv4(t, "2001:4860:0:2001::68", false) } func testAppendIPv4(t *testing.T, ipStr string, isValid bool) { ip := net.ParseIP(ipStr) if ip == nil { t.Fatalf("cannot parse ip %q", ipStr) } s := string(AppendIPv4(nil, ip)) if isValid { if s != ipStr { t.Fatalf("unepxected ip %q. Expecting %q", s, ipStr) } } else { ipStr = "non-v4 ip passed to AppendIPv4" if s != ipStr { t.Fatalf("unexpected ip %q. Expecting %q", s, ipStr) } } } func testAppendUint(t *testing.T, n int) { expectedS := fmt.Sprintf("%d", n) s := AppendUint(nil, n) if string(s) != expectedS { t.Fatalf("unexpected uint %q. Expecting %q. n=%d", s, expectedS, n) } } func testWriteHexInt(t *testing.T, n int, expectedS string) { var w ByteBuffer bw := bufio.NewWriter(&w) if err := writeHexInt(bw, n); err != nil { t.Fatalf("unexpected error when writing hex %x: %s", n, err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error when flushing hex %x: %s", n, err) } s := string(w.B) if s != expectedS { t.Fatalf("unexpected hex after writing %q. Expected %q", s, expectedS) } } func TestReadHexIntError(t *testing.T) { testReadHexIntError(t, "") testReadHexIntError(t, "ZZZ") testReadHexIntError(t, "-123") testReadHexIntError(t, "+434") } func testReadHexIntError(t *testing.T, s string) { r := bytes.NewBufferString(s) br := bufio.NewReader(r) n, err := readHexInt(br) if err == nil { t.Fatalf("expecting error when reading hex int %q", s) } if n >= 0 { t.Fatalf("unexpected hex value read %d for hex int %q. must be negative", n, s) } } func testReadHexIntSuccess(t *testing.T, s string, expectedN int) { r := bytes.NewBufferString(s) br := bufio.NewReader(r) n, err := readHexInt(br) if err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } if n != expectedN { t.Fatalf("unexpected hex int %d. Expected %d. s=%q", n, expectedN, s) } } func TestAppendHTTPDate(t *testing.T) { d := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) s := string(AppendHTTPDate(nil, d)) expectedS := "Tue, 10 Nov 2009 23:00:00 GMT" if s != expectedS { t.Fatalf("unexpected date %q. Expecting %q", s, expectedS) } b := []byte("prefix") s = string(AppendHTTPDate(b, d)) if s[:len(b)] != string(b) { t.Fatalf("unexpected prefix %q. Expecting %q", s[:len(b)], b) } s = s[len(b):] if s != expectedS { t.Fatalf("unexpected date %q. Expecting %q", s, expectedS) } } func TestParseUintError(t *testing.T) { // empty string testParseUintError(t, "") // negative value testParseUintError(t, "-123") // non-num testParseUintError(t, "foobar234") // non-num chars at the end testParseUintError(t, "123w") // floating point num testParseUintError(t, "1234.545") // too big num testParseUintError(t, "12345678901234567890") } func TestParseUfloatSuccess(t *testing.T) { testParseUfloatSuccess(t, "0", 0) testParseUfloatSuccess(t, "1.", 1.) testParseUfloatSuccess(t, ".1", 0.1) testParseUfloatSuccess(t, "123.456", 123.456) testParseUfloatSuccess(t, "123", 123) testParseUfloatSuccess(t, "1234e2", 1234e2) testParseUfloatSuccess(t, "1234E-5", 1234E-5) testParseUfloatSuccess(t, "1.234e+3", 1.234e+3) } func TestParseUfloatError(t *testing.T) { // empty num testParseUfloatError(t, "") // negative num testParseUfloatError(t, "-123.53") // non-num chars testParseUfloatError(t, "123sdfsd") testParseUfloatError(t, "sdsf234") testParseUfloatError(t, "sdfdf") // non-num chars in exponent testParseUfloatError(t, "123e3s") testParseUfloatError(t, "12.3e-op") testParseUfloatError(t, "123E+SS5") // duplicate point testParseUfloatError(t, "1.3.4") // duplicate exponent testParseUfloatError(t, "123e5e6") // missing exponent testParseUfloatError(t, "123534e") } func testParseUfloatError(t *testing.T, s string) { n, err := ParseUfloat([]byte(s)) if err == nil { t.Fatalf("Expecting error when parsing %q. obtained %f", s, n) } if n >= 0 { t.Fatalf("Expecting negative num instead of %f when parsing %q", n, s) } } func testParseUfloatSuccess(t *testing.T, s string, expectedF float64) { f, err := ParseUfloat([]byte(s)) if err != nil { t.Fatalf("Unexpected error when parsing %q: %s", s, err) } delta := f - expectedF if delta < 0 { delta = -delta } if delta > expectedF*1e-10 { t.Fatalf("Unexpected value when parsing %q: %f. Expected %f", s, f, expectedF) } } func testParseUintError(t *testing.T, s string) { n, err := ParseUint([]byte(s)) if err == nil { t.Fatalf("Expecting error when parsing %q. obtained %d", s, n) } if n >= 0 { t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, s) } } func testParseUintSuccess(t *testing.T, s string, expectedN int) { n, err := ParseUint([]byte(s)) if err != nil { t.Fatalf("Unexpected error when parsing %q: %s", s, err) } if n != expectedN { t.Fatalf("Unexpected value %d. Expected %d. num=%q", n, expectedN, s) } } golang-github-valyala-fasthttp-20160617/bytesconv_timing_test.go000066400000000000000000000063211273074646000247570ustar00rootroot00000000000000package fasthttp import ( "bufio" "html" "net" "testing" ) func BenchmarkAppendHTMLEscape(b *testing.B) { sOrig := "foobarbazxxxyyyzzz" sExpected := string(AppendHTMLEscape(nil, sOrig)) b.RunParallel(func(pb *testing.PB) { var buf []byte for pb.Next() { for i := 0; i < 10; i++ { buf = AppendHTMLEscape(buf[:0], sOrig) if string(buf) != sExpected { b.Fatalf("unexpected escaped string: %s. Expecting %s", buf, sExpected) } } } }) } func BenchmarkHTMLEscapeString(b *testing.B) { sOrig := "foobarbazxxxyyyzzz" sExpected := html.EscapeString(sOrig) b.RunParallel(func(pb *testing.PB) { var s string for pb.Next() { for i := 0; i < 10; i++ { s = html.EscapeString(sOrig) if s != sExpected { b.Fatalf("unexpected escaped string: %s. Expecting %s", s, sExpected) } } } }) } func BenchmarkParseIPv4(b *testing.B) { ipStr := []byte("123.145.167.189") b.RunParallel(func(pb *testing.PB) { var ip net.IP var err error for pb.Next() { ip, err = ParseIPv4(ip, ipStr) if err != nil { b.Fatalf("unexpected error: %s", err) } } }) } func BenchmarkAppendIPv4(b *testing.B) { ip := net.ParseIP("123.145.167.189") b.RunParallel(func(pb *testing.PB) { var buf []byte for pb.Next() { buf = AppendIPv4(buf[:0], ip) } }) } func BenchmarkInt2HexByte(b *testing.B) { buf := []int{1, 0xf, 2, 0xd, 3, 0xe, 4, 0xa, 5, 0xb, 6, 0xc, 7, 0xf, 0, 0xf, 6, 0xd, 9, 8, 4, 0x5} b.RunParallel(func(pb *testing.PB) { var n int for pb.Next() { for _, n = range buf { int2hexbyte(n) } } }) } func BenchmarkHexByte2Int(b *testing.B) { buf := []byte("0A1B2c3d4E5F6C7a8D9ab7cd03ef") b.RunParallel(func(pb *testing.PB) { var c byte for pb.Next() { for _, c = range buf { hexbyte2int(c) } } }) } func BenchmarkWriteHexInt(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var w ByteBuffer bw := bufio.NewWriter(&w) i := 0 for pb.Next() { writeHexInt(bw, i) i++ if i > 0x7fffffff { i = 0 } w.Reset() bw.Reset(&w) } }) } func BenchmarkParseUint(b *testing.B) { b.RunParallel(func(pb *testing.PB) { buf := []byte("1234567") for pb.Next() { n, err := ParseUint(buf) if err != nil { b.Fatalf("unexpected error: %s", err) } if n != 1234567 { b.Fatalf("unexpected result: %d. Expecting %s", n, buf) } } }) } func BenchmarkAppendUint(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var buf []byte i := 0 for pb.Next() { buf = AppendUint(buf[:0], i) i++ if i > 0x7fffffff { i = 0 } } }) } func BenchmarkLowercaseBytesNoop(b *testing.B) { src := []byte("foobarbaz_lowercased_all") b.RunParallel(func(pb *testing.PB) { s := make([]byte, len(src)) for pb.Next() { copy(s, src) lowercaseBytes(s) } }) } func BenchmarkLowercaseBytesAll(b *testing.B) { src := []byte("FOOBARBAZ_UPPERCASED_ALL") b.RunParallel(func(pb *testing.PB) { s := make([]byte, len(src)) for pb.Next() { copy(s, src) lowercaseBytes(s) } }) } func BenchmarkLowercaseBytesMixed(b *testing.B) { src := []byte("Foobarbaz_Uppercased_Mix") b.RunParallel(func(pb *testing.PB) { s := make([]byte, len(src)) for pb.Next() { copy(s, src) lowercaseBytes(s) } }) } golang-github-valyala-fasthttp-20160617/client.go000066400000000000000000001400221273074646000216100ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "crypto/tls" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" ) // Do performs the given http request and fills the given http response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections // to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func Do(req *Request, resp *Response) error { return defaultClient.Do(req, resp) } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return defaultClient.DoTimeout(req, resp, timeout) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func DoDeadline(req *Request, resp *Response, deadline time.Time) error { return defaultClient.DoDeadline(req, resp, deadline) } // Get appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. func Get(dst []byte, url string) (statusCode int, body []byte, err error) { return defaultClient.Get(dst, url) } // GetTimeout appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. func GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { return defaultClient.GetTimeout(dst, url, timeout) } // GetDeadline appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. func GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { return defaultClient.GetDeadline(dst, url, deadline) } // Post sends POST request to the given url with the given POST arguments. // // Response body is appended to dst, which is returned as body. // // New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { return defaultClient.Post(dst, url, postArgs) } var defaultClient Client // Client implements http client. // // Copying Client by value is prohibited. Create new instance instead. // // It is safe calling Client methods from concurrently running goroutines. type Client struct { noCopy noCopy // Client name. Used in User-Agent request header. // // Default client name is used if not set. Name string // Callback for establishing new connections to hosts. // // Default Dial is used if not set. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 addresses if set to true. // // This option is used only if default TCP dialer is used, // i.e. if Dial is blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool // TLS config for https connections. // // Default TLS config is used if not set. TLSConfig *tls.Config // Maximum number of connections per each host which may be established. // // DefaultMaxConnsPerHost is used if not set. MaxConnsPerHost int // Idle keep-alive connections are closed after this duration. // // By default idle connections are closed // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration // Per-connection buffer size for responses' reading. // This also limits the maximum header size. // // Default buffer size is used if 0. ReadBufferSize int // Per-connection buffer size for requests' writing. // // Default buffer size is used if 0. WriteBufferSize int // Maximum duration for full response reading (including body). // // By default response read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for full request writing (including body). // // By default request write timeout is unlimited. WriteTimeout time.Duration // Maximum response body size. // // The client returns ErrBodyTooLarge if this limit is greater than 0 // and response body is greater than the limit. // // By default response body size is unlimited. MaxResponseBodySize int // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // responses to other clients expecting case-sensitive // header names. See https://github.com/valyala/fasthttp/issues/57 // for details. // // By default request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool mLock sync.Mutex m map[string]*HostClient ms map[string]*HostClient } // Get appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. func (c *Client) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } // GetTimeout appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { return clientGetURLTimeout(dst, url, timeout, c) } // GetDeadline appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *Client) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { return clientGetURLDeadline(dst, url, deadline, c) } // Post sends POST request to the given url with the given POST arguments. // // Response body is appended to dst, which is returned as body. // // New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { return clientPostURL(dst, url, postArgs, c) } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return clientDoTimeout(req, resp, timeout, c) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error { return clientDoDeadline(req, resp, deadline, c) } // Do performs the given http request and fills the given http response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections // to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) Do(req *Request, resp *Response) error { uri := req.URI() host := uri.Host() isTLS := false scheme := uri.Scheme() if bytes.Equal(scheme, strHTTPS) { isTLS = true } else if !bytes.Equal(scheme, strHTTP) { return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) } startCleaner := false c.mLock.Lock() m := c.m if isTLS { m = c.ms } if m == nil { m = make(map[string]*HostClient) if isTLS { c.ms = m } else { c.m = m } } hc := m[string(host)] if hc == nil { hc = &HostClient{ Addr: addMissingPort(string(host), isTLS), Name: c.Name, Dial: c.Dial, DialDualStack: c.DialDualStack, IsTLS: isTLS, TLSConfig: c.TLSConfig, MaxConns: c.MaxConnsPerHost, MaxIdleConnDuration: c.MaxIdleConnDuration, ReadBufferSize: c.ReadBufferSize, WriteBufferSize: c.WriteBufferSize, ReadTimeout: c.ReadTimeout, WriteTimeout: c.WriteTimeout, MaxResponseBodySize: c.MaxResponseBodySize, DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, } m[string(host)] = hc if len(m) == 1 { startCleaner = true } } c.mLock.Unlock() if startCleaner { go c.mCleaner(m) } return hc.Do(req, resp) } func (c *Client) mCleaner(m map[string]*HostClient) { mustStop := false for { t := time.Now() c.mLock.Lock() for k, v := range m { if t.Sub(v.LastUseTime()) > time.Minute { delete(m, k) } } if len(m) == 0 { mustStop = true } c.mLock.Unlock() if mustStop { break } time.Sleep(10 * time.Second) } } // DefaultMaxConnsPerHost is the maximum number of concurrent connections // http client may establish per host by default (i.e. if // Client.MaxConnsPerHost isn't set). const DefaultMaxConnsPerHost = 512 // DefaultMaxIdleConnDuration is the default duration before idle keep-alive // connection is closed. const DefaultMaxIdleConnDuration = 10 * time.Second // DialFunc must establish connection to addr. // // There is no need in establishing TLS (SSL) connection for https. // The client automatically converts connection to TLS // if HostClient.IsTLS is set. // // TCP address passed to DialFunc always contains host and port. // Example TCP addr values: // // - foobar.com:80 // - foobar.com:443 // - foobar.com:8080 type DialFunc func(addr string) (net.Conn, error) // HostClient balances http requests among hosts listed in Addr. // // HostClient may be used for balancing load among multiple upstream hosts. // // It is forbidden copying HostClient instances. Create new instances instead. // // It is safe calling HostClient methods from concurrently running goroutines. type HostClient struct { noCopy noCopy // Comma-separated list of upstream HTTP server host addresses, // which are passed to Dial in round-robin manner. // // Each address may contain port if default dialer is used. // For example, // // - foobar.com:80 // - foobar.com:443 // - foobar.com:8080 Addr string // Client name. Used in User-Agent request header. Name string // Callback for establishing new connection to the host. // // Default Dial is used if not set. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 host addresses // if set to true. // // This option is used only if default TCP dialer is used, // i.e. if Dial is blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool // Whether to use TLS (aka SSL or HTTPS) for host connections. IsTLS bool // Optional TLS config. TLSConfig *tls.Config // Maximum number of connections which may be established to all hosts // listed in Addr. // // DefaultMaxConnsPerHost is used if not set. MaxConns int // Keep-alive connections are closed after this duration. // // By default connection duration is unlimited. MaxConnDuration time.Duration // Idle keep-alive connections are closed after this duration. // // By default idle connections are closed // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration // Per-connection buffer size for responses' reading. // This also limits the maximum header size. // // Default buffer size is used if 0. ReadBufferSize int // Per-connection buffer size for requests' writing. // // Default buffer size is used if 0. WriteBufferSize int // Maximum duration for full response reading (including body). // // By default response read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for full request writing (including body). // // By default request write timeout is unlimited. WriteTimeout time.Duration // Maximum response body size. // // The client returns ErrBodyTooLarge if this limit is greater than 0 // and response body is greater than the limit. // // By default response body size is unlimited. MaxResponseBodySize int // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // responses to other clients expecting case-sensitive // header names. See https://github.com/valyala/fasthttp/issues/57 // for details. // // By default request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool clientName atomic.Value lastUseTime uint32 connsLock sync.Mutex connsCount int conns []*clientConn addrsLock sync.Mutex addrs []string addrIdx uint32 readerPool sync.Pool writerPool sync.Pool } type clientConn struct { c net.Conn createdTime time.Time lastUseTime time.Time lastReadDeadlineTime time.Time lastWriteDeadlineTime time.Time } var startTimeUnix = time.Now().Unix() // LastUseTime returns time the client was last used func (c *HostClient) LastUseTime() time.Time { n := atomic.LoadUint32(&c.lastUseTime) return time.Unix(startTimeUnix+int64(n), 0) } // Get appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. func (c *HostClient) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } // GetTimeout appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { return clientGetURLTimeout(dst, url, timeout, c) } // GetDeadline appends url contents to dst and returns it as body. // // New body buffer is allocated if dst is nil. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *HostClient) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { return clientGetURLDeadline(dst, url, deadline, c) } // Post sends POST request to the given url with the given POST arguments. // // Response body is appended to dst, which is returned as body. // // New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { return clientPostURL(dst, url, postArgs, c) } type clientDoer interface { Do(req *Request, resp *Response) error } func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { req := AcquireRequest() statusCode, body, err = doRequestFollowRedirects(req, dst, url, c) ReleaseRequest(req) return statusCode, body, err } func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { deadline := time.Now().Add(timeout) return clientGetURLDeadline(dst, url, deadline, c) } func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { var sleepTime time.Duration for { statusCode, body, err = clientGetURLDeadlineFreeConn(dst, url, deadline, c) if err != ErrNoFreeConns { return statusCode, body, err } sleepTime = updateSleepTime(sleepTime, deadline) time.Sleep(sleepTime) } } type clientURLResponse struct { statusCode int body []byte err error } func clientGetURLDeadlineFreeConn(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { timeout := -time.Since(deadline) if timeout <= 0 { return 0, dst, ErrTimeout } var ch chan clientURLResponse chv := clientURLResponseChPool.Get() if chv == nil { chv = make(chan clientURLResponse, 1) } ch = chv.(chan clientURLResponse) req := AcquireRequest() // Note that the request continues execution on ErrTimeout until // client-specific ReadTimeout exceeds. This helps limiting load // on slow hosts by MaxConns* concurrent requests. // // Without this 'hack' the load on slow host could exceed MaxConns* // concurrent requests, since timed out requests on client side // usually continue execution on the host. go func() { statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirects(req, dst, url, c) ch <- clientURLResponse{ statusCode: statusCodeCopy, body: bodyCopy, err: errCopy, } }() tc := acquireTimer(timeout) select { case resp := <-ch: ReleaseRequest(req) clientURLResponseChPool.Put(chv) statusCode = resp.statusCode body = resp.body err = resp.err case <-tc.C: body = dst err = ErrTimeout } releaseTimer(tc) return statusCode, body, err } var clientURLResponseChPool sync.Pool func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) { req := AcquireRequest() req.Header.SetMethodBytes(strPost) req.Header.SetContentTypeBytes(strPostArgsContentType) if postArgs != nil { postArgs.WriteTo(req.BodyWriter()) } statusCode, body, err = doRequestFollowRedirects(req, dst, url, c) ReleaseRequest(req) return statusCode, body, err } var ( errMissingLocation = errors.New("missing Location header for http redirect") errTooManyRedirects = errors.New("too many redirects detected when doing the request") ) const maxRedirectsCount = 16 func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { resp := AcquireResponse() bodyBuf := resp.bodyBuffer() resp.keepBodyBuffer = true oldBody := bodyBuf.B bodyBuf.B = dst redirectsCount := 0 for { req.parsedURI = false req.Header.host = req.Header.host[:0] req.SetRequestURI(url) if err = c.Do(req, resp); err != nil { break } statusCode = resp.Header.StatusCode() if statusCode != StatusMovedPermanently && statusCode != StatusFound && statusCode != StatusSeeOther { break } redirectsCount++ if redirectsCount > maxRedirectsCount { err = errTooManyRedirects break } location := resp.Header.peek(strLocation) if len(location) == 0 { err = errMissingLocation break } url = getRedirectURL(url, location) } body = bodyBuf.B bodyBuf.B = oldBody resp.keepBodyBuffer = false ReleaseResponse(resp) return statusCode, body, err } func getRedirectURL(baseURL string, location []byte) string { u := AcquireURI() u.Update(baseURL) u.UpdateBytes(location) redirectURL := u.String() ReleaseURI(u) return redirectURL } var ( requestPool sync.Pool responsePool sync.Pool ) // AcquireRequest returns an empty Request instance from request pool. // // The returned Request instance may be passed to ReleaseRequest when it is // no longer needed. This allows Request recycling, reduces GC pressure // and usually improves performance. func AcquireRequest() *Request { v := requestPool.Get() if v == nil { return &Request{} } return v.(*Request) } // ReleaseRequest returns req acquired via AcquireRequest to request pool. // // It is forbidden accessing req and/or its' members after returning // it to request pool. func ReleaseRequest(req *Request) { req.Reset() requestPool.Put(req) } // AcquireResponse returns an empty Response instance from response pool. // // The returned Response instance may be passed to ReleaseResponse when it is // no longer needed. This allows Response recycling, reduces GC pressure // and usually improves performance. func AcquireResponse() *Response { v := responsePool.Get() if v == nil { return &Response{} } return v.(*Response) } // ReleaseResponse return resp acquired via AcquireResponse to response pool. // // It is forbidden accessing resp and/or its' members after returning // it to response pool. func ReleaseResponse(resp *Response) { resp.Reset() responsePool.Put(resp) } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return clientDoTimeout(req, resp, timeout, c) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { return clientDoDeadline(req, resp, deadline, c) } func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error { deadline := time.Now().Add(timeout) return clientDoDeadline(req, resp, deadline, c) } func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error { var sleepTime time.Duration for { err := clientDoDeadlineFreeConn(req, resp, deadline, c) if err != ErrNoFreeConns { return err } sleepTime = updateSleepTime(sleepTime, deadline) time.Sleep(sleepTime) } } var sleepJitter uint64 func updateSleepTime(prevTime time.Duration, deadline time.Time) time.Duration { sleepTime := prevTime * 2 if sleepTime == 0 { jitter := atomic.AddUint64(&sleepJitter, 1) % 40 sleepTime = (10 + time.Duration(jitter)) * time.Millisecond } remainingTime := deadline.Sub(time.Now()) if sleepTime >= remainingTime { // Just sleep for the remaining time and then time out. // This should save CPU time for real work by other goroutines. sleepTime = remainingTime + 10*time.Millisecond if sleepTime < 0 { sleepTime = 10 * time.Millisecond } } return sleepTime } func clientDoDeadlineFreeConn(req *Request, resp *Response, deadline time.Time, c clientDoer) error { timeout := -time.Since(deadline) if timeout <= 0 { return ErrTimeout } var ch chan error chv := errorChPool.Get() if chv == nil { chv = make(chan error, 1) } ch = chv.(chan error) // Make req and resp copies, since on timeout they no longer // may be accessed. reqCopy := AcquireRequest() req.copyToSkipBody(reqCopy) swapRequestBody(req, reqCopy) respCopy := AcquireResponse() // Note that the request continues execution on ErrTimeout until // client-specific ReadTimeout exceeds. This helps limiting load // on slow hosts by MaxConns* concurrent requests. // // Without this 'hack' the load on slow host could exceed MaxConns* // concurrent requests, since timed out requests on client side // usually continue execution on the host. go func() { ch <- c.Do(reqCopy, respCopy) }() tc := acquireTimer(timeout) var err error select { case err = <-ch: if resp != nil { respCopy.copyToSkipBody(resp) swapResponseBody(resp, respCopy) } ReleaseResponse(respCopy) ReleaseRequest(reqCopy) errorChPool.Put(chv) case <-tc.C: err = ErrTimeout } releaseTimer(tc) return err } var errorChPool sync.Pool // Do performs the given http request and sets the corresponding response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrNoFreeConns is returned if all HostClient.MaxConns connections // to the host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) Do(req *Request, resp *Response) error { retry, err := c.do(req, resp) if err != nil && retry && isIdempotent(req) { _, err = c.do(req, resp) } if err == io.EOF { err = ErrConnectionClosed } return err } func isIdempotent(req *Request) bool { return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut() } func (c *HostClient) do(req *Request, resp *Response) (bool, error) { nilResp := false if resp == nil { nilResp = true resp = AcquireResponse() } ok, err := c.doNonNilReqResp(req, resp) if nilResp { ReleaseResponse(resp) } return ok, err } func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) { if req == nil { panic("BUG: req cannot be nil") } if resp == nil { panic("BUG: resp cannot be nil") } atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // Free up resources occupied by response before sending the request, // so the GC may reclaim these resources (e.g. response body). resp.Reset() cc, err := c.acquireConn() if err != nil { return false, err } conn := cc.c if c.WriteTimeout > 0 { // Optimization: update write deadline only if more than 25% // of the last write deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. currentTime := time.Now() if currentTime.Sub(cc.lastWriteDeadlineTime) > (c.WriteTimeout >> 2) { if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil { c.closeConn(cc) return true, err } cc.lastWriteDeadlineTime = currentTime } } resetConnection := false if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() { req.SetConnectionClose() resetConnection = true } userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { req.Header.userAgent = c.getClientName() } bw := c.acquireWriter(conn) err = req.Write(bw) if len(userAgentOld) == 0 { req.Header.userAgent = userAgentOld } if resetConnection { req.Header.ResetConnectionClose() } if err == nil { err = bw.Flush() } if err != nil { c.releaseWriter(bw) c.closeConn(cc) return true, err } c.releaseWriter(bw) if c.ReadTimeout > 0 { // Optimization: update read deadline only if more than 25% // of the last read deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. currentTime := time.Now() if currentTime.Sub(cc.lastReadDeadlineTime) > (c.ReadTimeout >> 2) { if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil { c.closeConn(cc) return true, err } cc.lastReadDeadlineTime = currentTime } } if !req.Header.IsGet() && req.Header.IsHead() { resp.SkipBody = true } if c.DisableHeaderNamesNormalizing { resp.Header.DisableNormalizing() } br := c.acquireReader(conn) if err = resp.ReadLimitBody(br, c.MaxResponseBodySize); err != nil { c.releaseReader(br) c.closeConn(cc) if err == io.EOF { return true, err } return false, err } c.releaseReader(br) if resetConnection || req.ConnectionClose() || resp.ConnectionClose() { c.closeConn(cc) } else { c.releaseConn(cc) } return false, err } var ( // ErrNoFreeConns is returned when no free connections available // to the given host. ErrNoFreeConns = errors.New("no free connections available to host") // ErrTimeout is returned from timed out calls. ErrTimeout = errors.New("timeout") // ErrConnectionClosed may be returned from client methods if the server // closes connection before returning the first response byte. // // If you see this error, then either fix the server by returning // 'Connection: close' response header before closing the connection // or add 'Connection: close' request header before sending requests // to broken server. ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " + "Make sure the server returns 'Connection: close' response header before closing the connection") ) func (c *HostClient) acquireConn() (*clientConn, error) { var cc *clientConn createConn := false startCleaner := false var n int c.connsLock.Lock() n = len(c.conns) if n == 0 { maxConns := c.MaxConns if maxConns <= 0 { maxConns = DefaultMaxConnsPerHost } if c.connsCount < maxConns { c.connsCount++ createConn = true } if createConn && c.connsCount == 1 { startCleaner = true } } else { n-- cc = c.conns[n] c.conns = c.conns[:n] } c.connsLock.Unlock() if cc != nil { return cc, nil } if !createConn { return nil, ErrNoFreeConns } conn, err := c.dialHostHard() if err != nil { c.decConnsCount() return nil, err } cc = acquireClientConn(conn) if startCleaner { go c.connsCleaner() } return cc, nil } func (c *HostClient) connsCleaner() { var ( scratch []*clientConn mustStop bool maxIdleConnDuration = c.MaxIdleConnDuration ) if maxIdleConnDuration <= 0 { maxIdleConnDuration = DefaultMaxIdleConnDuration } for { currentTime := time.Now() c.connsLock.Lock() conns := c.conns n := len(conns) i := 0 for i < n && currentTime.Sub(conns[i].lastUseTime) > maxIdleConnDuration { i++ } mustStop = (c.connsCount == i) scratch = append(scratch[:0], conns[:i]...) if i > 0 { m := copy(conns, conns[i:]) for i = m; i < n; i++ { conns[i] = nil } c.conns = conns[:m] } c.connsLock.Unlock() for i, cc := range scratch { c.closeConn(cc) scratch[i] = nil } if mustStop { break } time.Sleep(maxIdleConnDuration) } } func (c *HostClient) closeConn(cc *clientConn) { c.decConnsCount() cc.c.Close() releaseClientConn(cc) } func (c *HostClient) decConnsCount() { c.connsLock.Lock() c.connsCount-- c.connsLock.Unlock() } func acquireClientConn(conn net.Conn) *clientConn { v := clientConnPool.Get() if v == nil { v = &clientConn{} } cc := v.(*clientConn) cc.c = conn cc.createdTime = time.Now() return cc } func releaseClientConn(cc *clientConn) { cc.c = nil clientConnPool.Put(cc) } var clientConnPool sync.Pool func (c *HostClient) releaseConn(cc *clientConn) { cc.lastUseTime = time.Now() c.connsLock.Lock() c.conns = append(c.conns, cc) c.connsLock.Unlock() } func (c *HostClient) acquireWriter(conn net.Conn) *bufio.Writer { v := c.writerPool.Get() if v == nil { n := c.WriteBufferSize if n <= 0 { n = defaultWriteBufferSize } return bufio.NewWriterSize(conn, n) } bw := v.(*bufio.Writer) bw.Reset(conn) return bw } func (c *HostClient) releaseWriter(bw *bufio.Writer) { c.writerPool.Put(bw) } func (c *HostClient) acquireReader(conn net.Conn) *bufio.Reader { v := c.readerPool.Get() if v == nil { n := c.ReadBufferSize if n <= 0 { n = defaultReadBufferSize } return bufio.NewReaderSize(conn, n) } br := v.(*bufio.Reader) br.Reset(conn) return br } func (c *HostClient) releaseReader(br *bufio.Reader) { c.readerPool.Put(br) } func newDefaultTLSConfig() *tls.Config { return &tls.Config{ InsecureSkipVerify: true, ClientSessionCache: tls.NewLRUClientSessionCache(0), } } func (c *HostClient) nextAddr() string { c.addrsLock.Lock() if c.addrs == nil { c.addrs = strings.Split(c.Addr, ",") } addr := c.addrs[0] if len(c.addrs) > 1 { addr = c.addrs[c.addrIdx%uint32(len(c.addrs))] c.addrIdx++ } c.addrsLock.Unlock() return addr } func (c *HostClient) dialHostHard() (conn net.Conn, err error) { // attempt to dial all the available hosts before giving up. c.addrsLock.Lock() n := len(c.addrs) c.addrsLock.Unlock() if n == 0 { // It looks like c.addrs isn't initialized yet. n = 1 } timeout := c.ReadTimeout + c.WriteTimeout if timeout <= 0 { timeout = DefaultDialTimeout } deadline := time.Now().Add(timeout) for n > 0 { addr := c.nextAddr() conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, c.TLSConfig) if err == nil { return conn, nil } if time.Since(deadline) >= 0 { break } n-- } return nil, err } func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) { if dial == nil { if dialDualStack { dial = DialDualStack } else { dial = Dial } addr = addMissingPort(addr, isTLS) } conn, err := dial(addr) if err != nil { return nil, err } if conn == nil { panic("BUG: DialFunc returned (nil, nil)") } if isTLS { if tlsConfig == nil { tlsConfig = newDefaultTLSConfig() } conn = tls.Client(conn, tlsConfig) } return conn, nil } func (c *HostClient) getClientName() []byte { v := c.clientName.Load() var clientName []byte if v == nil { clientName = []byte(c.Name) if len(clientName) == 0 { clientName = defaultUserAgent } c.clientName.Store(clientName) } else { clientName = v.([]byte) } return clientName } func addMissingPort(addr string, isTLS bool) string { n := strings.Index(addr, ":") if n >= 0 { return addr } port := 80 if isTLS { port = 443 } return fmt.Sprintf("%s:%d", addr, port) } // PipelineClient pipelines requests over a single connection to the given Addr. // // This client may be used in highly loaded HTTP-based RPC systems for reducing // context switches and network level overhead. // See https://en.wikipedia.org/wiki/HTTP_pipelining for details. // // It is forbidden copying PipelineClient instances. Create new instances // instead. // // It is safe calling PipelineClient methods from concurrently running // goroutines. type PipelineClient struct { noCopy noCopy // Address of the host to connect to. Addr string // The maximum number of pending pipelined requests to the server. // // DefaultMaxPendingRequests is used by default. MaxPendingRequests int // The maximum delay before sending pipelined requests as a batch // to the server. // // By default requests are sent immediately to the server. MaxBatchDelay time.Duration // Callback for connection establishing to the host. // // Default Dial is used if not set. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 host addresses // if set to true. // // This option is used only if default TCP dialer is used, // i.e. if Dial is blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool // Whether to use TLS (aka SSL or HTTPS) for host connections. IsTLS bool // Optional TLS config. TLSConfig *tls.Config // Idle connection to the host is closed after this duration. // // By default idle connection is closed after // DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration // Buffer size for responses' reading. // This also limits the maximum header size. // // Default buffer size is used if 0. ReadBufferSize int // Buffer size for requests' writing. // // Default buffer size is used if 0. WriteBufferSize int // Maximum duration for full response reading (including body). // // By default response read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for full request writing (including body). // // By default request write timeout is unlimited. WriteTimeout time.Duration // Logger for logging client errors. // // By default standard logger from log package is used. Logger Logger workPool sync.Pool chLock sync.Mutex chW chan *pipelineWork chR chan *pipelineWork } type pipelineWork struct { reqCopy Request respCopy Response req *Request resp *Response t *time.Timer deadline time.Time err error done chan struct{} } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *PipelineClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return c.DoDeadline(req, resp, time.Now().Add(timeout)) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { c.init() timeout := -time.Since(deadline) if timeout < 0 { return ErrTimeout } w := acquirePipelineWork(&c.workPool, timeout) w.req = &w.reqCopy w.resp = &w.respCopy // Make a copy of the request in order to avoid data races on timeouts req.copyToSkipBody(&w.reqCopy) swapRequestBody(req, &w.reqCopy) // Put the request to outgoing queue select { case c.chW <- w: // Fast path: len(c.ch) < cap(c.ch) default: // Slow path select { case c.chW <- w: case <-w.t.C: releasePipelineWork(&c.workPool, w) return ErrTimeout } } // Wait for the response var err error select { case <-w.done: if resp != nil { w.respCopy.copyToSkipBody(resp) swapResponseBody(resp, &w.respCopy) } err = w.err releasePipelineWork(&c.workPool, w) case <-w.t.C: err = ErrTimeout } return err } // Do performs the given http request and sets the corresponding response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Response is ignored if resp is nil. // // ErrNoFreeConns is returned if all HostClient.MaxConns connections // to the host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *PipelineClient) Do(req *Request, resp *Response) error { c.init() w := acquirePipelineWork(&c.workPool, 0) w.req = req if resp != nil { w.resp = resp } else { w.resp = &w.respCopy } // Put the request to outgoing queue select { case c.chW <- w: default: // Try substituting the oldest w with the current one. select { case wOld := <-c.chW: wOld.err = ErrPipelineOverflow wOld.done <- struct{}{} default: } select { case c.chW <- w: default: releasePipelineWork(&c.workPool, w) return ErrPipelineOverflow } } // Wait for the response <-w.done err := w.err releasePipelineWork(&c.workPool, w) return err } // ErrPipelineOverflow may be returned from PipelineClient.Do // if the requests' queue is overflown. var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflown. Increase MaxPendingRequests") // DefaultMaxPendingRequests is the default value // for PipelineClient.MaxPendingRequests. const DefaultMaxPendingRequests = 1024 func (c *PipelineClient) init() { c.chLock.Lock() if c.chR == nil { maxPendingRequests := c.MaxPendingRequests if maxPendingRequests <= 0 { maxPendingRequests = DefaultMaxPendingRequests } c.chR = make(chan *pipelineWork, maxPendingRequests) if c.chW == nil { c.chW = make(chan *pipelineWork, maxPendingRequests) } go func() { if err := c.worker(); err != nil { c.logger().Printf("error in PipelineClient(%q): %s", c.Addr, err) if netErr, ok := err.(net.Error); ok && netErr.Temporary() { // Throttle client reconnections on temporary errors time.Sleep(time.Second) } } c.chLock.Lock() // Do not reset c.chW to nil, since it may contain // pending requests, which could be served on the next // connection to the host. c.chR = nil c.chLock.Unlock() }() } c.chLock.Unlock() } func (c *PipelineClient) worker() error { conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, c.TLSConfig) if err != nil { return err } // Start reader and writer stopW := make(chan struct{}) doneW := make(chan error) go func() { doneW <- c.writer(conn, stopW) }() stopR := make(chan struct{}) doneR := make(chan error) go func() { doneR <- c.reader(conn, stopR) }() // Wait until reader and writer are stopped select { case err = <-doneW: conn.Close() close(stopR) <-doneR case err = <-doneR: conn.Close() close(stopW) <-doneW } // Notify pending readers for len(c.chR) > 0 { w := <-c.chR w.err = errPipelineClientStopped w.done <- struct{}{} } return err } func (c *PipelineClient) writer(conn net.Conn, stopCh <-chan struct{}) error { writeBufferSize := c.WriteBufferSize if writeBufferSize <= 0 { writeBufferSize = defaultWriteBufferSize } bw := bufio.NewWriterSize(conn, writeBufferSize) defer bw.Flush() chR := c.chR chW := c.chW writeTimeout := c.WriteTimeout maxIdleConnDuration := c.MaxIdleConnDuration if maxIdleConnDuration <= 0 { maxIdleConnDuration = DefaultMaxIdleConnDuration } maxBatchDelay := c.MaxBatchDelay var ( stopTimer = time.NewTimer(time.Hour) flushTimer = time.NewTimer(time.Hour) flushTimerCh <-chan time.Time instantTimerCh = make(chan time.Time) w *pipelineWork err error lastWriteDeadlineTime time.Time ) close(instantTimerCh) for { againChW: select { case w = <-chW: // Fast path: len(chW) > 0 default: // Slow path stopTimer.Reset(maxIdleConnDuration) select { case w = <-chW: case <-stopTimer.C: return nil case <-stopCh: return nil case <-flushTimerCh: if err = bw.Flush(); err != nil { return err } flushTimerCh = nil goto againChW } } if !w.deadline.IsZero() && time.Since(w.deadline) >= 0 { w.err = ErrTimeout w.done <- struct{}{} continue } if writeTimeout > 0 { // Optimization: update write deadline only if more than 25% // of the last write deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. currentTime := time.Now() if currentTime.Sub(lastWriteDeadlineTime) > (writeTimeout >> 2) { if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { w.err = err w.done <- struct{}{} return err } lastWriteDeadlineTime = currentTime } } if err = w.req.Write(bw); err != nil { w.err = err w.done <- struct{}{} return err } if flushTimerCh == nil && (len(chW) == 0 || len(chR) == cap(chR)) { if maxBatchDelay > 0 { flushTimer.Reset(maxBatchDelay) flushTimerCh = flushTimer.C } else { flushTimerCh = instantTimerCh } } againChR: select { case chR <- w: // Fast path: len(chR) < cap(chR) default: // Slow path select { case chR <- w: case <-stopCh: w.err = errPipelineClientStopped w.done <- struct{}{} return nil case <-flushTimerCh: if err = bw.Flush(); err != nil { w.err = err w.done <- struct{}{} return err } flushTimerCh = nil goto againChR } } } } func (c *PipelineClient) reader(conn net.Conn, stopCh <-chan struct{}) error { readBufferSize := c.ReadBufferSize if readBufferSize <= 0 { readBufferSize = defaultReadBufferSize } br := bufio.NewReaderSize(conn, readBufferSize) chR := c.chR readTimeout := c.ReadTimeout var ( w *pipelineWork err error lastReadDeadlineTime time.Time ) for { select { case w = <-chR: // Fast path: len(chR) > 0 default: // Slow path select { case w = <-chR: case <-stopCh: return nil } } if readTimeout > 0 { // Optimization: update read deadline only if more than 25% // of the last read deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. currentTime := time.Now() if currentTime.Sub(lastReadDeadlineTime) > (readTimeout >> 2) { if err = conn.SetReadDeadline(currentTime.Add(readTimeout)); err != nil { w.err = err w.done <- struct{}{} return err } lastReadDeadlineTime = currentTime } } if err = w.resp.Read(br); err != nil { w.err = err w.done <- struct{}{} return err } w.done <- struct{}{} } } func (c *PipelineClient) logger() Logger { if c.Logger != nil { return c.Logger } return defaultLogger } // PendingRequests returns the current number of pending requests pipelined // to the server. // // This number may exceed MaxPendingRequests by up to two times, since // the client may keep up to MaxPendingRequests requests in the queue before // sending them to the server. func (c *PipelineClient) PendingRequests() int { c.init() c.chLock.Lock() n := len(c.chR) + len(c.chW) c.chLock.Unlock() return n } var errPipelineClientStopped = errors.New("pipeline client has been stopped") func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork { v := pool.Get() if v == nil { v = &pipelineWork{ done: make(chan struct{}, 1), } } w := v.(*pipelineWork) if timeout > 0 { if w.t == nil { w.t = time.NewTimer(timeout) } else { w.t.Reset(timeout) } w.deadline = time.Now().Add(timeout) } else { w.deadline = zeroTime } return w } func releasePipelineWork(pool *sync.Pool, w *pipelineWork) { if w.t != nil { w.t.Stop() } w.reqCopy.Reset() w.respCopy.Reset() w.req = nil w.resp = nil w.err = nil pool.Put(w) } golang-github-valyala-fasthttp-20160617/client_example_test.go000066400000000000000000000020401273074646000243570ustar00rootroot00000000000000package fasthttp_test import ( "log" "github.com/valyala/fasthttp" ) func ExampleHostClient() { // Perpare a client, which fetches webpages via HTTP proxy listening // on the localhost:8080. c := &fasthttp.HostClient{ Addr: "localhost:8080", } // Fetch google page via local proxy. statusCode, body, err := c.Get(nil, "http://google.com/foo/bar") if err != nil { log.Fatalf("Error when loading google page through local proxy: %s", err) } if statusCode != fasthttp.StatusOK { log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK) } useResponseBody(body) // Fetch foobar page via local proxy. Reuse body buffer. statusCode, body, err = c.Get(body, "http://foobar.com/google/com") if err != nil { log.Fatalf("Error when loading foobar page through local proxy: %s", err) } if statusCode != fasthttp.StatusOK { log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK) } useResponseBody(body) } func useResponseBody(body []byte) { // Do something with body :) } golang-github-valyala-fasthttp-20160617/client_test.go000066400000000000000000000503071273074646000226550ustar00rootroot00000000000000package fasthttp import ( "crypto/tls" "fmt" "io" "net" "os" "runtime" "strings" "sync" "sync/atomic" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) func TestPipelineClientDoSerial(t *testing.T) { testPipelineClientDoConcurrent(t, 1, 0) } func TestPipelineClientDoConcurrent(t *testing.T) { testPipelineClientDoConcurrent(t, 10, 0) } func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) { testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond) } func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { ctx.WriteString("OK") }, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverStopCh) }() c := &PipelineClient{ Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, MaxIdleConnDuration: 23 * time.Millisecond, MaxPendingRequests: 6, MaxBatchDelay: maxBatchDelay, Logger: &customLogger{}, } clientStopCh := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testPipelineClientDo(t, c) clientStopCh <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-clientStopCh: case <-time.After(3 * time.Second): t.Fatalf("timeout") } } if c.PendingRequests() != 0 { t.Fatalf("unexpected number of pending requests: %d. Expecting zero", c.PendingRequests()) } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverStopCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func testPipelineClientDo(t *testing.T, c *PipelineClient) { var err error req := AcquireRequest() req.SetRequestURI("http://foobar/baz") resp := AcquireResponse() for i := 0; i < 10; i++ { if i&1 == 0 { err = c.DoTimeout(req, resp, time.Second) } else { err = c.Do(req, resp) } if err != nil { if err == ErrPipelineOverflow { time.Sleep(10 * time.Millisecond) continue } t.Fatalf("unexpected error on iteration %d: %s", i, err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } body := string(resp.Body()) if body != "OK" { t.Fatalf("unexpected body: %q. Expecting %q", body, "OK") } // sleep for a while, so the connection to the host may expire. if i%5 == 0 { time.Sleep(30 * time.Millisecond) } } ReleaseRequest(req) ReleaseResponse(resp) } func TestClientDoTimeoutDisableNormalizing(t *testing.T) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Response.Header.Set("foo-BAR", "baz") }, DisableHeaderNamesNormalizing: true, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverStopCh) }() c := &Client{ Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, DisableHeaderNamesNormalizing: true, } var req Request req.SetRequestURI("http://aaaai.com/bsdf?sddfsd") var resp Response for i := 0; i < 5; i++ { if err := c.DoTimeout(&req, &resp, time.Second); err != nil { t.Fatalf("unexpected error: %s", err) } hv := resp.Header.Peek("foo-BAR") if string(hv) != "baz" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz") } hv = resp.Header.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("unexpected non-empty header value %q", hv) } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverStopCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestHostClientMaxConnDuration(t *testing.T) { ln := fasthttputil.NewInmemoryListener() connectionCloseCount := uint32(0) s := &Server{ Handler: func(ctx *RequestCtx) { ctx.WriteString("abcd") if ctx.Request.ConnectionClose() { atomic.AddUint32(&connectionCloseCount, 1) } }, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverStopCh) }() c := &HostClient{ Addr: "foobar", Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, MaxConnDuration: 10 * time.Millisecond, } for i := 0; i < 5; i++ { statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK) } if string(body) != "abcd" { t.Fatalf("unexpected body %q. Expecting %q", body, "abcd") } time.Sleep(c.MaxConnDuration) } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverStopCh: case <-time.After(time.Second): t.Fatalf("timeout") } if connectionCloseCount == 0 { t.Fatalf("expecting at least one 'Connection: close' request header") } } func TestHostClientMultipleAddrs(t *testing.T) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Write(ctx.Host()) ctx.SetConnectionClose() }, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverStopCh) }() dialsCount := make(map[string]int) c := &HostClient{ Addr: "foo,bar,baz", Dial: func(addr string) (net.Conn, error) { dialsCount[addr]++ return ln.Dial() }, } for i := 0; i < 9; i++ { statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK) } if string(body) != "foobar" { t.Fatalf("unexpected body %q. Expecting %q", body, "foobar") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverStopCh: case <-time.After(time.Second): t.Fatalf("timeout") } if len(dialsCount) != 3 { t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount)) } for _, k := range []string{"foo", "bar", "baz"} { if dialsCount[k] != 3 { t.Fatalf("unexpected dialsCount for %q. Expecting 3", k) } } } func TestClientFollowRedirects(t *testing.T) { addr := "127.0.0.1:55234" s := &Server{ Handler: func(ctx *RequestCtx) { switch string(ctx.Path()) { case "/foo": u := ctx.URI() u.Update("/xy?z=wer") ctx.Redirect(u.String(), StatusFound) case "/xy": u := ctx.URI() u.Update("/bar") ctx.Redirect(u.String(), StatusFound) default: ctx.Success("text/plain", ctx.Path()) } }, } ln, err := net.Listen("tcp4", addr) if err != nil { t.Fatalf("unexpected error: %s", err) } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverStopCh) }() uri := fmt.Sprintf("http://%s/foo", addr) for i := 0; i < 10; i++ { statusCode, body, err := GetTimeout(nil, uri, time.Second) if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code: %d", statusCode) } if string(body) != "/bar" { t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") } } uri = fmt.Sprintf("http://%s/aaab/sss", addr) for i := 0; i < 10; i++ { statusCode, body, err := Get(nil, uri) if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code: %d", statusCode) } if string(body) != "/aaab/sss" { t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss") } } } func TestClientGetTimeoutSuccess(t *testing.T) { addr := "127.0.0.1:56889" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) } func TestClientGetTimeoutSuccessConcurrent(t *testing.T) { addr := "127.0.0.1:56989" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) }() } wg.Wait() } func TestClientDoTimeoutSuccess(t *testing.T) { addr := "127.0.0.1:63897" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr testClientDoTimeoutSuccess(t, &defaultClient, addr, 100) } func TestClientDoTimeoutSuccessConcurrent(t *testing.T) { addr := "127.0.0.1:63898" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testClientDoTimeoutSuccess(t, &defaultClient, addr, 100) }() } wg.Wait() } func TestClientGetTimeoutError(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil }, } testClientGetTimeoutError(t, c, 100) } func TestClientGetTimeoutErrorConcurrent(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil }, MaxConnsPerHost: 1000, } var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testClientGetTimeoutError(t, c, 100) }() } wg.Wait() } func TestClientDoTimeoutError(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil }, } testClientDoTimeoutError(t, c, 100) } func TestClientDoTimeoutErrorConcurrent(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil }, MaxConnsPerHost: 1000, } var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testClientDoTimeoutError(t, c, 100) }() } wg.Wait() } func testClientDoTimeoutError(t *testing.T, c *Client, n int) { var req Request var resp Response req.SetRequestURI("http://foobar.com/baz") for i := 0; i < n; i++ { err := c.DoTimeout(&req, &resp, time.Millisecond) if err == nil { t.Fatalf("expecting error") } if err != ErrTimeout { t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) } } } func testClientGetTimeoutError(t *testing.T, c *Client, n int) { buf := make([]byte, 10) for i := 0; i < n; i++ { statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond) if err == nil { t.Fatalf("expecting error") } if err != ErrTimeout { t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) } if statusCode != 0 { t.Fatalf("unexpected statusCode=%d. Expecting %d", statusCode, 0) } if body == nil { t.Fatalf("body must be non-nil") } } } type readTimeoutConn struct { net.Conn t time.Duration } func (r *readTimeoutConn) Read(p []byte) (int, error) { time.Sleep(r.t) return 0, io.EOF } func (r *readTimeoutConn) Write(p []byte) (int, error) { return len(p), nil } func (r *readTimeoutConn) Close() error { return nil } func TestClientIdempotentRequest(t *testing.T) { dialsCount := 0 c := &Client{ Dial: func(addr string) (net.Conn, error) { switch dialsCount { case 0: dialsCount++ return &readErrorConn{}, nil case 1: dialsCount++ return &singleReadConn{ s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456", }, nil default: t.Fatalf("unexpected number of dials: %d", dialsCount) } panic("unreachable") }, } statusCode, body, err := c.Get(nil, "http://foobar/a/b") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != 345 { t.Fatalf("unexpected status code: %d. Expecting 345", statusCode) } if string(body) != "0123456" { t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") } var args Args dialsCount = 0 statusCode, body, err = c.Post(nil, "http://foobar/a/b", &args) if err == nil { t.Fatalf("expecting error") } dialsCount = 0 statusCode, body, err = c.Post(nil, "http://foobar/a/b", nil) if err == nil { t.Fatalf("expecting error") } } type readErrorConn struct { net.Conn } func (r *readErrorConn) Read(p []byte) (int, error) { return 0, fmt.Errorf("error") } func (r *readErrorConn) Write(p []byte) (int, error) { return len(p), nil } func (r *readErrorConn) Close() error { return nil } type singleReadConn struct { net.Conn s string n int } func (r *singleReadConn) Read(p []byte) (int, error) { if len(r.s) == r.n { return 0, io.EOF } n := copy(p, []byte(r.s[r.n:])) r.n += n return n, nil } func (r *singleReadConn) Write(p []byte) (int, error) { return len(p), nil } func (r *singleReadConn) Close() error { return nil } func TestClientHTTPSConcurrent(t *testing.T) { addrHTTP := "127.0.0.1:56793" sHTTP := startEchoServer(t, "tcp", addrHTTP) defer sHTTP.Stop() addrHTTPS := "127.0.0.1:56794" sHTTPS := startEchoServerTLS(t, "tcp", addrHTTPS) defer sHTTPS.Stop() var wg sync.WaitGroup for i := 0; i < 4; i++ { wg.Add(1) addr := "http://" + addrHTTP if i&1 != 0 { addr = "https://" + addrHTTPS } go func() { defer wg.Done() testClientGet(t, &defaultClient, addr, 20) testClientPost(t, &defaultClient, addr, 10) }() } wg.Wait() } func TestClientManyServers(t *testing.T) { var addrs []string for i := 0; i < 10; i++ { addr := fmt.Sprintf("127.0.0.1:%d", 56904+i) s := startEchoServer(t, "tcp", addr) defer s.Stop() addrs = append(addrs, addr) } var wg sync.WaitGroup for i := 0; i < 4; i++ { wg.Add(1) addr := "http://" + addrs[i] go func() { defer wg.Done() testClientGet(t, &defaultClient, addr, 20) testClientPost(t, &defaultClient, addr, 10) }() } wg.Wait() } func TestClientGet(t *testing.T) { addr := "127.0.0.1:56789" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr testClientGet(t, &defaultClient, addr, 100) } func TestClientPost(t *testing.T) { addr := "127.0.0.1:56798" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr testClientPost(t, &defaultClient, addr, 100) } func TestClientConcurrent(t *testing.T) { addr := "127.0.0.1:55780" s := startEchoServer(t, "tcp", addr) defer s.Stop() addr = "http://" + addr var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testClientGet(t, &defaultClient, addr, 30) testClientPost(t, &defaultClient, addr, 10) }() } wg.Wait() } func skipIfNotUnix(tb testing.TB) { switch runtime.GOOS { case "android", "nacl", "plan9", "windows": tb.Skipf("%s does not support unix sockets", runtime.GOOS) } if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { tb.Skip("iOS does not support unix, unixgram") } } func TestHostClientGet(t *testing.T) { skipIfNotUnix(t) addr := "TestHostClientGet.unix" s := startEchoServer(t, "unix", addr) defer s.Stop() c := createEchoClient(t, "unix", addr) testHostClientGet(t, c, 100) } func TestHostClientPost(t *testing.T) { skipIfNotUnix(t) addr := "./TestHostClientPost.unix" s := startEchoServer(t, "unix", addr) defer s.Stop() c := createEchoClient(t, "unix", addr) testHostClientPost(t, c, 100) } func TestHostClientConcurrent(t *testing.T) { skipIfNotUnix(t) addr := "./TestHostClientConcurrent.unix" s := startEchoServer(t, "unix", addr) defer s.Stop() c := createEchoClient(t, "unix", addr) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testHostClientGet(t, c, 30) testHostClientPost(t, c, 10) }() } wg.Wait() } func testClientGet(t *testing.T, c clientGetter, addr string, n int) { var buf []byte for i := 0; i < n; i++ { uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) statusCode, body, err := c.Get(buf, uri) buf = body if err != nil { t.Fatalf("unexpected error when doing http request: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } resultURI := string(body) if strings.HasPrefix(uri, "https") { resultURI = uri[:5] + resultURI[4:] } if resultURI != uri { t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri) } } } func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { var req Request var resp Response for i := 0; i < n; i++ { uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) req.SetRequestURI(uri) if err := c.DoTimeout(&req, &resp, time.Second); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } resultURI := string(resp.Body()) if strings.HasPrefix(uri, "https") { resultURI = uri[:5] + resultURI[4:] } if resultURI != uri { t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri) } } } func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { var buf []byte for i := 0; i < n; i++ { uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) statusCode, body, err := c.GetTimeout(buf, uri, time.Second) buf = body if err != nil { t.Fatalf("unexpected error when doing http request: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } resultURI := string(body) if strings.HasPrefix(uri, "https") { resultURI = uri[:5] + resultURI[4:] } if resultURI != uri { t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri) } } } func testClientPost(t *testing.T, c clientPoster, addr string, n int) { var buf []byte var args Args for i := 0; i < n; i++ { uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) args.Set("xx", fmt.Sprintf("yy%d", i)) args.Set("zzz", fmt.Sprintf("qwe_%d", i)) argsS := args.String() statusCode, body, err := c.Post(buf, uri, &args) buf = body if err != nil { t.Fatalf("unexpected error when doing http request: %s", err) } if statusCode != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } s := string(body) if s != argsS { t.Fatalf("unexpected response %q. Expecting %q", s, argsS) } } } func testHostClientGet(t *testing.T, c *HostClient, n int) { testClientGet(t, c, "http://google.com", n) } func testHostClientPost(t *testing.T, c *HostClient, n int) { testClientPost(t, c, "http://post-host.com", n) } type clientPoster interface { Post(dst []byte, uri string, postArgs *Args) (int, []byte, error) } type clientGetter interface { Get(dst []byte, uri string) (int, []byte, error) } func createEchoClient(t *testing.T, network, addr string) *HostClient { return &HostClient{ Addr: addr, Dial: func(addr string) (net.Conn, error) { return net.Dial(network, addr) }, } } type testEchoServer struct { s *Server ln net.Listener ch chan struct{} t *testing.T } func (s *testEchoServer) Stop() { s.ln.Close() select { case <-s.ch: case <-time.After(time.Second): s.t.Fatalf("timeout when waiting for server close") } } func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer { return startEchoServerExt(t, network, addr, true) } func startEchoServer(t *testing.T, network, addr string) *testEchoServer { return startEchoServerExt(t, network, addr, false) } func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEchoServer { if network == "unix" { os.Remove(addr) } var ln net.Listener var err error if isTLS { certFile := "./ssl-cert-snakeoil.pem" keyFile := "./ssl-cert-snakeoil.key" cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { t.Fatalf("Cannot load TLS certificate: %s", err) } tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, } ln, err = tls.Listen(network, addr, tlsConfig) } else { ln, err = net.Listen(network, addr) } if err != nil { t.Fatalf("cannot listen %q: %s", addr, err) } s := &Server{ Handler: func(ctx *RequestCtx) { if ctx.IsGet() { ctx.Success("text/plain", ctx.URI().FullURI()) } else if ctx.IsPost() { ctx.PostArgs().WriteTo(ctx) } }, } ch := make(chan struct{}) go func() { err := s.Serve(ln) if err != nil { t.Fatalf("unexpected error returned from Serve(): %s", err) } close(ch) }() return &testEchoServer{ s: s, ln: ln, ch: ch, t: t, } } golang-github-valyala-fasthttp-20160617/client_timing_test.go000066400000000000000000000371751273074646000242340ustar00rootroot00000000000000package fasthttp import ( "bytes" "fmt" "io/ioutil" "net" "net/http" "runtime" "strings" "sync" "sync/atomic" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) type fakeClientConn struct { net.Conn s []byte n int ch chan struct{} } func (c *fakeClientConn) Write(b []byte) (int, error) { c.ch <- struct{}{} return len(b), nil } func (c *fakeClientConn) Read(b []byte) (int, error) { if c.n == 0 { // wait for request :) <-c.ch } n := 0 for len(b) > 0 { if c.n == len(c.s) { c.n = 0 return n, nil } n = copy(b, c.s[c.n:]) c.n += n b = b[n:] } return n, nil } func (c *fakeClientConn) Close() error { releaseFakeServerConn(c) return nil } func releaseFakeServerConn(c *fakeClientConn) { c.n = 0 fakeClientConnPool.Put(c) } func acquireFakeServerConn(s []byte) *fakeClientConn { v := fakeClientConnPool.Get() if v == nil { c := &fakeClientConn{ s: s, ch: make(chan struct{}, 1), } return c } return v.(*fakeClientConn) } var fakeClientConnPool sync.Pool func BenchmarkClientGetTimeoutFastServer(b *testing.B) { body := []byte("123456789099") s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body)) c := &Client{ Dial: func(addr string) (net.Conn, error) { return acquireFakeServerConn(s), nil }, } nn := uint32(0) b.RunParallel(func(pb *testing.PB) { url := fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)) var statusCode int var bodyBuf []byte var err error for pb.Next() { statusCode, bodyBuf, err = c.GetTimeout(bodyBuf[:0], url, time.Second) if err != nil { b.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { b.Fatalf("unexpected status code: %d", statusCode) } if !bytes.Equal(bodyBuf, body) { b.Fatalf("unexpected response body: %q. Expected %q", bodyBuf, body) } } }) } func BenchmarkClientDoFastServer(b *testing.B) { body := []byte("012345678912") s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body)) c := &Client{ Dial: func(addr string) (net.Conn, error) { return acquireFakeServerConn(s), nil }, MaxConnsPerHost: runtime.GOMAXPROCS(-1), } nn := uint32(0) b.RunParallel(func(pb *testing.PB) { var req Request var resp Response req.Header.SetRequestURI(fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1))) for pb.Next() { if err := c.Do(&req, &resp); err != nil { b.Fatalf("unexpected error: %s", err) } if resp.Header.StatusCode() != StatusOK { b.Fatalf("unexpected status code: %d", resp.Header.StatusCode()) } if !bytes.Equal(resp.Body(), body) { b.Fatalf("unexpected response body: %q. Expected %q", resp.Body(), body) } } }) } func BenchmarkNetHTTPClientDoFastServer(b *testing.B) { body := []byte("012345678912") s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body)) c := &http.Client{ Transport: &http.Transport{ Dial: func(network, addr string) (net.Conn, error) { return acquireFakeServerConn(s), nil }, MaxIdleConnsPerHost: runtime.GOMAXPROCS(-1), }, } nn := uint32(0) b.RunParallel(func(pb *testing.PB) { req, err := http.NewRequest("GET", fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil) if err != nil { b.Fatalf("unexpected error: %s", err) } for pb.Next() { resp, err := c.Do(req) if err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d", resp.StatusCode) } respBody, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { b.Fatalf("unexpected error when reading response body: %s", err) } if !bytes.Equal(respBody, body) { b.Fatalf("unexpected response body: %q. Expected %q", respBody, body) } } }) } func fasthttpEchoHandler(ctx *RequestCtx) { ctx.Success("text/plain", ctx.RequestURI()) } func nethttpEchoHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.Write([]byte(r.RequestURI)) } func BenchmarkClientGetEndToEnd1TCP(b *testing.B) { benchmarkClientGetEndToEndTCP(b, 1) } func BenchmarkClientGetEndToEnd10TCP(b *testing.B) { benchmarkClientGetEndToEndTCP(b, 10) } func BenchmarkClientGetEndToEnd100TCP(b *testing.B) { benchmarkClientGetEndToEndTCP(b, 100) } func benchmarkClientGetEndToEndTCP(b *testing.B, parallelism int) { addr := "127.0.0.1:8543" ln, err := net.Listen("tcp4", addr) if err != nil { b.Fatalf("cannot listen %q: %s", addr, err) } ch := make(chan struct{}) go func() { if err := Serve(ln, fasthttpEchoHandler); err != nil { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &Client{ MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism, } requestURI := "/foo/bar?baz=123" url := "http://" + addr + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { var buf []byte for pb.Next() { statusCode, body, err := c.Get(buf, url) if err != nil { b.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } if string(body) != requestURI { b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) } buf = body } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkNetHTTPClientGetEndToEnd1TCP(b *testing.B) { benchmarkNetHTTPClientGetEndToEndTCP(b, 1) } func BenchmarkNetHTTPClientGetEndToEnd10TCP(b *testing.B) { benchmarkNetHTTPClientGetEndToEndTCP(b, 10) } func BenchmarkNetHTTPClientGetEndToEnd100TCP(b *testing.B) { benchmarkNetHTTPClientGetEndToEndTCP(b, 100) } func benchmarkNetHTTPClientGetEndToEndTCP(b *testing.B, parallelism int) { addr := "127.0.0.1:8542" ln, err := net.Listen("tcp4", addr) if err != nil { b.Fatalf("cannot listen %q: %s", addr, err) } ch := make(chan struct{}) go func() { if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &http.Client{ Transport: &http.Transport{ MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1), }, } requestURI := "/foo/bar?baz=123" url := "http://" + addr + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { for pb.Next() { resp, err := c.Get(url) if err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK) } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { b.Fatalf("unexpected error when reading response body: %s", err) } if string(body) != requestURI { b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) } } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkClientGetEndToEnd1Inmemory(b *testing.B) { benchmarkClientGetEndToEndInmemory(b, 1) } func BenchmarkClientGetEndToEnd10Inmemory(b *testing.B) { benchmarkClientGetEndToEndInmemory(b, 10) } func BenchmarkClientGetEndToEnd100Inmemory(b *testing.B) { benchmarkClientGetEndToEndInmemory(b, 100) } func BenchmarkClientGetEndToEnd1000Inmemory(b *testing.B) { benchmarkClientGetEndToEndInmemory(b, 1000) } func BenchmarkClientGetEndToEnd10KInmemory(b *testing.B) { benchmarkClientGetEndToEndInmemory(b, 10000) } func benchmarkClientGetEndToEndInmemory(b *testing.B, parallelism int) { ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := Serve(ln, fasthttpEchoHandler); err != nil { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &Client{ MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism, Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, } requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { var buf []byte for pb.Next() { statusCode, body, err := c.Get(buf, url) if err != nil { b.Fatalf("unexpected error: %s", err) } if statusCode != StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } if string(body) != requestURI { b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) } buf = body } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkNetHTTPClientGetEndToEnd1Inmemory(b *testing.B) { benchmarkNetHTTPClientGetEndToEndInmemory(b, 1) } func BenchmarkNetHTTPClientGetEndToEnd10Inmemory(b *testing.B) { benchmarkNetHTTPClientGetEndToEndInmemory(b, 10) } func BenchmarkNetHTTPClientGetEndToEnd100Inmemory(b *testing.B) { benchmarkNetHTTPClientGetEndToEndInmemory(b, 100) } func BenchmarkNetHTTPClientGetEndToEnd1000Inmemory(b *testing.B) { benchmarkNetHTTPClientGetEndToEndInmemory(b, 1000) } func benchmarkNetHTTPClientGetEndToEndInmemory(b *testing.B, parallelism int) { ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &http.Client{ Transport: &http.Transport{ Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() }, MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1), }, } requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { for pb.Next() { resp, err := c.Get(url) if err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK) } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { b.Fatalf("unexpected error when reading response body: %s", err) } if string(body) != requestURI { b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) } } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkClientEndToEndBigResponse1Inmemory(b *testing.B) { benchmarkClientEndToEndBigResponseInmemory(b, 1) } func BenchmarkClientEndToEndBigResponse10Inmemory(b *testing.B) { benchmarkClientEndToEndBigResponseInmemory(b, 10) } func benchmarkClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) { bigResponse := createFixedBody(1024 * 1024) h := func(ctx *RequestCtx) { ctx.SetContentType("text/plain") ctx.Write(bigResponse) } ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := Serve(ln, h); err != nil { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &Client{ MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism, Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, } requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { var req Request req.SetRequestURI(url) var resp Response for pb.Next() { if err := c.DoTimeout(&req, &resp, 5*time.Second); err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } body := resp.Body() if !bytes.Equal(bigResponse, body) { b.Fatalf("unexpected response %q. Expecting %q", body, bigResponse) } } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkNetHTTPClientEndToEndBigResponse1Inmemory(b *testing.B) { benchmarkNetHTTPClientEndToEndBigResponseInmemory(b, 1) } func BenchmarkNetHTTPClientEndToEndBigResponse10Inmemory(b *testing.B) { benchmarkNetHTTPClientEndToEndBigResponseInmemory(b, 10) } func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) { bigResponse := createFixedBody(1024 * 1024) h := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.Write(bigResponse) } ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := http.Serve(ln, http.HandlerFunc(h)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { b.Fatalf("error when serving requests: %s", err) } close(ch) }() c := &http.Client{ Transport: &http.Transport{ Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() }, MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1), }, Timeout: 5 * time.Second, } requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { req, err := http.NewRequest("GET", url, nil) if err != nil { b.Fatalf("unexpected error: %s", err) } for pb.Next() { resp, err := c.Do(req) if err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode != http.StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK) } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { b.Fatalf("unexpected error when reading response body: %s", err) } if !bytes.Equal(bigResponse, body) { b.Fatalf("unexpected response %q. Expecting %q", body, bigResponse) } } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } func BenchmarkPipelineClient1(b *testing.B) { benchmarkPipelineClient(b, 1) } func BenchmarkPipelineClient10(b *testing.B) { benchmarkPipelineClient(b, 10) } func BenchmarkPipelineClient100(b *testing.B) { benchmarkPipelineClient(b, 100) } func BenchmarkPipelineClient1000(b *testing.B) { benchmarkPipelineClient(b, 1000) } func benchmarkPipelineClient(b *testing.B, parallelism int) { h := func(ctx *RequestCtx) { ctx.WriteString("foobar") } ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := Serve(ln, h); err != nil { b.Fatalf("error when serving requests: %s", err) } close(ch) }() var clients []*PipelineClient for i := 0; i < runtime.GOMAXPROCS(-1); i++ { c := &PipelineClient{ Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, ReadBufferSize: 1024 * 1024, WriteBufferSize: 1024 * 1024, MaxPendingRequests: parallelism, } clients = append(clients, c) } clientID := uint32(0) requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { n := atomic.AddUint32(&clientID, 1) c := clients[n%uint32(len(clients))] var req Request req.SetRequestURI(url) var resp Response for pb.Next() { if err := c.Do(&req, &resp); err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } body := resp.Body() if string(body) != "foobar" { b.Fatalf("unexpected response %q. Expecting %q", body, "foobar") } } }) ln.Close() select { case <-ch: case <-time.After(time.Second): b.Fatalf("server wasn't stopped") } } golang-github-valyala-fasthttp-20160617/compress.go000066400000000000000000000145671273074646000222030ustar00rootroot00000000000000package fasthttp import ( "fmt" "io" "os" "sync" "github.com/klauspost/compress/flate" "github.com/klauspost/compress/gzip" "github.com/klauspost/compress/zlib" ) // Supported compression levels. const ( CompressNoCompression = flate.NoCompression CompressBestSpeed = flate.BestSpeed CompressBestCompression = flate.BestCompression CompressDefaultCompression = flate.DefaultCompression ) func acquireGzipReader(r io.Reader) (*gzip.Reader, error) { v := gzipReaderPool.Get() if v == nil { return gzip.NewReader(r) } zr := v.(*gzip.Reader) if err := zr.Reset(r); err != nil { return nil, err } return zr, nil } func releaseGzipReader(zr *gzip.Reader) { zr.Close() gzipReaderPool.Put(zr) } var gzipReaderPool sync.Pool func acquireFlateReader(r io.Reader) (io.ReadCloser, error) { v := flateReaderPool.Get() if v == nil { zr, err := zlib.NewReader(r) if err != nil { return nil, err } return zr, nil } zr := v.(io.ReadCloser) if err := resetFlateReader(zr, r); err != nil { return nil, err } return zr, nil } func releaseFlateReader(zr io.ReadCloser) { zr.Close() flateReaderPool.Put(zr) } func resetFlateReader(zr io.ReadCloser, r io.Reader) error { zrr, ok := zr.(zlib.Resetter) if !ok { panic("BUG: zlib.Reader doesn't implement zlib.Resetter???") } return zrr.Reset(r, nil) } var flateReaderPool sync.Pool func acquireGzipWriter(w io.Writer, level int) *gzipWriter { p := gzipWriterPoolMap[level] if p == nil { panic(fmt.Sprintf("BUG: unexpected compression level passed: %d. See compress/gzip for supported levels", level)) } v := p.Get() if v == nil { zw, err := gzip.NewWriterLevel(w, level) if err != nil { panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err)) } return &gzipWriter{ Writer: zw, p: p, } } zw := v.(*gzipWriter) zw.Reset(w) return zw } func releaseGzipWriter(zw *gzipWriter) { zw.Close() zw.p.Put(zw) } type gzipWriter struct { *gzip.Writer p *sync.Pool } var gzipWriterPoolMap = func() map[int]*sync.Pool { // Initialize pools for all the compression levels defined // in https://golang.org/pkg/compress/gzip/#pkg-constants . m := make(map[int]*sync.Pool, 11) m[-1] = &sync.Pool{} for i := 0; i < 10; i++ { m[i] = &sync.Pool{} } return m }() // AppendGzipBytesLevel appends gzipped src to dst using the given // compression level and returns the resulting dst. // // Supported compression levels are: // // * CompressNoCompression // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression func AppendGzipBytesLevel(dst, src []byte, level int) []byte { w := &byteSliceWriter{dst} WriteGzipLevel(w, src, level) return w.b } // WriteGzipLevel writes gzipped p to w using the given compression level // and returns the number of compressed bytes written to w. // // Supported compression levels are: // // * CompressNoCompression // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) { zw := acquireGzipWriter(w, level) n, err := zw.Write(p) releaseGzipWriter(zw) return n, err } // WriteGzip writes gzipped p to w and returns the number of compressed // bytes written to w. func WriteGzip(w io.Writer, p []byte) (int, error) { return WriteGzipLevel(w, p, CompressDefaultCompression) } // AppendGzipBytes appends gzipped src to dst and returns the resulting dst. func AppendGzipBytes(dst, src []byte) []byte { return AppendGzipBytesLevel(dst, src, CompressDefaultCompression) } // WriteGunzip writes ungzipped p to w and returns the number of uncompressed // bytes written to w. func WriteGunzip(w io.Writer, p []byte) (int, error) { r := &byteSliceReader{p} zr, err := acquireGzipReader(r) if err != nil { return 0, err } n, err := copyZeroAlloc(w, zr) releaseGzipReader(zr) nn := int(n) if int64(nn) != n { return 0, fmt.Errorf("too much data gunzipped: %d", n) } return nn, err } // WriteInflate writes inflated p to w and returns the number of uncompressed // bytes written to w. func WriteInflate(w io.Writer, p []byte) (int, error) { r := &byteSliceReader{p} zr, err := acquireFlateReader(r) if err != nil { return 0, err } n, err := copyZeroAlloc(w, zr) releaseFlateReader(zr) nn := int(n) if int64(nn) != n { return 0, fmt.Errorf("too much data inflated: %d", n) } return nn, err } // AppendGunzipBytes append gunzipped src to dst and returns the resulting dst. func AppendGunzipBytes(dst, src []byte) ([]byte, error) { w := &byteSliceWriter{dst} _, err := WriteGunzip(w, src) return w.b, err } type byteSliceWriter struct { b []byte } func (w *byteSliceWriter) Write(p []byte) (int, error) { w.b = append(w.b, p...) return len(p), nil } type byteSliceReader struct { b []byte } func (r *byteSliceReader) Read(p []byte) (int, error) { if len(r.b) == 0 { return 0, io.EOF } n := copy(p, r.b) r.b = r.b[n:] return n, nil } func acquireFlateWriter(w io.Writer, level int) *flateWriter { p := flateWriterPoolMap[level] if p == nil { panic(fmt.Sprintf("BUG: unexpected compression level passed: %d. See compress/flate for supported levels", level)) } v := p.Get() if v == nil { zw, err := zlib.NewWriterLevel(w, level) if err != nil { panic(fmt.Sprintf("BUG: unexpected error in zlib.NewWriterLevel(%d): %s", level, err)) } return &flateWriter{ Writer: zw, p: p, } } zw := v.(*flateWriter) zw.Reset(w) return zw } func releaseFlateWriter(zw *flateWriter) { zw.Close() zw.p.Put(zw) } type flateWriter struct { *zlib.Writer p *sync.Pool } var flateWriterPoolMap = func() map[int]*sync.Pool { // Initialize pools for all the compression levels defined // in https://golang.org/pkg/compress/flate/#pkg-constants . m := make(map[int]*sync.Pool, 11) m[-1] = &sync.Pool{} for i := 0; i < 10; i++ { m[i] = &sync.Pool{} } return m }() func isFileCompressible(f *os.File, minCompressRatio float64) bool { // Try compressing the first 4kb of of the file // and see if it can be compressed by more than // the given minCompressRatio. b := AcquireByteBuffer() zw := acquireGzipWriter(b, CompressDefaultCompression) lr := &io.LimitedReader{ R: f, N: 4096, } _, err := copyZeroAlloc(zw, lr) releaseGzipWriter(zw) f.Seek(0, 0) if err != nil { return false } n := 4096 - lr.N zn := len(b.B) ReleaseByteBuffer(b) return float64(zn) < float64(n)*minCompressRatio } golang-github-valyala-fasthttp-20160617/compress_test.go000066400000000000000000000046461273074646000232370ustar00rootroot00000000000000package fasthttp import ( "bytes" "io/ioutil" "testing" ) func TestGzipBytes(t *testing.T) { testGzipBytes(t, "") testGzipBytes(t, "foobar") testGzipBytes(t, "выфаодлодл одлфываыв sd2 k34") } func testGzipBytes(t *testing.T, s string) { prefix := []byte("foobar") gzippedS := AppendGzipBytes(prefix, []byte(s)) if !bytes.Equal(gzippedS[:len(prefix)], prefix) { t.Fatalf("unexpected prefix when compressing %q: %q. Expecting %q", s, gzippedS[:len(prefix)], prefix) } gunzippedS, err := AppendGunzipBytes(prefix, gzippedS[len(prefix):]) if err != nil { t.Fatalf("unexpected error when uncompressing %q: %s", s, err) } if !bytes.Equal(gunzippedS[:len(prefix)], prefix) { t.Fatalf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, gunzippedS[:len(prefix)], prefix) } gunzippedS = gunzippedS[len(prefix):] if string(gunzippedS) != s { t.Fatalf("unexpected uncompressed string %q. Expecting %q", gunzippedS, s) } } func TestGzipCompress(t *testing.T) { testGzipCompress(t, "") testGzipCompress(t, "foobar") testGzipCompress(t, "ajjnkn asdlkjfqoijfw jfqkwj foj eowjiq") } func TestFlateCompress(t *testing.T) { testFlateCompress(t, "") testFlateCompress(t, "foobar") testFlateCompress(t, "adf asd asd fasd fasd") } func testGzipCompress(t *testing.T, s string) { var buf bytes.Buffer zw := acquireGzipWriter(&buf, CompressDefaultCompression) if _, err := zw.Write([]byte(s)); err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } releaseGzipWriter(zw) zr, err := acquireGzipReader(&buf) if err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } body, err := ioutil.ReadAll(zr) if err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } if string(body) != s { t.Fatalf("unexpected string after decompression: %q. Expecting %q", body, s) } releaseGzipReader(zr) } func testFlateCompress(t *testing.T, s string) { var buf bytes.Buffer zw := acquireFlateWriter(&buf, CompressDefaultCompression) if _, err := zw.Write([]byte(s)); err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } releaseFlateWriter(zw) zr, err := acquireFlateReader(&buf) if err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } body, err := ioutil.ReadAll(zr) if err != nil { t.Fatalf("unexpected error: %s. s=%q", err, s) } if string(body) != s { t.Fatalf("unexpected string after decompression: %q. Expecting %q", body, s) } releaseFlateReader(zr) } golang-github-valyala-fasthttp-20160617/cookie.go000066400000000000000000000206511273074646000216100ustar00rootroot00000000000000package fasthttp import ( "bytes" "errors" "io" "sync" "time" ) var zeroTime time.Time var ( // CookieExpireDelete may be set on Cookie.Expire for expiring the given cookie. CookieExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) // CookieExpireUnlimited indicates that the cookie doesn't expire. CookieExpireUnlimited = zeroTime ) // AcquireCookie returns an empty Cookie object from the pool. // // The returned object may be returned back to the pool with ReleaseCookie. // This allows reducing GC load. func AcquireCookie() *Cookie { return cookiePool.Get().(*Cookie) } // ReleaseCookie returns the Cookie object acquired with AcquireCookie back // to the pool. // // Do not access released Cookie object, otherwise data races may occur. func ReleaseCookie(c *Cookie) { c.Reset() cookiePool.Put(c) } var cookiePool = &sync.Pool{ New: func() interface{} { return &Cookie{} }, } // Cookie represents HTTP response cookie. // // Do not copy Cookie objects. Create new object and use CopyTo instead. // // Cookie instance MUST NOT be used from concurrently running goroutines. type Cookie struct { noCopy noCopy key []byte value []byte expire time.Time domain []byte path []byte httpOnly bool secure bool bufKV argsKV buf []byte } // CopyTo copies src cookie to c. func (c *Cookie) CopyTo(src *Cookie) { c.Reset() c.key = append(c.key[:0], src.key...) c.value = append(c.value[:0], src.value...) c.expire = src.expire c.domain = append(c.domain[:0], src.domain...) c.path = append(c.path[:0], src.path...) c.httpOnly = src.httpOnly c.secure = src.secure } // HTTPOnly returns true if the cookie is http only. func (c *Cookie) HTTPOnly() bool { return c.httpOnly } // SetHTTPOnly sets cookie's httpOnly flag to the given value. func (c *Cookie) SetHTTPOnly(httpOnly bool) { c.httpOnly = httpOnly } // Secure returns true if the cookie is secure. func (c *Cookie) Secure() bool { return c.secure } // SetSecure sets cookie's secure flag to the given value. func (c *Cookie) SetSecure(secure bool) { c.secure = secure } // Path returns cookie path. func (c *Cookie) Path() []byte { return c.path } // SetPath sets cookie path. func (c *Cookie) SetPath(path string) { c.buf = append(c.buf[:0], path...) c.path = normalizePath(c.path, c.buf) } // SetPathBytes sets cookie path. func (c *Cookie) SetPathBytes(path []byte) { c.buf = append(c.buf[:0], path...) c.path = normalizePath(c.path, c.buf) } // Domain returns cookie domain. // // The returned domain is valid until the next Cookie modification method call. func (c *Cookie) Domain() []byte { return c.domain } // SetDomain sets cookie domain. func (c *Cookie) SetDomain(domain string) { c.domain = append(c.domain[:0], domain...) } // SetDomainBytes sets cookie domain. func (c *Cookie) SetDomainBytes(domain []byte) { c.domain = append(c.domain[:0], domain...) } // Expire returns cookie expiration time. // // CookieExpireUnlimited is returned if cookie doesn't expire func (c *Cookie) Expire() time.Time { expire := c.expire if expire.IsZero() { expire = CookieExpireUnlimited } return expire } // SetExpire sets cookie expiration time. // // Set expiration time to CookieExpireDelete for expiring (deleting) // the cookie on the client. // // By default cookie lifetime is limited by browser session. func (c *Cookie) SetExpire(expire time.Time) { c.expire = expire } // Value returns cookie value. // // The returned value is valid until the next Cookie modification method call. func (c *Cookie) Value() []byte { return c.value } // SetValue sets cookie value. func (c *Cookie) SetValue(value string) { c.value = append(c.value[:0], value...) } // SetValueBytes sets cookie value. func (c *Cookie) SetValueBytes(value []byte) { c.value = append(c.value[:0], value...) } // Key returns cookie name. // // The returned value is valid until the next Cookie modification method call. func (c *Cookie) Key() []byte { return c.key } // SetKey sets cookie name. func (c *Cookie) SetKey(key string) { c.key = append(c.key[:0], key...) } // SetKeyBytes sets cookie name. func (c *Cookie) SetKeyBytes(key []byte) { c.key = append(c.key[:0], key...) } // Reset clears the cookie. func (c *Cookie) Reset() { c.key = c.key[:0] c.value = c.value[:0] c.expire = zeroTime c.domain = c.domain[:0] c.path = c.path[:0] c.httpOnly = false c.secure = false } // AppendBytes appends cookie representation to dst and returns // the extended dst. func (c *Cookie) AppendBytes(dst []byte) []byte { if len(c.key) > 0 { dst = AppendQuotedArg(dst, c.key) dst = append(dst, '=') } dst = AppendQuotedArg(dst, c.value) if !c.expire.IsZero() { c.bufKV.value = AppendHTTPDate(c.bufKV.value[:0], c.expire) dst = append(dst, ';', ' ') dst = append(dst, strCookieExpires...) dst = append(dst, '=') dst = append(dst, c.bufKV.value...) } if len(c.domain) > 0 { dst = appendCookiePart(dst, strCookieDomain, c.domain) } if len(c.path) > 0 { dst = appendCookiePart(dst, strCookiePath, c.path) } if c.httpOnly { dst = append(dst, ';', ' ') dst = append(dst, strCookieHTTPOnly...) } if c.secure { dst = append(dst, ';', ' ') dst = append(dst, strCookieSecure...) } return dst } // Cookie returns cookie representation. // // The returned value is valid until the next call to Cookie methods. func (c *Cookie) Cookie() []byte { c.buf = c.AppendBytes(c.buf[:0]) return c.buf } // String returns cookie representation. func (c *Cookie) String() string { return string(c.Cookie()) } // WriteTo writes cookie representation to w. // // WriteTo implements io.WriterTo interface. func (c *Cookie) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(c.Cookie()) return int64(n), err } var errNoCookies = errors.New("no cookies found") // Parse parses Set-Cookie header. func (c *Cookie) Parse(src string) error { c.buf = append(c.buf[:0], src...) return c.ParseBytes(c.buf) } // ParseBytes parses Set-Cookie header. func (c *Cookie) ParseBytes(src []byte) error { c.Reset() var s cookieScanner s.b = src kv := &c.bufKV if !s.next(kv, true) { return errNoCookies } c.key = append(c.key[:0], kv.key...) c.value = append(c.value[:0], kv.value...) for s.next(kv, false) { if len(kv.key) == 0 && len(kv.value) == 0 { continue } switch string(kv.key) { case "expires": v := b2s(kv.value) exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) if err != nil { return err } c.expire = exptime case "domain": c.domain = append(c.domain[:0], kv.value...) case "path": c.path = append(c.path[:0], kv.value...) case "": switch string(kv.value) { case "HttpOnly": c.httpOnly = true case "secure": c.secure = true } } } return nil } func appendCookiePart(dst, key, value []byte) []byte { dst = append(dst, ';', ' ') dst = append(dst, key...) dst = append(dst, '=') return append(dst, value...) } func getCookieKey(dst, src []byte) []byte { n := bytes.IndexByte(src, '=') if n >= 0 { src = src[:n] } return decodeCookieArg(dst, src, true) } func appendRequestCookieBytes(dst []byte, cookies []argsKV) []byte { for i, n := 0, len(cookies); i < n; i++ { kv := &cookies[i] if len(kv.key) > 0 { dst = AppendQuotedArg(dst, kv.key) dst = append(dst, '=') } dst = AppendQuotedArg(dst, kv.value) if i+1 < n { dst = append(dst, ';', ' ') } } return dst } func parseRequestCookies(cookies []argsKV, src []byte) []argsKV { var s cookieScanner s.b = src var kv *argsKV cookies, kv = allocArg(cookies) for s.next(kv, true) { if len(kv.key) > 0 || len(kv.value) > 0 { cookies, kv = allocArg(cookies) } } return releaseArg(cookies) } type cookieScanner struct { b []byte } func (s *cookieScanner) next(kv *argsKV, decode bool) bool { b := s.b if len(b) == 0 { return false } isKey := true k := 0 for i, c := range b { switch c { case '=': if isKey { isKey = false kv.key = decodeCookieArg(kv.key, b[:i], decode) k = i + 1 } case ';': if isKey { kv.key = kv.key[:0] } kv.value = decodeCookieArg(kv.value, b[k:i], decode) s.b = b[i+1:] return true } } if isKey { kv.key = kv.key[:0] } kv.value = decodeCookieArg(kv.value, b[k:], decode) s.b = b[len(b):] return true } func decodeCookieArg(dst, src []byte, decode bool) []byte { for len(src) > 0 && src[0] == ' ' { src = src[1:] } for len(src) > 0 && src[len(src)-1] == ' ' { src = src[:len(src)-1] } if !decode { return append(dst[:0], src...) } return decodeArg(dst, src, true) } golang-github-valyala-fasthttp-20160617/cookie_test.go000066400000000000000000000143271273074646000226520ustar00rootroot00000000000000package fasthttp import ( "strings" "testing" "time" ) func TestCookieSecureHttpOnly(t *testing.T) { var c Cookie if err := c.Parse("foo=bar; HttpOnly; secure"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.Secure() { t.Fatalf("secure must be set") } if !c.HTTPOnly() { t.Fatalf("HttpOnly must be set") } s := c.String() if !strings.Contains(s, "; secure") { t.Fatalf("missing secure flag in cookie %q", s) } if !strings.Contains(s, "; HttpOnly") { t.Fatalf("missing HttpOnly flag in cookie %q", s) } } func TestCookieSecure(t *testing.T) { var c Cookie if err := c.Parse("foo=bar; secure"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.Secure() { t.Fatalf("secure must be set") } s := c.String() if !strings.Contains(s, "; secure") { t.Fatalf("missing secure flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.HTTPOnly() { t.Fatalf("Unexpected secure flag set") } s = c.String() if strings.Contains(s, "secure") { t.Fatalf("unexpected secure flag in cookie %q", s) } } func TestCookieHttpOnly(t *testing.T) { var c Cookie if err := c.Parse("foo=bar; HttpOnly"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.HTTPOnly() { t.Fatalf("HTTPOnly must be set") } s := c.String() if !strings.Contains(s, "; HttpOnly") { t.Fatalf("missing HttpOnly flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.HTTPOnly() { t.Fatalf("Unexpected HTTPOnly flag set") } s = c.String() if strings.Contains(s, "HttpOnly") { t.Fatalf("unexpected HttpOnly flag in cookie %q", s) } } func TestCookieAcquireReleaseSequential(t *testing.T) { testCookieAcquireRelease(t) } func TestCookieAcquireReleaseConcurrent(t *testing.T) { ch := make(chan struct{}, 10) for i := 0; i < 10; i++ { go func() { testCookieAcquireRelease(t) ch <- struct{}{} }() } for i := 0; i < 10; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testCookieAcquireRelease(t *testing.T) { c := AcquireCookie() key := "foo" c.SetKey(key) value := "bar" c.SetValue(value) domain := "foo.bar.com" c.SetDomain(domain) path := "/foi/bar/aaa" c.SetPath(path) s := c.String() c.Reset() if err := c.Parse(s); err != nil { t.Fatalf("unexpected error: %s", err) } if string(c.Key()) != key { t.Fatalf("unexpected cookie name %q. Expecting %q", c.Key(), key) } if string(c.Value()) != value { t.Fatalf("unexpected cookie value %q. Expecting %q", c.Value(), value) } if string(c.Domain()) != domain { t.Fatalf("unexpected domain %q. Expecting %q", c.Domain(), domain) } if string(c.Path()) != path { t.Fatalf("unexpected path %q. Expecting %q", c.Path(), path) } ReleaseCookie(c) } func TestCookieParse(t *testing.T) { testCookieParse(t, "foo", "foo") testCookieParse(t, "foo=bar", "foo=bar") testCookieParse(t, "foo=", "foo=") testCookieParse(t, "foo=bar; domain=aaa.com; path=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") testCookieParse(t, " xxx = yyy ; path=/a/b;;;domain=foobar.com ; expires= Tue, 10 Nov 2009 23:00:00 GMT ; ;;", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") } func testCookieParse(t *testing.T, s, expectedS string) { var c Cookie if err := c.Parse(s); err != nil { t.Fatalf("unexpected error: %s", err) } result := string(c.Cookie()) if result != expectedS { t.Fatalf("unexpected cookies %q. Expected %q. Original %q", result, expectedS, s) } } func TestCookieAppendBytes(t *testing.T) { c := &Cookie{} testCookieAppendBytes(t, c, "", "bar", "bar") testCookieAppendBytes(t, c, "foo", "", "foo=") testCookieAppendBytes(t, c, "ффф", "12 лодлы", "%D1%84%D1%84%D1%84=12%20%D0%BB%D0%BE%D0%B4%D0%BB%D1%8B") c.SetDomain("foobar.com") testCookieAppendBytes(t, c, "a", "b", "a=b; domain=foobar.com") c.SetPath("/a/b") testCookieAppendBytes(t, c, "aa", "bb", "aa=bb; domain=foobar.com; path=/a/b") c.SetExpire(CookieExpireDelete) testCookieAppendBytes(t, c, "xxx", "yyy", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") } func testCookieAppendBytes(t *testing.T, c *Cookie, key, value, expectedS string) { c.SetKey(key) c.SetValue(value) result := string(c.AppendBytes(nil)) if result != expectedS { t.Fatalf("Unexpected cookie %q. Expected %q", result, expectedS) } } func TestParseRequestCookies(t *testing.T) { testParseRequestCookies(t, "", "") testParseRequestCookies(t, "=", "") testParseRequestCookies(t, "foo", "foo") testParseRequestCookies(t, "=foo", "foo") testParseRequestCookies(t, "bar=", "bar=") testParseRequestCookies(t, "xxx=aa;bb=c; =d; ;;e=g", "xxx=aa; bb=c; d; e=g") testParseRequestCookies(t, "a;b;c; d=1;d=2", "a; b; c; d=1; d=2") testParseRequestCookies(t, " %D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc ;s%20s=aaa ", "%D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc; s%20s=aaa") } func testParseRequestCookies(t *testing.T, s, expectedS string) { cookies := parseRequestCookies(nil, []byte(s)) ss := string(appendRequestCookieBytes(nil, cookies)) if ss != expectedS { t.Fatalf("Unexpected cookies after parsing: %q. Expected %q. String to parse %q", ss, expectedS, s) } } func TestAppendRequestCookieBytes(t *testing.T) { testAppendRequestCookieBytes(t, "=", "") testAppendRequestCookieBytes(t, "foo=", "foo=") testAppendRequestCookieBytes(t, "=bar", "bar") testAppendRequestCookieBytes(t, "привет=a b;c&s s=aaa", "%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc; s%20s=aaa") } func testAppendRequestCookieBytes(t *testing.T, s, expectedS string) { var cookies []argsKV for _, ss := range strings.Split(s, "&") { tmp := strings.SplitN(ss, "=", 2) if len(tmp) != 2 { t.Fatalf("Cannot find '=' in %q, part of %q", ss, s) } cookies = append(cookies, argsKV{ key: []byte(tmp[0]), value: []byte(tmp[1]), }) } prefix := "foobar" result := string(appendRequestCookieBytes([]byte(prefix), cookies)) if result[:len(prefix)] != prefix { t.Fatalf("unexpected prefix %q. Expected %q for cookie %q", result[:len(prefix)], prefix, s) } result = result[len(prefix):] if result != expectedS { t.Fatalf("Unexpected result %q. Expected %q for cookie %q", result, expectedS, s) } } golang-github-valyala-fasthttp-20160617/cookie_timing_test.go000066400000000000000000000014651273074646000242200ustar00rootroot00000000000000package fasthttp import ( "testing" ) func BenchmarkCookieParseMin(b *testing.B) { var c Cookie s := []byte("xxx=yyy") for i := 0; i < b.N; i++ { if err := c.ParseBytes(s); err != nil { b.Fatalf("unexpected error when parsing cookies: %s", err) } } } func BenchmarkCookieParseNoExpires(b *testing.B) { var c Cookie s := []byte("xxx=yyy; domain=foobar.com; path=/a/b") for i := 0; i < b.N; i++ { if err := c.ParseBytes(s); err != nil { b.Fatalf("unexpected error when parsing cookies: %s", err) } } } func BenchmarkCookieParseFull(b *testing.B) { var c Cookie s := []byte("xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") for i := 0; i < b.N; i++ { if err := c.ParseBytes(s); err != nil { b.Fatalf("unexpected error when parsing cookies: %s", err) } } } golang-github-valyala-fasthttp-20160617/doc.go000066400000000000000000000031431273074646000211010ustar00rootroot00000000000000/* Package fasthttp provides fast HTTP server and client API. Fasthttp provides the following features: * Optimized for speed. Easily handles more than 100K qps and more than 1M concurrent keep-alive connections on modern hardware. * Optimized for low memory usage. * Easy 'Connection: Upgrade' support via RequestCtx.Hijack. * Server supports requests' pipelining. Multiple requests may be read from a single network packet and multiple responses may be sent in a single network packet. This may be useful for highly loaded REST services. * Server provides the following anti-DoS limits: * The number of concurrent connections. * The number of concurrent connections per client IP. * The number of requests per connection. * Request read timeout. * Response write timeout. * Maximum request header size. * Maximum request body size. * Maximum request execution time. * Maximum keep-alive connection lifetime. * Early filtering out non-GET requests. * A lot of additional useful info is exposed to request handler: * Server and client address. * Per-request logger. * Unique request id. * Request start time. * Connection start time. * Request sequence number for the current connection. * Client supports automatic retry on idempotent requests' failure. * Fasthttp API is designed with the ability to extend existing client and server implementations or to write custom client and server implementations from scratch. */ package fasthttp golang-github-valyala-fasthttp-20160617/examples/000077500000000000000000000000001273074646000216225ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/examples/README.md000066400000000000000000000001341273074646000230770ustar00rootroot00000000000000# Code examples * [HelloWorld server](helloworldserver) * [Static file server](fileserver) golang-github-valyala-fasthttp-20160617/examples/fileserver/000077500000000000000000000000001273074646000237705ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/examples/fileserver/.gitignore000066400000000000000000000000131273074646000257520ustar00rootroot00000000000000fileserver golang-github-valyala-fasthttp-20160617/examples/fileserver/Makefile000066400000000000000000000002221273074646000254240ustar00rootroot00000000000000fileserver: clean go get -u github.com/valyala/fasthttp go get -u github.com/valyala/fasthttp/expvarhandler go build clean: rm -f fileserver golang-github-valyala-fasthttp-20160617/examples/fileserver/README.md000066400000000000000000000041321273074646000252470ustar00rootroot00000000000000# Static file server example * Serves files from the given directory. * Supports transparent response compression. * Supports byte range responses. * Generates directory index pages. * Supports TLS (aka SSL or HTTPS). * Supports virtual hosts. * Exports various stats on /stats path. # How to build ``` make ``` # How to run ``` ./fileserver -h ./fileserver -addr=tcp.addr.to.listen:to -dir=/path/to/directory/to/serve ``` # fileserver vs nginx performance comparison Serving default nginx path (`/usr/share/nginx/html` on ubuntu). * nginx ``` $ ./wrk -t 4 -c 16 -d 10 http://localhost:80 Running 10s test @ http://localhost:80 4 threads and 16 connections Thread Stats Avg Stdev Max +/- Stdev Latency 397.76us 1.08ms 20.23ms 95.19% Req/Sec 21.20k 2.49k 31.34k 79.65% 850220 requests in 10.10s, 695.65MB read Requests/sec: 84182.71 Transfer/sec: 68.88MB ``` * fileserver ``` $ ./wrk -t 4 -c 16 -d 10 http://localhost:8080 Running 10s test @ http://localhost:8080 4 threads and 16 connections Thread Stats Avg Stdev Max +/- Stdev Latency 447.99us 1.59ms 27.20ms 94.79% Req/Sec 37.13k 3.99k 47.86k 76.00% 1478457 requests in 10.02s, 1.03GB read Requests/sec: 147597.06 Transfer/sec: 105.15MB ``` 8 pipelined requests * nginx ``` $ ./wrk -s pipeline.lua -t 4 -c 16 -d 10 http://localhost:80 -- 8 Running 10s test @ http://localhost:80 4 threads and 16 connections Thread Stats Avg Stdev Max +/- Stdev Latency 1.34ms 2.15ms 30.91ms 92.16% Req/Sec 33.54k 7.36k 108.12k 76.81% 1339908 requests in 10.10s, 1.07GB read Requests/sec: 132705.81 Transfer/sec: 108.58MB ``` * fileserver ``` $ ./wrk -s pipeline.lua -t 4 -c 16 -d 10 http://localhost:8080 -- 8 Running 10s test @ http://localhost:8080 4 threads and 16 connections Thread Stats Avg Stdev Max +/- Stdev Latency 2.08ms 6.33ms 88.26ms 92.83% Req/Sec 116.54k 14.66k 167.98k 69.00% 4642226 requests in 10.03s, 3.23GB read Requests/sec: 462769.41 Transfer/sec: 329.67MB ``` golang-github-valyala-fasthttp-20160617/examples/fileserver/fileserver.go000066400000000000000000000073231273074646000264720ustar00rootroot00000000000000// Example static file server. // // Serves static files from the given directory. // Exports various stats at /stats . package main import ( "expvar" "flag" "log" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/expvarhandler" ) var ( addr = flag.String("addr", "localhost:8080", "TCP address to listen to") addrTLS = flag.String("addrTLS", "", "TCP address to listen to TLS (aka SSL or HTTPS) requests. Leave empty for disabling TLS") byteRange = flag.Bool("byteRange", false, "Enables byte range requests if set to true") certFile = flag.String("certFile", "./ssl-cert-snakeoil.pem", "Path to TLS certificate file") compress = flag.Bool("compress", false, "Enables transparent response compression if set to true") dir = flag.String("dir", "/usr/share/nginx/html", "Directory to serve static files from") generateIndexPages = flag.Bool("generateIndexPages", true, "Whether to generate directory index pages") keyFile = flag.String("keyFile", "./ssl-cert-snakeoil.key", "Path to TLS key file") vhost = flag.Bool("vhost", false, "Enables virtual hosting by prepending the requested path with the requested hostname") ) func main() { // Parse command-line flags. flag.Parse() // Setup FS handler fs := &fasthttp.FS{ Root: *dir, IndexNames: []string{"index.html"}, GenerateIndexPages: *generateIndexPages, Compress: *compress, AcceptByteRange: *byteRange, } if *vhost { fs.PathRewrite = fasthttp.NewVHostPathRewriter(0) } fsHandler := fs.NewRequestHandler() // Create RequestHandler serving server stats on /stats and files // on other requested paths. // /stats output may be filtered using regexps. For example: // // * /stats?r=fs will show only stats (expvars) containing 'fs' // in their names. requestHandler := func(ctx *fasthttp.RequestCtx) { switch string(ctx.Path()) { case "/stats": expvarhandler.ExpvarHandler(ctx) default: fsHandler(ctx) updateFSCounters(ctx) } } // Start HTTP server. if len(*addr) > 0 { log.Printf("Starting HTTP server on %q", *addr) go func() { if err := fasthttp.ListenAndServe(*addr, requestHandler); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } }() } // Start HTTPS server. if len(*addrTLS) > 0 { log.Printf("Starting HTTPS server on %q", *addrTLS) go func() { if err := fasthttp.ListenAndServeTLS(*addrTLS, *certFile, *keyFile, requestHandler); err != nil { log.Fatalf("error in ListenAndServeTLS: %s", err) } }() } log.Printf("Serving files from directory %q", *dir) log.Printf("See stats at http://%s/stats", *addr) // Wait forever. select {} } func updateFSCounters(ctx *fasthttp.RequestCtx) { // Increment the number of fsHandler calls. fsCalls.Add(1) // Update other stats counters resp := &ctx.Response switch resp.StatusCode() { case fasthttp.StatusOK: fsOKResponses.Add(1) fsResponseBodyBytes.Add(int64(resp.Header.ContentLength())) case fasthttp.StatusNotModified: fsNotModifiedResponses.Add(1) case fasthttp.StatusNotFound: fsNotFoundResponses.Add(1) default: fsOtherResponses.Add(1) } } // Various counters - see https://golang.org/pkg/expvar/ for details. var ( // Counter for total number of fs calls fsCalls = expvar.NewInt("fsCalls") // Counters for various response status codes fsOKResponses = expvar.NewInt("fsOKResponses") fsNotModifiedResponses = expvar.NewInt("fsNotModifiedResponses") fsNotFoundResponses = expvar.NewInt("fsNotFoundResponses") fsOtherResponses = expvar.NewInt("fsOtherResponses") // Total size in bytes for OK response bodies served. fsResponseBodyBytes = expvar.NewInt("fsResponseBodyBytes") ) golang-github-valyala-fasthttp-20160617/examples/fileserver/ssl-cert-snakeoil.key000066400000000000000000000032501273074646000300410ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG 3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT 6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt 9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV c257YgaWmjK9uB0Y2r2VxS0G -----END PRIVATE KEY----- golang-github-valyala-fasthttp-20160617/examples/fileserver/ssl-cert-snakeoil.pem000066400000000000000000000017551273074646000300420ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te +z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY 6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd -----END CERTIFICATE----- golang-github-valyala-fasthttp-20160617/examples/helloworldserver/000077500000000000000000000000001273074646000252245ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/examples/helloworldserver/.gitignore000066400000000000000000000000211273074646000272050ustar00rootroot00000000000000helloworldserver golang-github-valyala-fasthttp-20160617/examples/helloworldserver/Makefile000066400000000000000000000001511273074646000266610ustar00rootroot00000000000000helloworldserver: clean go get -u github.com/valyala/fasthttp go build clean: rm -f helloworldserver golang-github-valyala-fasthttp-20160617/examples/helloworldserver/README.md000066400000000000000000000003531273074646000265040ustar00rootroot00000000000000# HelloWorld server example * Displays various request info. * Sets response headers and cookies. * Supports transparent compression. # How to build ``` make ``` # How to run ``` ./helloworldserver -addr=tcp.addr.to.listen:to ``` golang-github-valyala-fasthttp-20160617/examples/helloworldserver/helloworldserver.go000066400000000000000000000027761273074646000311710ustar00rootroot00000000000000package main import ( "flag" "fmt" "log" "github.com/valyala/fasthttp" ) var ( addr = flag.String("addr", ":8080", "TCP address to listen to") compress = flag.Bool("compress", false, "Whether to enable transparent response compression") ) func main() { flag.Parse() h := requestHandler if *compress { h = fasthttp.CompressHandler(h) } if err := fasthttp.ListenAndServe(*addr, h); err != nil { log.Fatalf("Error in ListenAndServe: %s", err) } } func requestHandler(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hello, world!\n\n") fmt.Fprintf(ctx, "Request method is %q\n", ctx.Method()) fmt.Fprintf(ctx, "RequestURI is %q\n", ctx.RequestURI()) fmt.Fprintf(ctx, "Requested path is %q\n", ctx.Path()) fmt.Fprintf(ctx, "Host is %q\n", ctx.Host()) fmt.Fprintf(ctx, "Query string is %q\n", ctx.QueryArgs()) fmt.Fprintf(ctx, "User-Agent is %q\n", ctx.UserAgent()) fmt.Fprintf(ctx, "Connection has been established at %s\n", ctx.ConnTime()) fmt.Fprintf(ctx, "Request has been started at %s\n", ctx.Time()) fmt.Fprintf(ctx, "Serial request number for the current connection is %d\n", ctx.ConnRequestNum()) fmt.Fprintf(ctx, "Your ip is %q\n\n", ctx.RemoteIP()) fmt.Fprintf(ctx, "Raw request is:\n---CUT---\n%s\n---CUT---", &ctx.Request) ctx.SetContentType("text/plain; charset=utf8") // Set arbitrary headers ctx.Response.Header.Set("X-My-Header", "my-header-value") // Set cookies var c fasthttp.Cookie c.SetKey("cookie-name") c.SetValue("cookie-value") ctx.Response.Header.SetCookie(&c) } golang-github-valyala-fasthttp-20160617/expvarhandler/000077500000000000000000000000001273074646000226475ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/expvarhandler/expvar.go000066400000000000000000000025601273074646000245060ustar00rootroot00000000000000// Package expvarhandler provides fasthttp-compatible request handler // serving expvars. package expvarhandler import ( "expvar" "fmt" "regexp" "github.com/valyala/fasthttp" ) var ( expvarHandlerCalls = expvar.NewInt("expvarHandlerCalls") expvarRegexpErrors = expvar.NewInt("expvarRegexpErrors") ) // ExpvarHandler dumps json representation of expvars to http response. // // Expvars may be filtered by regexp provided via 'r' query argument. // // See https://golang.org/pkg/expvar/ for details. func ExpvarHandler(ctx *fasthttp.RequestCtx) { expvarHandlerCalls.Add(1) ctx.Response.Reset() r, err := getExpvarRegexp(ctx) if err != nil { expvarRegexpErrors.Add(1) fmt.Fprintf(ctx, "Error when obtaining expvar regexp: %s", err) ctx.SetStatusCode(fasthttp.StatusBadRequest) return } fmt.Fprintf(ctx, "{\n") first := true expvar.Do(func(kv expvar.KeyValue) { if !first { fmt.Fprintf(ctx, ",\n") } if r.MatchString(kv.Key) { first = false fmt.Fprintf(ctx, "\t%q: %s", kv.Key, kv.Value) } }) fmt.Fprintf(ctx, "\n}\n") ctx.SetContentType("application/json; charset=utf-8") } func getExpvarRegexp(ctx *fasthttp.RequestCtx) (*regexp.Regexp, error) { r := string(ctx.QueryArgs().Peek("r")) if len(r) == 0 { r = "." } rr, err := regexp.Compile(r) if err != nil { return nil, fmt.Errorf("cannot parse r=%q: %s", r, err) } return rr, nil } golang-github-valyala-fasthttp-20160617/expvarhandler/expvar_test.go000066400000000000000000000026501273074646000255450ustar00rootroot00000000000000package expvarhandler import ( "encoding/json" "expvar" "strings" "testing" "github.com/valyala/fasthttp" ) func TestExpvarHandlerBasic(t *testing.T) { expvar.Publish("customVar", expvar.Func(func() interface{} { return "foobar" })) var ctx fasthttp.RequestCtx expvarHandlerCalls.Set(0) ExpvarHandler(&ctx) body := ctx.Response.Body() var m map[string]interface{} if err := json.Unmarshal(body, &m); err != nil { t.Fatalf("unexpected error: %s", err) } if _, ok := m["cmdline"]; !ok { t.Fatalf("cannot locate cmdline expvar") } if _, ok := m["memstats"]; !ok { t.Fatalf("cannot locate memstats expvar") } v := m["customVar"] sv, ok := v.(string) if !ok { t.Fatalf("unexpected custom var type %T. Expecting string", v) } if sv != "foobar" { t.Fatalf("unexpected custom var value: %q. Expecting %q", v, "foobar") } v = m["expvarHandlerCalls"] fv, ok := v.(float64) if !ok { t.Fatalf("unexpected expvarHandlerCalls type %T. Expecting float64", v) } if int(fv) != 1 { t.Fatalf("unexpected value for expvarHandlerCalls: %v. Expecting %v", fv, 1) } } func TestExpvarHandlerRegexp(t *testing.T) { var ctx fasthttp.RequestCtx ctx.QueryArgs().Set("r", "cmd") ExpvarHandler(&ctx) body := string(ctx.Response.Body()) if !strings.Contains(body, `"cmdline"`) { t.Fatalf("missing 'cmdline' expvar") } if strings.Contains(body, `"memstats"`) { t.Fatalf("unexpected memstats expvar found") } } golang-github-valyala-fasthttp-20160617/fasthttpadaptor/000077500000000000000000000000001273074646000232145ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor.go000066400000000000000000000073501273074646000252020ustar00rootroot00000000000000// Package fasthttpadaptor provides helper functions for converting net/http // request handlers to fasthttp request handlers. package fasthttpadaptor import ( "io" "net/http" "net/url" "github.com/valyala/fasthttp" ) // NewFastHTTPHandlerFunc wraps net/http handler func to fasthttp // request handler, so it can be passed to fasthttp server. // // While this function may be used for easy switching from net/http to fasthttp, // it has the following drawbacks comparing to using manually written fasthttp // request handler: // // * A lot of useful functionality provided by fasthttp is missing // from net/http handler. // * net/http -> fasthttp handler conversion has some overhead, // so the returned handler will be always slower than manually written // fasthttp handler. // // So it is advisable using this function only for quick net/http -> fasthttp // switching. Then manually convert net/http handlers to fasthttp handlers // according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp . func NewFastHTTPHandlerFunc(h http.HandlerFunc) fasthttp.RequestHandler { return NewFastHTTPHandler(h) } // NewFastHTTPHandler wraps net/http handler to fasthttp request handler, // so it can be passed to fasthttp server. // // While this function may be used for easy switching from net/http to fasthttp, // it has the following drawbacks comparing to using manually written fasthttp // request handler: // // * A lot of useful functionality provided by fasthttp is missing // from net/http handler. // * net/http -> fasthttp handler conversion has some overhead, // so the returned handler will be always slower than manually written // fasthttp handler. // // So it is advisable using this function only for quick net/http -> fasthttp // switching. Then manually convert net/http handlers to fasthttp handlers // according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp . func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { var r http.Request body := ctx.PostBody() r.Method = string(ctx.Method()) r.Proto = "HTTP/1.1" r.ProtoMajor = 1 r.ProtoMinor = 1 r.RequestURI = string(ctx.RequestURI()) r.ContentLength = int64(len(body)) r.Host = string(ctx.Host()) r.RemoteAddr = ctx.RemoteAddr().String() hdr := make(http.Header) ctx.Request.Header.VisitAll(func(k, v []byte) { hdr.Set(string(k), string(v)) }) r.Header = hdr r.Body = &netHTTPBody{body} rURL, err := url.ParseRequestURI(r.RequestURI) if err != nil { ctx.Logger().Printf("cannot parse requestURI %q: %s", r.RequestURI, err) ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } r.URL = rURL var w netHTTPResponseWriter h.ServeHTTP(&w, &r) ctx.SetStatusCode(w.StatusCode()) for k, vv := range w.Header() { for _, v := range vv { ctx.Response.Header.Set(k, v) } } ctx.Write(w.body) } } type netHTTPBody struct { b []byte } func (r *netHTTPBody) Read(p []byte) (int, error) { if len(r.b) == 0 { return 0, io.EOF } n := copy(p, r.b) r.b = r.b[n:] return n, nil } func (r *netHTTPBody) Close() error { r.b = r.b[:0] return nil } type netHTTPResponseWriter struct { statusCode int h http.Header body []byte } func (w *netHTTPResponseWriter) StatusCode() int { if w.statusCode == 0 { return http.StatusOK } return w.statusCode } func (w *netHTTPResponseWriter) Header() http.Header { if w.h == nil { w.h = make(http.Header) } return w.h } func (w *netHTTPResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode } func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { w.body = append(w.body, p...) return len(p), nil } golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor_test.go000066400000000000000000000074571273074646000262510ustar00rootroot00000000000000package fasthttpadaptor import ( "fmt" "io/ioutil" "net" "net/http" "net/url" "reflect" "testing" "github.com/valyala/fasthttp" ) func TestNewFastHTTPHandler(t *testing.T) { expectedMethod := "POST" expectedProto := "HTTP/1.1" expectedProtoMajor := 1 expectedProtoMinor := 1 expectedRequestURI := "/foo/bar?baz=123" expectedBody := "body 123 foo bar baz" expectedContentLength := len(expectedBody) expectedHost := "foobar.com" expectedRemoteAddr := "1.2.3.4:6789" expectedHeader := map[string]string{ "Foo-Bar": "baz", "Abc": "defg", "XXX-Remote-Addr": "123.43.4543.345", } expectedURL, err := url.ParseRequestURI(expectedRequestURI) if err != nil { t.Fatalf("unexpected error: %s", err) } callsCount := 0 nethttpH := func(w http.ResponseWriter, r *http.Request) { callsCount++ if r.Method != expectedMethod { t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod) } if r.Proto != expectedProto { t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto) } if r.ProtoMajor != expectedProtoMajor { t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor) } if r.ProtoMinor != expectedProtoMinor { t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor) } if r.RequestURI != expectedRequestURI { t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI) } if r.ContentLength != int64(expectedContentLength) { t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength) } if r.Host != expectedHost { t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost) } if r.RemoteAddr != expectedRemoteAddr { t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr) } body, err := ioutil.ReadAll(r.Body) r.Body.Close() if err != nil { t.Fatalf("unexpected error when reading request body: %s", err) } if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } if !reflect.DeepEqual(r.URL, expectedURL) { t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL) } for k, expectedV := range expectedHeader { v := r.Header.Get(k) if v != expectedV { t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV) } } w.Header().Set("Header1", "value1") w.Header().Set("Header2", "value2") w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "request body is %q", body) } fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH)) var ctx fasthttp.RequestCtx var req fasthttp.Request req.Header.SetMethod(expectedMethod) req.SetRequestURI(expectedRequestURI) req.Header.SetHost(expectedHost) req.BodyWriter().Write([]byte(expectedBody)) for k, v := range expectedHeader { req.Header.Set(k, v) } remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr) if err != nil { t.Fatalf("unexpected error: %s", err) } ctx.Init(&req, remoteAddr, nil) fasthttpH(&ctx) if callsCount != 1 { t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount) } resp := &ctx.Response if resp.StatusCode() != fasthttp.StatusBadRequest { t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusBadRequest) } if string(resp.Header.Peek("Header1")) != "value1" { t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1") } if string(resp.Header.Peek("Header2")) != "value2" { t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2") } expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) if string(resp.Body()) != expectedResponseBody { t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody) } } golang-github-valyala-fasthttp-20160617/fasthttputil/000077500000000000000000000000001273074646000225375ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/fasthttputil/doc.go000066400000000000000000000001261273074646000236320ustar00rootroot00000000000000// Package fasthttputil provides utility functions for fasthttp. package fasthttputil golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener.go000066400000000000000000000036061273074646000266370ustar00rootroot00000000000000package fasthttputil import ( "fmt" "net" "sync" ) // InmemoryListener provides in-memory dialer<->net.Listener implementation. // // It may be used either for fast in-process client<->server communcations // without network stack overhead or for client<->server tests. type InmemoryListener struct { lock sync.Mutex closed bool conns chan net.Conn } // NewInmemoryListener returns new in-memory dialer<->net.Listener. func NewInmemoryListener() *InmemoryListener { return &InmemoryListener{ conns: make(chan net.Conn, 1024), } } // Accept implements net.Listener's Accept. // // It is safe calling Accept from concurrently running goroutines. // // Accept returns new connection per each Dial call. func (ln *InmemoryListener) Accept() (net.Conn, error) { c, ok := <-ln.conns if !ok { return nil, fmt.Errorf("InmemoryListener is already closed: use of closed network connection") } return c, nil } // Close implements net.Listener's Close. func (ln *InmemoryListener) Close() error { var err error ln.lock.Lock() if !ln.closed { close(ln.conns) ln.closed = true } else { err = fmt.Errorf("InmemoryListener is already closed") } ln.lock.Unlock() return err } // Addr implements net.Listener's Addr. func (ln *InmemoryListener) Addr() net.Addr { return &net.UnixAddr{ Name: "InmemoryListener", Net: "memory", } } // Dial creates new client<->server connection, enqueues server side // of the connection to Accept and returns client side of the connection. // // It is safe calling Dial from concurrently running goroutines. func (ln *InmemoryListener) Dial() (net.Conn, error) { pc := NewPipeConns() cConn := pc.Conn1() sConn := pc.Conn2() ln.lock.Lock() if !ln.closed { ln.conns <- sConn } else { sConn.Close() cConn.Close() cConn = nil } ln.lock.Unlock() if cConn == nil { return nil, fmt.Errorf("InmemoryListener is already closed") } return cConn, nil } golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_test.go000066400000000000000000000037151273074646000276770ustar00rootroot00000000000000package fasthttputil import ( "bytes" "fmt" "testing" "time" ) func TestInmemoryListener(t *testing.T) { ln := NewInmemoryListener() ch := make(chan struct{}) for i := 0; i < 10; i++ { go func(n int) { conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } defer conn.Close() req := fmt.Sprintf("request_%d", n) nn, err := conn.Write([]byte(req)) if err != nil { t.Fatalf("unexpected error: %s", err) } if nn != len(req) { t.Fatalf("unexpected number of bytes written: %d. Expecting %d", nn, len(req)) } buf := make([]byte, 30) nn, err = conn.Read(buf) if err != nil { t.Fatalf("unexpected error: %s", err) } buf = buf[:nn] resp := fmt.Sprintf("response_%d", n) if nn != len(resp) { t.Fatalf("unexpected number of bytes read: %d. Expecting %d", nn, len(resp)) } if string(buf) != resp { t.Fatalf("unexpected response %q. Expecting %q", buf, resp) } ch <- struct{}{} }(i) } serverCh := make(chan struct{}) go func() { for { conn, err := ln.Accept() if err != nil { close(serverCh) return } defer conn.Close() buf := make([]byte, 30) n, err := conn.Read(buf) if err != nil { t.Fatalf("unexpected error: %s", err) } buf = buf[:n] if !bytes.HasPrefix(buf, []byte("request_")) { t.Fatalf("unexpected request prefix %q. Expecting %q", buf, "request_") } resp := fmt.Sprintf("response_%s", buf[len("request_"):]) n, err = conn.Write([]byte(resp)) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != len(resp) { t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(resp)) } } }() for i := 0; i < 10; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_timing_test.go000066400000000000000000000054201273074646000312410ustar00rootroot00000000000000package fasthttputil_test import ( "net" "testing" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) // BenchmarkPlainStreaming measures end-to-end plaintext streaming performance // for fasthttp client and server. // // It issues http requests over a small number of keep-alive connections. func BenchmarkPlainStreaming(b *testing.B) { benchmark(b, streamingHandler, false) } // BenchmarkPlainHandshake measures end-to-end plaintext handshake performance // for fasthttp client and server. // // It re-establishes new connection per each http request. func BenchmarkPlainHandshake(b *testing.B) { benchmark(b, handshakeHandler, false) } // BenchmarkTLSStreaming measures end-to-end TLS streaming performance // for fasthttp client and server. // // It issues http requests over a small number of TLS keep-alive connections. func BenchmarkTLSStreaming(b *testing.B) { benchmark(b, streamingHandler, true) } // BenchmarkTLSHandshake measures end-to-end TLS handshake performance // for fasthttp client and server. // // It re-establishes new TLS connection per each http request. func BenchmarkTLSHandshake(b *testing.B) { benchmark(b, handshakeHandler, true) } func benchmark(b *testing.B, h fasthttp.RequestHandler, isTLS bool) { ln := fasthttputil.NewInmemoryListener() serverStopCh := startServer(b, ln, h, isTLS) c := newClient(ln, isTLS) b.RunParallel(func(pb *testing.PB) { runRequests(b, pb, c) }) ln.Close() <-serverStopCh } func streamingHandler(ctx *fasthttp.RequestCtx) { ctx.WriteString("foobar") } func handshakeHandler(ctx *fasthttp.RequestCtx) { streamingHandler(ctx) // Explicitly close connection after each response. ctx.SetConnectionClose() } func startServer(b *testing.B, ln *fasthttputil.InmemoryListener, h fasthttp.RequestHandler, isTLS bool) <-chan struct{} { ch := make(chan struct{}) go func() { var err error if isTLS { err = fasthttp.ServeTLS(ln, certFile, keyFile, h) } else { err = fasthttp.Serve(ln, h) } if err != nil { b.Fatalf("unexpected error in server: %s", err) } close(ch) }() return ch } const ( certFile = "./ssl-cert-snakeoil.pem" keyFile = "./ssl-cert-snakeoil.key" ) func newClient(ln *fasthttputil.InmemoryListener, isTLS bool) *fasthttp.HostClient { return &fasthttp.HostClient{ Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, IsTLS: isTLS, } } func runRequests(b *testing.B, pb *testing.PB, c *fasthttp.HostClient) { var req fasthttp.Request req.SetRequestURI("http://foo.bar/baz") var resp fasthttp.Response for pb.Next() { if err := c.Do(&req, &resp); err != nil { b.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != fasthttp.StatusOK { b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK) } } } golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns.go000066400000000000000000000075321273074646000250730ustar00rootroot00000000000000package fasthttputil import ( "errors" "io" "net" "sync" "time" ) // NewPipeConns returns new bi-directonal connection pipe. func NewPipeConns() *PipeConns { ch1 := make(chan *byteBuffer, 4) ch2 := make(chan *byteBuffer, 4) pc := &PipeConns{ stopCh: make(chan struct{}), } pc.c1.rCh = ch1 pc.c1.wCh = ch2 pc.c2.rCh = ch2 pc.c2.wCh = ch1 pc.c1.pc = pc pc.c2.pc = pc return pc } // PipeConns provides bi-directional connection pipe, // which use in-process memory as a transport. // // PipeConns must be created by calling NewPipeConns. // // PipeConns has the following additional features comparing to connections // returned from net.Pipe(): // // * It is faster. // * It buffers Write calls, so there is no need to have concurrent goroutine // calling Read in order to unblock each Write call. type PipeConns struct { c1 pipeConn c2 pipeConn stopCh chan struct{} stopChLock sync.Mutex } // Conn1 returns the first end of bi-directional pipe. // // Data written to Conn1 may be read from Conn2. // Data written to Conn2 may be read from Conn1. func (pc *PipeConns) Conn1() net.Conn { return &pc.c1 } // Conn2 returns the second end of bi-directional pipe. // // Data written to Conn2 may be read from Conn1. // Data written to Conn1 may be read from Conn2. func (pc *PipeConns) Conn2() net.Conn { return &pc.c2 } // Close closes pipe connections. func (pc *PipeConns) Close() error { pc.stopChLock.Lock() select { case <-pc.stopCh: default: close(pc.stopCh) } pc.stopChLock.Unlock() return nil } type pipeConn struct { b *byteBuffer bb []byte rCh chan *byteBuffer wCh chan *byteBuffer pc *PipeConns } func (c *pipeConn) Write(p []byte) (int, error) { b := acquireByteBuffer() b.b = append(b.b[:0], p...) select { case <-c.pc.stopCh: releaseByteBuffer(b) return 0, errConnectionClosed default: } select { case c.wCh <- b: default: select { case c.wCh <- b: case <-c.pc.stopCh: releaseByteBuffer(b) return 0, errConnectionClosed } } return len(p), nil } func (c *pipeConn) Read(p []byte) (int, error) { mayBlock := true nn := 0 for len(p) > 0 { n, err := c.read(p, mayBlock) nn += n if err != nil { if !mayBlock && err == errWouldBlock { err = nil } return nn, err } p = p[n:] mayBlock = false } return nn, nil } func (c *pipeConn) read(p []byte, mayBlock bool) (int, error) { if len(c.bb) == 0 { if err := c.readNextByteBuffer(mayBlock); err != nil { return 0, err } } n := copy(p, c.bb) c.bb = c.bb[n:] return n, nil } func (c *pipeConn) readNextByteBuffer(mayBlock bool) error { releaseByteBuffer(c.b) c.b = nil select { case c.b = <-c.rCh: default: if !mayBlock { return errWouldBlock } select { case c.b = <-c.rCh: case <-c.pc.stopCh: return io.EOF } } c.bb = c.b.b return nil } var ( errWouldBlock = errors.New("would block") errConnectionClosed = errors.New("connection closed") errNoDeadlines = errors.New("deadline not supported") ) func (c *pipeConn) Close() error { return c.pc.Close() } func (c *pipeConn) LocalAddr() net.Addr { return pipeAddr(0) } func (c *pipeConn) RemoteAddr() net.Addr { return pipeAddr(0) } func (c *pipeConn) SetDeadline(t time.Time) error { return errNoDeadlines } func (c *pipeConn) SetReadDeadline(t time.Time) error { return c.SetDeadline(t) } func (c *pipeConn) SetWriteDeadline(t time.Time) error { return c.SetDeadline(t) } type pipeAddr int func (pipeAddr) Network() string { return "pipe" } func (pipeAddr) String() string { return "pipe" } type byteBuffer struct { b []byte } func acquireByteBuffer() *byteBuffer { return byteBufferPool.Get().(*byteBuffer) } func releaseByteBuffer(b *byteBuffer) { if b != nil { byteBufferPool.Put(b) } } var byteBufferPool = &sync.Pool{ New: func() interface{} { return &byteBuffer{ b: make([]byte, 1024), } }, } golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns_test.go000066400000000000000000000116211273074646000261240ustar00rootroot00000000000000package fasthttputil import ( "fmt" "io" "io/ioutil" "net" "testing" "time" ) func TestPipeConnsCloseWhileReadWriteConcurrent(t *testing.T) { concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testPipeConnsCloseWhileReadWriteSerial(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(3 * time.Second): t.Fatalf("timeout") } } } func TestPipeConnsCloseWhileReadWriteSerial(t *testing.T) { testPipeConnsCloseWhileReadWriteSerial(t) } func testPipeConnsCloseWhileReadWriteSerial(t *testing.T) { for i := 0; i < 10; i++ { testPipeConnsCloseWhileReadWrite(t) } } func testPipeConnsCloseWhileReadWrite(t *testing.T) { pc := NewPipeConns() c1 := pc.Conn1() c2 := pc.Conn2() readCh := make(chan error) go func() { var err error if _, err = io.Copy(ioutil.Discard, c1); err != nil { if err != errConnectionClosed { err = fmt.Errorf("unexpected error: %s", err) } else { err = nil } } readCh <- err }() writeCh := make(chan error) go func() { var err error for { if _, err = c2.Write([]byte("foobar")); err != nil { if err != errConnectionClosed { err = fmt.Errorf("unexpected error: %s", err) } else { err = nil } break } } writeCh <- err }() time.Sleep(10 * time.Millisecond) if err := c1.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } if err := c2.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case err := <-readCh: if err != nil { t.Fatalf("unexpected error in reader: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout") } select { case err := <-writeCh: if err != nil { t.Fatalf("unexpected error in writer: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout") } } func TestPipeConnsReadWriteSerial(t *testing.T) { testPipeConnsReadWriteSerial(t) } func TestPipeConnsReadWriteConcurrent(t *testing.T) { testConcurrency(t, 10, testPipeConnsReadWriteSerial) } func testPipeConnsReadWriteSerial(t *testing.T) { pc := NewPipeConns() testPipeConnsReadWrite(t, pc.Conn1(), pc.Conn2()) pc = NewPipeConns() testPipeConnsReadWrite(t, pc.Conn2(), pc.Conn1()) } func testPipeConnsReadWrite(t *testing.T, c1, c2 net.Conn) { defer c1.Close() defer c2.Close() var buf [32]byte for i := 0; i < 10; i++ { // The first write s1 := fmt.Sprintf("foo_%d", i) n, err := c1.Write([]byte(s1)) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != len(s1) { t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s1)) } // The second write s2 := fmt.Sprintf("bar_%d", i) n, err = c1.Write([]byte(s2)) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != len(s2) { t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s2)) } // Read data written above in two writes s := s1 + s2 n, err = c2.Read(buf[:]) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != len(s) { t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(s)) } if string(buf[:n]) != s { t.Fatalf("unexpected string read: %q. Expecting %q", buf[:n], s) } } } func TestPipeConnsCloseSerial(t *testing.T) { testPipeConnsCloseSerial(t) } func TestPipeConnsCloseConcurrent(t *testing.T) { testConcurrency(t, 10, testPipeConnsCloseSerial) } func testPipeConnsCloseSerial(t *testing.T) { pc := NewPipeConns() testPipeConnsClose(t, pc.Conn1(), pc.Conn2()) pc = NewPipeConns() testPipeConnsClose(t, pc.Conn2(), pc.Conn1()) } func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) { if err := c1.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } var buf [10]byte // attempt writing to closed conn for i := 0; i < 10; i++ { n, err := c1.Write(buf[:]) if err == nil { t.Fatalf("expecting error") } if n != 0 { t.Fatalf("unexpected number of bytes written: %d. Expecting 0", n) } } // attempt reading from closed conn for i := 0; i < 10; i++ { n, err := c2.Read(buf[:]) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } if n != 0 { t.Fatalf("unexpected number of bytes read: %d. Expecting 0", n) } } if err := c2.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } // attempt closing already closed conns for i := 0; i < 10; i++ { if err := c1.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } if err := c2.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } } } func testConcurrency(t *testing.T, concurrency int, f func(*testing.T)) { ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { f(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.key000066400000000000000000000032501273074646000266100ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG 3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT 6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt 9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV c257YgaWmjK9uB0Y2r2VxS0G -----END PRIVATE KEY----- golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.pem000066400000000000000000000017551273074646000266110ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te +z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY 6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd -----END CERTIFICATE----- golang-github-valyala-fasthttp-20160617/fs.go000066400000000000000000000776531273074646000207650ustar00rootroot00000000000000package fasthttp import ( "bytes" "errors" "fmt" "html" "io" "io/ioutil" "mime" "net/http" "os" "path/filepath" "sort" "strings" "sync" "time" "github.com/klauspost/compress/gzip" ) // ServeFileBytesUncompressed returns HTTP response containing file contents // from the given path. // // Directory contents is returned if path points to directory. // // ServeFileBytes may be used for saving network traffic when serving files // with good compression ratio. // // See also RequestCtx.SendFileBytes. func ServeFileBytesUncompressed(ctx *RequestCtx, path []byte) { ServeFileUncompressed(ctx, b2s(path)) } // ServeFileUncompressed returns HTTP response containing file contents // from the given path. // // Directory contents is returned if path points to directory. // // ServeFile may be used for saving network traffic when serving files // with good compression ratio. // // See also RequestCtx.SendFile. func ServeFileUncompressed(ctx *RequestCtx, path string) { ctx.Request.Header.DelBytes(strAcceptEncoding) ServeFile(ctx, path) } // ServeFileBytes returns HTTP response containing compressed file contents // from the given path. // // HTTP response may contain uncompressed file contents in the following cases: // // * Missing 'Accept-Encoding: gzip' request header. // * No write access to directory containing the file. // // Directory contents is returned if path points to directory. // // Use ServeFileBytesUncompressed is you don't need serving compressed // file contents. // // See also RequestCtx.SendFileBytes. func ServeFileBytes(ctx *RequestCtx, path []byte) { ServeFile(ctx, b2s(path)) } // ServeFile returns HTTP response containing compressed file contents // from the given path. // // HTTP response may contain uncompressed file contents in the following cases: // // * Missing 'Accept-Encoding: gzip' request header. // * No write access to directory containing the file. // // Directory contents is returned if path points to directory. // // Use ServeFileUncompressed is you don't need serving compressed file contents. // // See also RequestCtx.SendFile. func ServeFile(ctx *RequestCtx, path string) { rootFSOnce.Do(func() { rootFSHandler = rootFS.NewRequestHandler() }) if len(path) == 0 || path[0] != '/' { // extend relative path to absolute path var err error if path, err = filepath.Abs(path); err != nil { ctx.Logger().Printf("cannot resolve path %q to absolute file path: %s", path, err) ctx.Error("Internal Server Error", StatusInternalServerError) return } } ctx.Request.SetRequestURI(path) rootFSHandler(ctx) } var ( rootFSOnce sync.Once rootFS = &FS{ Root: "/", GenerateIndexPages: true, Compress: true, AcceptByteRange: true, } rootFSHandler RequestHandler ) // PathRewriteFunc must return new request path based on arbitrary ctx // info such as ctx.Path(). // // Path rewriter is used in FS for translating the current request // to the local filesystem path relative to FS.Root. // // The returned path must not contain '/../' substrings due to security reasons, // since such paths may refer files outside FS.Root. // // The returned path may refer to ctx members. For example, ctx.Path(). type PathRewriteFunc func(ctx *RequestCtx) []byte // NewVHostPathRewriter returns path rewriter, which strips slashesCount // leading slashes from the path and prepends the path with request's host, // thus simplifying virtual hosting for static files. // // Examples: // // * host=foobar.com, slashesCount=0, original path="/foo/bar". // Resulting path: "/foobar.com/foo/bar" // // * host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg" // Resulting path: "/img.aaa.com/123/456.jpg" // func NewVHostPathRewriter(slashesCount int) PathRewriteFunc { return func(ctx *RequestCtx) []byte { path := stripLeadingSlashes(ctx.Path(), slashesCount) host := ctx.Host() if n := bytes.IndexByte(host, '/'); n >= 0 { host = nil } if len(host) == 0 { host = strInvalidHost } b := AcquireByteBuffer() b.B = append(b.B, '/') b.B = append(b.B, host...) b.B = append(b.B, path...) ctx.URI().SetPathBytes(b.B) ReleaseByteBuffer(b) return ctx.Path() } } var strInvalidHost = []byte("invalid-host") // NewPathSlashesStripper returns path rewriter, which strips slashesCount // leading slashes from the path. // // Examples: // // * slashesCount = 0, original path: "/foo/bar", result: "/foo/bar" // * slashesCount = 1, original path: "/foo/bar", result: "/bar" // * slashesCount = 2, original path: "/foo/bar", result: "" // // The returned path rewriter may be used as FS.PathRewrite . func NewPathSlashesStripper(slashesCount int) PathRewriteFunc { return func(ctx *RequestCtx) []byte { return stripLeadingSlashes(ctx.Path(), slashesCount) } } // NewPathPrefixStripper returns path rewriter, which removes prefixSize bytes // from the path prefix. // // Examples: // // * prefixSize = 0, original path: "/foo/bar", result: "/foo/bar" // * prefixSize = 3, original path: "/foo/bar", result: "o/bar" // * prefixSize = 7, original path: "/foo/bar", result: "r" // // The returned path rewriter may be used as FS.PathRewrite . func NewPathPrefixStripper(prefixSize int) PathRewriteFunc { return func(ctx *RequestCtx) []byte { path := ctx.Path() if len(path) >= prefixSize { path = path[prefixSize:] } return path } } // FS represents settings for request handler serving static files // from the local filesystem. // // It is prohibited copying FS values. Create new values instead. type FS struct { noCopy noCopy // Path to the root directory to serve files from. Root string // List of index file names to try opening during directory access. // // For example: // // * index.html // * index.htm // * my-super-index.xml // // By default the list is empty. IndexNames []string // Index pages for directories without files matching IndexNames // are automatically generated if set. // // Directory index generation may be quite slow for directories // with many files (more than 1K), so it is discouraged enabling // index pages' generation for such directories. // // By default index pages aren't generated. GenerateIndexPages bool // Transparently compresses responses if set to true. // // The server tries minimizing CPU usage by caching compressed files. // It adds CompressedFileSuffix suffix to the original file name and // tries saving the resulting compressed file under the new file name. // So it is advisable to give the server write access to Root // and to all inner folders in order to minimze CPU usage when serving // compressed responses. // // Transparent compression is disabled by default. Compress bool // Enables byte range requests if set to true. // // Byte range requests are disabled by default. AcceptByteRange bool // Path rewriting function. // // By default request path is not modified. PathRewrite PathRewriteFunc // Expiration duration for inactive file handlers. // // FSHandlerCacheDuration is used by default. CacheDuration time.Duration // Suffix to add to the name of cached compressed file. // // This value has sense only if Compress is set. // // FSCompressedFileSuffix is used by default. CompressedFileSuffix string once sync.Once h RequestHandler } // FSCompressedFileSuffix is the suffix FS adds to the original file names // when trying to store compressed file under the new file name. // See FS.Compress for details. const FSCompressedFileSuffix = ".fasthttp.gz" // FSHandlerCacheDuration is the default expiration duration for inactive // file handlers opened by FS. const FSHandlerCacheDuration = 10 * time.Second // FSHandler returns request handler serving static files from // the given root folder. // // stripSlashes indicates how many leading slashes must be stripped // from requested path before searching requested file in the root folder. // Examples: // // * stripSlashes = 0, original path: "/foo/bar", result: "/foo/bar" // * stripSlashes = 1, original path: "/foo/bar", result: "/bar" // * stripSlashes = 2, original path: "/foo/bar", result: "" // // The returned request handler automatically generates index pages // for directories without index.html. // // The returned handler caches requested file handles // for FSHandlerCacheDuration. // Make sure your program has enough 'max open files' limit aka // 'ulimit -n' if root folder contains many files. // // Do not create multiple request handler instances for the same // (root, stripSlashes) arguments - just reuse a single instance. // Otherwise goroutine leak will occur. func FSHandler(root string, stripSlashes int) RequestHandler { fs := &FS{ Root: root, IndexNames: []string{"index.html"}, GenerateIndexPages: true, AcceptByteRange: true, } if stripSlashes > 0 { fs.PathRewrite = NewPathSlashesStripper(stripSlashes) } return fs.NewRequestHandler() } // NewRequestHandler returns new request handler with the given FS settings. // // The returned handler caches requested file handles // for FS.CacheDuration. // Make sure your program has enough 'max open files' limit aka // 'ulimit -n' if FS.Root folder contains many files. // // Do not create multiple request handlers from a single FS instance - // just reuse a single request handler. func (fs *FS) NewRequestHandler() RequestHandler { fs.once.Do(fs.initRequestHandler) return fs.h } func (fs *FS) initRequestHandler() { root := fs.Root // serve files from the current working directory if root is empty if len(root) == 0 { root = "." } // strip trailing slashes from the root path for len(root) > 0 && root[len(root)-1] == '/' { root = root[:len(root)-1] } cacheDuration := fs.CacheDuration if cacheDuration <= 0 { cacheDuration = FSHandlerCacheDuration } compressedFileSuffix := fs.CompressedFileSuffix if len(compressedFileSuffix) == 0 { compressedFileSuffix = FSCompressedFileSuffix } h := &fsHandler{ root: root, indexNames: fs.IndexNames, pathRewrite: fs.PathRewrite, generateIndexPages: fs.GenerateIndexPages, compress: fs.Compress, acceptByteRange: fs.AcceptByteRange, cacheDuration: cacheDuration, compressedFileSuffix: compressedFileSuffix, cache: make(map[string]*fsFile), compressedCache: make(map[string]*fsFile), } go func() { var pendingFiles []*fsFile for { time.Sleep(cacheDuration / 2) pendingFiles = h.cleanCache(pendingFiles) } }() fs.h = h.handleRequest } type fsHandler struct { root string indexNames []string pathRewrite PathRewriteFunc generateIndexPages bool compress bool acceptByteRange bool cacheDuration time.Duration compressedFileSuffix string cache map[string]*fsFile compressedCache map[string]*fsFile cacheLock sync.Mutex smallFileReaderPool sync.Pool } type fsFile struct { h *fsHandler f *os.File dirIndex []byte contentType string contentLength int compressed bool lastModified time.Time lastModifiedStr []byte t time.Time readersCount int bigFiles []*bigFileReader bigFilesLock sync.Mutex } func (ff *fsFile) NewReader() (io.Reader, error) { if ff.isBig() { r, err := ff.bigFileReader() if err != nil { ff.decReadersCount() } return r, err } return ff.smallFileReader(), nil } func (ff *fsFile) smallFileReader() io.Reader { v := ff.h.smallFileReaderPool.Get() if v == nil { v = &fsSmallFileReader{} } r := v.(*fsSmallFileReader) r.ff = ff r.endPos = ff.contentLength if r.startPos > 0 { panic("BUG: fsSmallFileReader with non-nil startPos found in the pool") } return r } // files bigger than this size are sent with sendfile const maxSmallFileSize = 2 * 4096 func (ff *fsFile) isBig() bool { return ff.contentLength > maxSmallFileSize && len(ff.dirIndex) == 0 } func (ff *fsFile) bigFileReader() (io.Reader, error) { if ff.f == nil { panic("BUG: ff.f must be non-nil in bigFileReader") } var r io.Reader ff.bigFilesLock.Lock() n := len(ff.bigFiles) if n > 0 { r = ff.bigFiles[n-1] ff.bigFiles = ff.bigFiles[:n-1] } ff.bigFilesLock.Unlock() if r != nil { return r, nil } f, err := os.Open(ff.f.Name()) if err != nil { return nil, fmt.Errorf("cannot open already opened file: %s", err) } return &bigFileReader{ f: f, ff: ff, r: f, }, nil } func (ff *fsFile) Release() { if ff.f != nil { ff.f.Close() if ff.isBig() { ff.bigFilesLock.Lock() for _, r := range ff.bigFiles { r.f.Close() } ff.bigFilesLock.Unlock() } } } func (ff *fsFile) decReadersCount() { ff.h.cacheLock.Lock() ff.readersCount-- if ff.readersCount < 0 { panic("BUG: negative fsFile.readersCount!") } ff.h.cacheLock.Unlock() } // bigFileReader attempts to trigger sendfile // for sending big files over the wire. type bigFileReader struct { f *os.File ff *fsFile r io.Reader lr io.LimitedReader } func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error { if _, err := r.f.Seek(int64(startPos), 0); err != nil { return err } r.r = &r.lr r.lr.R = r.f r.lr.N = int64(endPos - startPos + 1) return nil } func (r *bigFileReader) Read(p []byte) (int, error) { return r.r.Read(p) } func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) { if rf, ok := w.(io.ReaderFrom); ok { // fast path. Senfile must be triggered return rf.ReadFrom(r.r) } // slow path return copyZeroAlloc(w, r.r) } func (r *bigFileReader) Close() error { r.r = r.f n, err := r.f.Seek(0, 0) if err == nil { if n != 0 { panic("BUG: File.Seek(0,0) returned (non-zero, nil)") } ff := r.ff ff.bigFilesLock.Lock() ff.bigFiles = append(ff.bigFiles, r) ff.bigFilesLock.Unlock() } else { r.f.Close() } r.ff.decReadersCount() return err } type fsSmallFileReader struct { ff *fsFile startPos int endPos int } func (r *fsSmallFileReader) Close() error { ff := r.ff ff.decReadersCount() r.ff = nil r.startPos = 0 r.endPos = 0 ff.h.smallFileReaderPool.Put(r) return nil } func (r *fsSmallFileReader) UpdateByteRange(startPos, endPos int) error { r.startPos = startPos r.endPos = endPos + 1 return nil } func (r *fsSmallFileReader) Read(p []byte) (int, error) { tailLen := r.endPos - r.startPos if tailLen <= 0 { return 0, io.EOF } if len(p) > tailLen { p = p[:tailLen] } ff := r.ff if ff.f != nil { n, err := ff.f.ReadAt(p, int64(r.startPos)) r.startPos += n return n, err } n := copy(p, ff.dirIndex[r.startPos:]) r.startPos += n return n, nil } func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) { ff := r.ff var n int var err error if ff.f == nil { n, err = w.Write(ff.dirIndex[r.startPos:r.endPos]) return int64(n), err } if rf, ok := w.(io.ReaderFrom); ok { return rf.ReadFrom(r) } curPos := r.startPos bufv := copyBufPool.Get() buf := bufv.([]byte) for err != nil { tailLen := r.endPos - curPos if tailLen <= 0 { break } if len(buf) > tailLen { buf = buf[:tailLen] } n, err = ff.f.ReadAt(buf, int64(curPos)) nw, errw := w.Write(buf[:n]) curPos += nw if errw == nil && nw != n { panic("BUG: Write(p) returned (n, nil), where n != len(p)") } if err == nil { err = errw } } copyBufPool.Put(bufv) if err == io.EOF { err = nil } return int64(curPos - r.startPos), err } func (h *fsHandler) cleanCache(pendingFiles []*fsFile) []*fsFile { var filesToRelease []*fsFile h.cacheLock.Lock() // Close files which couldn't be closed before due to non-zero // readers count on the previous run. var remainingFiles []*fsFile for _, ff := range pendingFiles { if ff.readersCount > 0 { remainingFiles = append(remainingFiles, ff) } else { filesToRelease = append(filesToRelease, ff) } } pendingFiles = remainingFiles pendingFiles, filesToRelease = cleanCacheNolock(h.cache, pendingFiles, filesToRelease, h.cacheDuration) pendingFiles, filesToRelease = cleanCacheNolock(h.compressedCache, pendingFiles, filesToRelease, h.cacheDuration) h.cacheLock.Unlock() for _, ff := range filesToRelease { ff.Release() } return pendingFiles } func cleanCacheNolock(cache map[string]*fsFile, pendingFiles, filesToRelease []*fsFile, cacheDuration time.Duration) ([]*fsFile, []*fsFile) { t := time.Now() for k, ff := range cache { if t.Sub(ff.t) > cacheDuration { if ff.readersCount > 0 { // There are pending readers on stale file handle, // so we cannot close it. Put it into pendingFiles // so it will be closed later. pendingFiles = append(pendingFiles, ff) } else { filesToRelease = append(filesToRelease, ff) } delete(cache, k) } } return pendingFiles, filesToRelease } func (h *fsHandler) handleRequest(ctx *RequestCtx) { var path []byte if h.pathRewrite != nil { path = h.pathRewrite(ctx) } else { path = ctx.Path() } path = stripTrailingSlashes(path) if n := bytes.IndexByte(path, 0); n >= 0 { ctx.Logger().Printf("cannot serve path with nil byte at position %d: %q", n, path) ctx.Error("Are you a hacker?", StatusBadRequest) return } if h.pathRewrite != nil { // There is no need to check for '/../' if path = ctx.Path(), // since ctx.Path must normalize and sanitize the path. if n := bytes.Index(path, strSlashDotDotSlash); n >= 0 { ctx.Logger().Printf("cannot serve path with '/../' at position %d due to security reasons: %q", n, path) ctx.Error("Internal Server Error", StatusInternalServerError) return } } mustCompress := false fileCache := h.cache byteRange := ctx.Request.Header.peek(strRange) if len(byteRange) == 0 && h.compress && ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { mustCompress = true fileCache = h.compressedCache } h.cacheLock.Lock() ff, ok := fileCache[string(path)] if ok { ff.readersCount++ } h.cacheLock.Unlock() if !ok { pathStr := string(path) filePath := h.root + pathStr var err error ff, err = h.openFSFile(filePath, mustCompress) if mustCompress && err == errNoCreatePermission { ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+ "Allow write access to the directory with this file in order to improve fasthttp performance", filePath) mustCompress = false ff, err = h.openFSFile(filePath, mustCompress) } if err == errDirIndexRequired { ff, err = h.openIndexFile(ctx, filePath, mustCompress) if err != nil { ctx.Logger().Printf("cannot open dir index %q: %s", filePath, err) ctx.Error("Directory index is forbidden", StatusForbidden) return } } else if err != nil { ctx.Logger().Printf("cannot open file %q: %s", filePath, err) ctx.Error("Cannot open requested path", StatusNotFound) return } h.cacheLock.Lock() ff1, ok := fileCache[pathStr] if !ok { fileCache[pathStr] = ff ff.readersCount++ } else { ff1.readersCount++ } h.cacheLock.Unlock() if ok { // The file has been already opened by another // goroutine, so close the current file and use // the file opened by another goroutine instead. ff.Release() ff = ff1 } } if !ctx.IfModifiedSince(ff.lastModified) { ff.decReadersCount() ctx.NotModified() return } r, err := ff.NewReader() if err != nil { ctx.Logger().Printf("cannot obtain file reader for path=%q: %s", path, err) ctx.Error("Internal Server Error", StatusInternalServerError) return } hdr := &ctx.Response.Header if ff.compressed { hdr.SetCanonical(strContentEncoding, strGzip) } statusCode := StatusOK contentLength := ff.contentLength if h.acceptByteRange { hdr.SetCanonical(strAcceptRanges, strBytes) if len(byteRange) > 0 { startPos, endPos, err := ParseByteRange(byteRange, contentLength) if err != nil { r.(io.Closer).Close() ctx.Logger().Printf("cannot parse byte range %q for path=%q: %s", byteRange, path, err) ctx.Error("Range Not Satisfiable", StatusRequestedRangeNotSatisfiable) return } if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil { r.(io.Closer).Close() ctx.Logger().Printf("cannot seek byte range %q for path=%q: %s", byteRange, path, err) ctx.Error("Internal Server Error", StatusInternalServerError) return } hdr.SetContentRange(startPos, endPos, contentLength) contentLength = endPos - startPos + 1 statusCode = StatusPartialContent } } hdr.SetCanonical(strLastModified, ff.lastModifiedStr) if !ctx.IsHead() { ctx.SetBodyStream(r, contentLength) } else { ctx.Response.ResetBody() ctx.Response.SkipBody = true ctx.Response.Header.SetContentLength(contentLength) if rc, ok := r.(io.Closer); ok { if err := rc.Close(); err != nil { ctx.Logger().Printf("cannot close file reader: %s", err) ctx.Error("Internal Server Error", StatusInternalServerError) return } } } ctx.SetContentType(ff.contentType) ctx.SetStatusCode(statusCode) } type byteRangeUpdater interface { UpdateByteRange(startPos, endPos int) error } // ParseByteRange parses 'Range: bytes=...' header value. // // It follows https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 . func ParseByteRange(byteRange []byte, contentLength int) (startPos, endPos int, err error) { b := byteRange if !bytes.HasPrefix(b, strBytes) { return 0, 0, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, strBytes) } b = b[len(strBytes):] if len(b) == 0 || b[0] != '=' { return 0, 0, fmt.Errorf("missing byte range in %q", byteRange) } b = b[1:] n := bytes.IndexByte(b, '-') if n < 0 { return 0, 0, fmt.Errorf("missing the end position of byte range in %q", byteRange) } if n == 0 { v, err := ParseUint(b[n+1:]) if err != nil { return 0, 0, err } startPos := contentLength - v if startPos < 0 { startPos = 0 } return startPos, contentLength - 1, nil } if startPos, err = ParseUint(b[:n]); err != nil { return 0, 0, err } if startPos >= contentLength { return 0, 0, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange) } b = b[n+1:] if len(b) == 0 { return startPos, contentLength - 1, nil } if endPos, err = ParseUint(b); err != nil { return 0, 0, err } if endPos >= contentLength { endPos = contentLength - 1 } if endPos < startPos { return 0, 0, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange) } return startPos, endPos, nil } func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress bool) (*fsFile, error) { for _, indexName := range h.indexNames { indexFilePath := dirPath + "/" + indexName ff, err := h.openFSFile(indexFilePath, mustCompress) if err == nil { return ff, nil } if !os.IsNotExist(err) { return nil, fmt.Errorf("cannot open file %q: %s", indexFilePath, err) } } if !h.generateIndexPages { return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath) } return h.createDirIndex(ctx.URI(), dirPath, mustCompress) } var ( errDirIndexRequired = errors.New("directory index required") errNoCreatePermission = errors.New("no 'create file' permissions") ) func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool) (*fsFile, error) { w := &ByteBuffer{} basePathEscaped := html.EscapeString(string(base.Path())) fmt.Fprintf(w, "%s", basePathEscaped) fmt.Fprintf(w, "

%s

", basePathEscaped) fmt.Fprintf(w, "") if mustCompress { var zbuf ByteBuffer zw := acquireGzipWriter(&zbuf, CompressDefaultCompression) _, err = zw.Write(w.B) releaseGzipWriter(zw) if err != nil { return nil, fmt.Errorf("error when compressing automatically generated index for directory %q: %s", dirPath, err) } w = &zbuf } dirIndex := w.B lastModified := time.Now() ff := &fsFile{ h: h, dirIndex: dirIndex, contentType: "text/html; charset=utf-8", contentLength: len(dirIndex), compressed: mustCompress, lastModified: lastModified, lastModifiedStr: AppendHTTPDate(nil, lastModified), t: lastModified, } return ff, nil } const ( fsMinCompressRatio = 0.8 fsMaxCompressibleFileSize = 8 * 1024 * 1024 ) func (h *fsHandler) compressAndOpenFSFile(filePath string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, err } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err) } if fileInfo.IsDir() { f.Close() return nil, errDirIndexRequired } if strings.HasSuffix(filePath, h.compressedFileSuffix) || fileInfo.Size() > fsMaxCompressibleFileSize || !isFileCompressible(f, fsMinCompressRatio) { return h.newFSFile(f, fileInfo, false) } compressedFilePath := filePath + h.compressedFileSuffix absPath, err := filepath.Abs(compressedFilePath) if err != nil { f.Close() return nil, fmt.Errorf("cannot determine absolute path for %q: %s", compressedFilePath, err) } flock := getFileLock(absPath) flock.Lock() ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath) flock.Unlock() return ff, err } func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string) (*fsFile, error) { // Attempt to open compressed file created by another concurrent // goroutine. // It is safe opening such a file, since the file creation // is guarded by file mutex - see getFileLock call. if _, err := os.Stat(compressedFilePath); err == nil { f.Close() return h.newCompressedFSFile(compressedFilePath) } // Create temporary file, so concurrent goroutines don't use // it until it is created. tmpFilePath := compressedFilePath + ".tmp" zf, err := os.Create(tmpFilePath) if err != nil { f.Close() if !os.IsPermission(err) { return nil, fmt.Errorf("cannot create temporary file %q: %s", tmpFilePath, err) } return nil, errNoCreatePermission } zw := acquireGzipWriter(zf, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseGzipWriter(zw) zf.Close() f.Close() if err != nil { return nil, fmt.Errorf("error when compressing file %q to %q: %s", filePath, tmpFilePath, err) } if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil { return nil, fmt.Errorf("cannot change modification time to %s for tmp file %q: %s", fileInfo.ModTime(), tmpFilePath, err) } if err = os.Rename(tmpFilePath, compressedFilePath); err != nil { return nil, fmt.Errorf("cannot move compressed file from %q to %q: %s", tmpFilePath, compressedFilePath, err) } return h.newCompressedFSFile(compressedFilePath) } func (h *fsHandler) newCompressedFSFile(filePath string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("cannot open compressed file %q: %s", filePath, err) } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for compressed file %q: %s", filePath, err) } return h.newFSFile(f, fileInfo, true) } func (h *fsHandler) openFSFile(filePath string, mustCompress bool) (*fsFile, error) { filePathOriginal := filePath if mustCompress { filePath += h.compressedFileSuffix } f, err := os.Open(filePath) if err != nil { if mustCompress && os.IsNotExist(err) { return h.compressAndOpenFSFile(filePathOriginal) } return nil, err } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err) } if fileInfo.IsDir() { f.Close() if mustCompress { return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q", filePath, h.compressedFileSuffix) } return nil, errDirIndexRequired } if mustCompress { fileInfoOriginal, err := os.Stat(filePathOriginal) if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for original file %q: %s", filePathOriginal, err) } if fileInfoOriginal.ModTime() != fileInfo.ModTime() { // The compressed file became stale. Re-create it. f.Close() os.Remove(filePath) return h.compressAndOpenFSFile(filePathOriginal) } } return h.newFSFile(f, fileInfo, mustCompress) } func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool) (*fsFile, error) { n := fileInfo.Size() contentLength := int(n) if n != int64(contentLength) { f.Close() return nil, fmt.Errorf("too big file: %d bytes", n) } // detect content-type ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffix) contentType := mime.TypeByExtension(ext) if len(contentType) == 0 { data, err := readFileHeader(f, compressed) if err != nil { return nil, fmt.Errorf("cannot read header of the file %q: %s", f.Name(), err) } contentType = http.DetectContentType(data) } lastModified := fileInfo.ModTime() ff := &fsFile{ h: h, f: f, contentType: contentType, contentLength: contentLength, compressed: compressed, lastModified: lastModified, lastModifiedStr: AppendHTTPDate(nil, lastModified), t: time.Now(), } return ff, nil } func readFileHeader(f *os.File, compressed bool) ([]byte, error) { r := io.Reader(f) var zr *gzip.Reader if compressed { var err error if zr, err = acquireGzipReader(f); err != nil { return nil, err } r = zr } lr := &io.LimitedReader{ R: r, N: 512, } data, err := ioutil.ReadAll(lr) f.Seek(0, 0) if zr != nil { releaseGzipReader(zr) } return data, err } func stripLeadingSlashes(path []byte, stripSlashes int) []byte { for stripSlashes > 0 && len(path) > 0 { if path[0] != '/' { panic("BUG: path must start with slash") } n := bytes.IndexByte(path[1:], '/') if n < 0 { path = path[:0] break } path = path[n+1:] stripSlashes-- } return path } func stripTrailingSlashes(path []byte) []byte { for len(path) > 0 && path[len(path)-1] == '/' { path = path[:len(path)-1] } return path } func fileExtension(path string, compressed bool, compressedFileSuffix string) string { if compressed && strings.HasSuffix(path, compressedFileSuffix) { path = path[:len(path)-len(compressedFileSuffix)] } n := strings.LastIndexByte(path, '.') if n < 0 { return "" } return path[n:] } // FileLastModified returns last modified time for the file. func FileLastModified(path string) (time.Time, error) { f, err := os.Open(path) if err != nil { return zeroTime, err } fileInfo, err := f.Stat() f.Close() if err != nil { return zeroTime, err } return fsModTime(fileInfo.ModTime()), nil } func fsModTime(t time.Time) time.Time { return t.In(time.UTC).Truncate(time.Second) } var ( filesLockMap = make(map[string]*sync.Mutex) filesLockMapLock sync.Mutex ) func getFileLock(absPath string) *sync.Mutex { filesLockMapLock.Lock() flock := filesLockMap[absPath] if flock == nil { flock = &sync.Mutex{} filesLockMap[absPath] = flock } filesLockMapLock.Unlock() return flock } golang-github-valyala-fasthttp-20160617/fs_example_test.go000066400000000000000000000011011273074646000235060ustar00rootroot00000000000000package fasthttp_test import ( "log" "github.com/valyala/fasthttp" ) func ExampleFS() { fs := &fasthttp.FS{ // Path to directory to serve. Root: "/var/www/static-site", // Generate index pages if client requests directory contents. GenerateIndexPages: true, // Enable transparent compression to save network traffic. Compress: true, } // Create request handler for serving static files. h := fs.NewRequestHandler() // Start the server. if err := fasthttp.ListenAndServe(":8080", h); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } golang-github-valyala-fasthttp-20160617/fs_handler_example_test.go000066400000000000000000000022211273074646000252070ustar00rootroot00000000000000package fasthttp_test import ( "bytes" "log" "github.com/valyala/fasthttp" ) // Setup file handlers (aka 'file server config') var ( // Handler for serving images from /img/ path, // i.e. /img/foo/bar.jpg will be served from // /var/www/images/foo/bar.jpb . imgPrefix = []byte("/img/") imgHandler = fasthttp.FSHandler("/var/www/images", 1) // Handler for serving css from /static/css/ path, // i.e. /static/css/foo/bar.css will be served from // /home/dev/css/foo/bar.css . cssPrefix = []byte("/static/css/") cssHandler = fasthttp.FSHandler("/home/dev/css", 2) // Handler for serving the rest of requests, // i.e. /foo/bar/baz.html will be served from // /var/www/files/foo/bar/baz.html . filesHandler = fasthttp.FSHandler("/var/www/files", 0) ) // Main request handler func requestHandler(ctx *fasthttp.RequestCtx) { path := ctx.Path() switch { case bytes.HasPrefix(path, imgPrefix): imgHandler(ctx) case bytes.HasPrefix(path, cssPrefix): cssHandler(ctx) default: filesHandler(ctx) } } func ExampleFSHandler() { if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil { log.Fatalf("Error in server: %s", err) } } golang-github-valyala-fasthttp-20160617/fs_test.go000066400000000000000000000366641273074646000220210ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "fmt" "io/ioutil" "math/rand" "os" "sort" "testing" "time" ) func TestNewVHostPathRewriter(t *testing.T) { var ctx RequestCtx var req Request req.Header.SetHost("foobar.com") req.SetRequestURI("/foo/bar/baz") ctx.Init(&req, nil, nil) f := NewVHostPathRewriter(0) path := f(&ctx) expectedPath := "/foobar.com/foo/bar/baz" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } ctx.Request.Reset() ctx.Request.SetRequestURI("https://aaa.bbb.cc/one/two/three/four?asdf=dsf") f = NewVHostPathRewriter(2) path = f(&ctx) expectedPath = "/aaa.bbb.cc/three/four" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } } func TestNewVHostPathRewriterMaliciousHost(t *testing.T) { var ctx RequestCtx var req Request req.Header.SetHost("/../../../etc/passwd") req.SetRequestURI("/foo/bar/baz") ctx.Init(&req, nil, nil) f := NewVHostPathRewriter(0) path := f(&ctx) expectedPath := "/invalid-host/foo/bar/baz" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } } func TestServeFileHead(t *testing.T) { var ctx RequestCtx var req Request req.Header.SetMethod("HEAD") req.SetRequestURI("http://foobar.com/baz") ctx.Init(&req, nil, nil) ServeFile(&ctx, "fs.go") var resp Response resp.SkipBody = true s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := resp.Header.Peek("Content-Encoding") if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } body := resp.Body() if len(body) > 0 { t.Fatalf("unexpected response body %q. Expecting empty body", body) } expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } contentLength := resp.Header.ContentLength() if contentLength != len(expectedBody) { t.Fatalf("unexpected Content-Length: %d. expecting %d", contentLength, len(expectedBody)) } } func TestServeFileCompressed(t *testing.T) { var ctx RequestCtx var req Request req.SetRequestURI("http://foobar.com/baz") req.Header.Set("Accept-Encoding", "gzip") ctx.Init(&req, nil, nil) ServeFile(&ctx, "fs.go") var resp Response s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := resp.Header.Peek("Content-Encoding") if string(ce) != "gzip" { t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") } body, err := resp.BodyGunzip() if err != nil { t.Fatalf("unexpected error: %s", err) } expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) } } func TestServeFileUncompressed(t *testing.T) { var ctx RequestCtx var req Request req.SetRequestURI("http://foobar.com/baz") req.Header.Set("Accept-Encoding", "gzip") ctx.Init(&req, nil, nil) ServeFileUncompressed(&ctx, "fs.go") var resp Response s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := resp.Header.Peek("Content-Encoding") if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } body := resp.Body() expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) } } func TestFSByteRangeConcurrent(t *testing.T) { fs := &FS{ Root: ".", AcceptByteRange: true, } h := fs.NewRequestHandler() concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { for j := 0; j < 5; j++ { testFSByteRange(t, h, "/fs.go") testFSByteRange(t, h, "/README.md") } ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-time.After(time.Second): t.Fatalf("timeout") case <-ch: } } } func TestFSByteRangeSingleThread(t *testing.T) { fs := &FS{ Root: ".", AcceptByteRange: true, } h := fs.NewRequestHandler() testFSByteRange(t, h, "/fs.go") testFSByteRange(t, h, "/README.md") } func testFSByteRange(t *testing.T, h RequestHandler, filePath string) { var ctx RequestCtx ctx.Init(&Request{}, nil, nil) expectedBody, err := getFileContents(filePath) if err != nil { t.Fatalf("cannot read file %q: %s", filePath, err) } fileSize := len(expectedBody) startPos := rand.Intn(fileSize) endPos := rand.Intn(fileSize) if endPos < startPos { startPos, endPos = endPos, startPos } ctx.Request.SetRequestURI(filePath) ctx.Request.Header.SetByteRange(startPos, endPos) h(&ctx) var resp Response s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if resp.StatusCode() != StatusPartialContent { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath) } cr := resp.Header.Peek("Content-Range") expectedCR := fmt.Sprintf("bytes %d-%d/%d", startPos, endPos, fileSize) if string(cr) != expectedCR { t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr, expectedCR, filePath) } body := resp.Body() bodySize := endPos - startPos + 1 if len(body) != bodySize { t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d", len(body), bodySize, filePath, startPos, endPos) } expectedBody = expectedBody[startPos : endPos+1] if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d", body, expectedBody, filePath, startPos, endPos) } } func getFileContents(path string) ([]byte, error) { path = "." + path f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() return ioutil.ReadAll(f) } func TestParseByteRangeSuccess(t *testing.T) { testParseByteRangeSuccess(t, "bytes=0-0", 1, 0, 0) testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, 1234, 6789) testParseByteRangeSuccess(t, "bytes=123-", 456, 123, 455) testParseByteRangeSuccess(t, "bytes=-1", 1, 0, 0) testParseByteRangeSuccess(t, "bytes=-123", 456, 333, 455) // End position exceeding content-length. It should be updated to content-length-1. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 testParseByteRangeSuccess(t, "bytes=1-2345", 234, 1, 233) testParseByteRangeSuccess(t, "bytes=0-2345", 2345, 0, 2344) // Start position overflow. Whole range must be returned. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 testParseByteRangeSuccess(t, "bytes=-567", 56, 0, 55) } func testParseByteRangeSuccess(t *testing.T, v string, contentLength, startPos, endPos int) { startPos1, endPos1, err := ParseByteRange([]byte(v), contentLength) if err != nil { t.Fatalf("unexpected error: %s. v=%q, contentLength=%d", err, v, contentLength) } if startPos1 != startPos { t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1, startPos, v, contentLength) } if endPos1 != endPos { t.Fatalf("unexpected endPos=%d. Expectind %d. v=%q, contentLenght=%d", endPos1, endPos, v, contentLength) } } func TestParseByteRangeError(t *testing.T) { // invalid value testParseByteRangeError(t, "asdfasdfas", 1234) // invalid units testParseByteRangeError(t, "foobar=1-34", 600) // missing '-' testParseByteRangeError(t, "bytes=1234", 1235) // non-numeric range testParseByteRangeError(t, "bytes=foobar", 123) testParseByteRangeError(t, "bytes=1-foobar", 123) testParseByteRangeError(t, "bytes=df-344", 545) // multiple byte ranges testParseByteRangeError(t, "bytes=1-2,4-6", 123) // byte range exceeding contentLength testParseByteRangeError(t, "bytes=123-", 12) // startPos exceeding endPos testParseByteRangeError(t, "bytes=123-34", 1234) } func testParseByteRangeError(t *testing.T, v string, contentLength int) { _, _, err := ParseByteRange([]byte(v), contentLength) if err == nil { t.Fatalf("expecting error when parsing byte range %q", v) } } func TestFSCompressConcurrent(t *testing.T) { fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, } h := fs.NewRequestHandler() concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { for j := 0; j < 5; j++ { testFSCompress(t, h, "/fs.go") testFSCompress(t, h, "/") testFSCompress(t, h, "/README.md") } ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func TestFSCompressSingleThread(t *testing.T) { fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, } h := fs.NewRequestHandler() testFSCompress(t, h, "/fs.go") testFSCompress(t, h, "/") testFSCompress(t, h, "/README.md") } func testFSCompress(t *testing.T, h RequestHandler, filePath string) { var ctx RequestCtx ctx.Init(&Request{}, nil, nil) // request uncompressed file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) h(&ctx) var resp Response s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } ce := resp.Header.Peek("Content-Encoding") if string(ce) != "" { t.Fatalf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } body := string(resp.Body()) // request compressed file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) ctx.Request.Header.Set("Accept-Encoding", "gzip") h(&ctx) s = ctx.Response.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } ce = resp.Header.Peek("Content-Encoding") if string(ce) != "gzip" { t.Fatalf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) } zbody, err := resp.BodyGunzip() if err != nil { t.Fatalf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath) } if string(zbody) != body { t.Fatalf("unexpected body %q. Expected %q. FilePath=%q", zbody, body, filePath) } } func TestFileLock(t *testing.T) { for i := 0; i < 10; i++ { filePath := fmt.Sprintf("foo/bar/%d.jpg", i) lock := getFileLock(filePath) lock.Lock() lock.Unlock() } for i := 0; i < 10; i++ { filePath := fmt.Sprintf("foo/bar/%d.jpg", i) lock := getFileLock(filePath) lock.Lock() lock.Unlock() } } func TestFSHandlerSingleThread(t *testing.T) { requestHandler := FSHandler(".", 0) f, err := os.Open(".") if err != nil { t.Fatalf("cannot open cwd: %s", err) } filenames, err := f.Readdirnames(0) f.Close() if err != nil { t.Fatalf("cannot read dirnames in cwd: %s", err) } sort.Sort(sort.StringSlice(filenames)) for i := 0; i < 3; i++ { fsHandlerTest(t, requestHandler, filenames) } } func TestFSHandlerConcurrent(t *testing.T) { requestHandler := FSHandler(".", 0) f, err := os.Open(".") if err != nil { t.Fatalf("cannot open cwd: %s", err) } filenames, err := f.Readdirnames(0) f.Close() if err != nil { t.Fatalf("cannot read dirnames in cwd: %s", err) } sort.Sort(sort.StringSlice(filenames)) concurrency := 10 ch := make(chan struct{}, concurrency) for j := 0; j < concurrency; j++ { go func() { for i := 0; i < 3; i++ { fsHandlerTest(t, requestHandler, filenames) } ch <- struct{}{} }() } for j := 0; j < concurrency; j++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func fsHandlerTest(t *testing.T, requestHandler RequestHandler, filenames []string) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) ctx.Request.Header.SetHost("foobar.com") filesTested := 0 for _, name := range filenames { f, err := os.Open(name) if err != nil { t.Fatalf("cannot open file %q: %s", name, err) } stat, err := f.Stat() if err != nil { t.Fatalf("cannot get file stat %q: %s", name, err) } if stat.IsDir() { f.Close() continue } data, err := ioutil.ReadAll(f) f.Close() if err != nil { t.Fatalf("cannot read file contents %q: %s", name, err) } ctx.URI().Update(name) requestHandler(&ctx) if ctx.Response.bodyStream == nil { t.Fatalf("response body stream must be non-empty") } body, err := ioutil.ReadAll(ctx.Response.bodyStream) if err != nil { t.Fatalf("error when reading response body stream: %s", err) } if !bytes.Equal(body, data) { t.Fatalf("unexpected body returned: %q. Expecting %q", body, data) } filesTested++ if filesTested >= 10 { break } } // verify index page generation ctx.URI().Update("/") requestHandler(&ctx) if ctx.Response.bodyStream == nil { t.Fatalf("response body stream must be non-empty") } body, err := ioutil.ReadAll(ctx.Response.bodyStream) if err != nil { t.Fatalf("error when reading response body stream: %s", err) } if len(body) == 0 { t.Fatalf("index page must be non-empty") } } func TestStripPathSlashes(t *testing.T) { testStripPathSlashes(t, "", 0, "") testStripPathSlashes(t, "", 10, "") testStripPathSlashes(t, "/", 0, "") testStripPathSlashes(t, "/", 1, "") testStripPathSlashes(t, "/", 10, "") testStripPathSlashes(t, "/foo/bar/baz", 0, "/foo/bar/baz") testStripPathSlashes(t, "/foo/bar/baz", 1, "/bar/baz") testStripPathSlashes(t, "/foo/bar/baz", 2, "/baz") testStripPathSlashes(t, "/foo/bar/baz", 3, "") testStripPathSlashes(t, "/foo/bar/baz", 10, "") // trailing slash testStripPathSlashes(t, "/foo/bar/", 0, "/foo/bar") testStripPathSlashes(t, "/foo/bar/", 1, "/bar") testStripPathSlashes(t, "/foo/bar/", 2, "") testStripPathSlashes(t, "/foo/bar/", 3, "") } func testStripPathSlashes(t *testing.T, path string, stripSlashes int, expectedPath string) { s := stripLeadingSlashes([]byte(path), stripSlashes) s = stripTrailingSlashes(s) if string(s) != expectedPath { t.Fatalf("unexpected path after stripping %q with stripSlashes=%d: %q. Expecting %q", path, stripSlashes, s, expectedPath) } } func TestFileExtension(t *testing.T) { testFileExtension(t, "foo.bar", false, "zzz", ".bar") testFileExtension(t, "foobar", false, "zzz", "") testFileExtension(t, "foo.bar.baz", false, "zzz", ".baz") testFileExtension(t, "", false, "zzz", "") testFileExtension(t, "/a/b/c.d/efg.jpg", false, ".zzz", ".jpg") testFileExtension(t, "foo.bar", true, ".zzz", ".bar") testFileExtension(t, "foobar.zzz", true, ".zzz", "") testFileExtension(t, "foo.bar.baz.fasthttp.gz", true, ".fasthttp.gz", ".baz") testFileExtension(t, "", true, ".zzz", "") testFileExtension(t, "/a/b/c.d/efg.jpg.xxx", true, ".xxx", ".jpg") } func testFileExtension(t *testing.T, path string, compressed bool, compressedFileSuffix, expectedExt string) { ext := fileExtension(path, compressed, compressedFileSuffix) if ext != expectedExt { t.Fatalf("unexpected file extension for file %q: %q. Expecting %q", path, ext, expectedExt) } } golang-github-valyala-fasthttp-20160617/header.go000066400000000000000000001445411273074646000215740ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "errors" "fmt" "io" "sync/atomic" "time" ) // ResponseHeader represents HTTP response header. // // It is forbidden copying ResponseHeader instances. // Create new instances instead and use CopyTo. // // ResponseHeader instance MUST NOT be used from concurrently running // goroutines. type ResponseHeader struct { noCopy noCopy disableNormalizing bool noHTTP11 bool connectionClose bool statusCode int contentLength int contentLengthBytes []byte contentType []byte server []byte h []argsKV bufKV argsKV cookies []argsKV } // RequestHeader represents HTTP request header. // // It is forbidden copying RequestHeader instances. // Create new instances instead and use CopyTo. // // RequestHeader instance MUST NOT be used from concurrently running // goroutines. type RequestHeader struct { noCopy noCopy disableNormalizing bool noHTTP11 bool connectionClose bool isGet bool // These two fields have been moved close to other bool fields // for reducing RequestHeader object size. cookiesCollected bool rawHeadersParsed bool contentLength int contentLengthBytes []byte method []byte requestURI []byte host []byte contentType []byte userAgent []byte h []argsKV bufKV argsKV cookies []argsKV rawHeaders []byte } // SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength' // header. func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) { b := h.bufKV.value[:0] b = append(b, strBytes...) b = append(b, ' ') b = AppendUint(b, startPos) b = append(b, '-') b = AppendUint(b, endPos) b = append(b, '/') b = AppendUint(b, contentLength) h.bufKV.value = b h.SetCanonical(strContentRange, h.bufKV.value) } // SetByteRange sets 'Range: bytes=startPos-endPos' header. // // * If startPos is negative, then 'bytes=-startPos' value is set. // * If endPos is negative, then 'bytes=startPos-' value is set. func (h *RequestHeader) SetByteRange(startPos, endPos int) { h.parseRawHeaders() b := h.bufKV.value[:0] b = append(b, strBytes...) b = append(b, '=') if startPos >= 0 { b = AppendUint(b, startPos) } else { endPos = -startPos } b = append(b, '-') if endPos >= 0 { b = AppendUint(b, endPos) } h.bufKV.value = b h.SetCanonical(strRange, h.bufKV.value) } // StatusCode returns response status code. func (h *ResponseHeader) StatusCode() int { if h.statusCode == 0 { return StatusOK } return h.statusCode } // SetStatusCode sets response status code. func (h *ResponseHeader) SetStatusCode(statusCode int) { h.statusCode = statusCode } // SetLastModified sets 'Last-Modified' header to the given value. func (h *ResponseHeader) SetLastModified(t time.Time) { h.bufKV.value = AppendHTTPDate(h.bufKV.value[:0], t) h.SetCanonical(strLastModified, h.bufKV.value) } // ConnectionClose returns true if 'Connection: close' header is set. func (h *ResponseHeader) ConnectionClose() bool { return h.connectionClose } // SetConnectionClose sets 'Connection: close' header. func (h *ResponseHeader) SetConnectionClose() { h.connectionClose = true } // ResetConnectionClose clears 'Connection: close' header if it exists. func (h *ResponseHeader) ResetConnectionClose() { if h.connectionClose { h.connectionClose = false h.h = delAllArgsBytes(h.h, strConnection) } } // ConnectionClose returns true if 'Connection: close' header is set. func (h *RequestHeader) ConnectionClose() bool { h.parseRawHeaders() return h.connectionClose } func (h *RequestHeader) connectionCloseFast() bool { // h.parseRawHeaders() isn't called for performance reasons. // Use ConnectionClose for triggering raw headers parsing. return h.connectionClose } // SetConnectionClose sets 'Connection: close' header. func (h *RequestHeader) SetConnectionClose() { // h.parseRawHeaders() isn't called for performance reasons. h.connectionClose = true } // ResetConnectionClose clears 'Connection: close' header if it exists. func (h *RequestHeader) ResetConnectionClose() { h.parseRawHeaders() if h.connectionClose { h.connectionClose = false h.h = delAllArgsBytes(h.h, strConnection) } } // ConnectionUpgrade returns true if 'Connection: Upgrade' header is set. func (h *ResponseHeader) ConnectionUpgrade() bool { return hasHeaderValue(h.Peek("Connection"), strUpgrade) } // ConnectionUpgrade returns true if 'Connection: Upgrade' header is set. func (h *RequestHeader) ConnectionUpgrade() bool { h.parseRawHeaders() return hasHeaderValue(h.Peek("Connection"), strUpgrade) } // ContentLength returns Content-Length header value. // // It may be negative: // -1 means Transfer-Encoding: chunked. // -2 means Transfer-Encoding: identity. func (h *ResponseHeader) ContentLength() int { return h.contentLength } // SetContentLength sets Content-Length header value. // // Content-Length may be negative: // -1 means Transfer-Encoding: chunked. // -2 means Transfer-Encoding: identity. func (h *ResponseHeader) SetContentLength(contentLength int) { if h.mustSkipContentLength() { return } h.contentLength = contentLength if contentLength >= 0 { h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength) h.h = delAllArgsBytes(h.h, strTransferEncoding) } else { h.contentLengthBytes = h.contentLengthBytes[:0] value := strChunked if contentLength == -2 { h.SetConnectionClose() value = strIdentity } h.h = setArgBytes(h.h, strTransferEncoding, value) } } func (h *ResponseHeader) mustSkipContentLength() bool { // From http/1.1 specs: // All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a message-body statusCode := h.StatusCode() // Fast path. if statusCode < 100 || statusCode == StatusOK { return false } // Slow path. return statusCode == StatusNotModified || statusCode == StatusNoContent || statusCode < 200 } // ContentLength returns Content-Length header value. // // It may be negative: // -1 means Transfer-Encoding: chunked. func (h *RequestHeader) ContentLength() int { if h.noBody() { return 0 } h.parseRawHeaders() return h.contentLength } // SetContentLength sets Content-Length header value. // // Negative content-length sets 'Transfer-Encoding: chunked' header. func (h *RequestHeader) SetContentLength(contentLength int) { h.parseRawHeaders() h.contentLength = contentLength if contentLength >= 0 { h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength) h.h = delAllArgsBytes(h.h, strTransferEncoding) } else { h.contentLengthBytes = h.contentLengthBytes[:0] h.h = setArgBytes(h.h, strTransferEncoding, strChunked) } } // ContentType returns Content-Type header value. func (h *ResponseHeader) ContentType() []byte { contentType := h.contentType if len(h.contentType) == 0 { contentType = defaultContentType } return contentType } // SetContentType sets Content-Type header value. func (h *ResponseHeader) SetContentType(contentType string) { h.contentType = append(h.contentType[:0], contentType...) } // SetContentTypeBytes sets Content-Type header value. func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) } // Server returns Server header value. func (h *ResponseHeader) Server() []byte { return h.server } // SetServer sets Server header value. func (h *ResponseHeader) SetServer(server string) { h.server = append(h.server[:0], server...) } // SetServerBytes sets Server header value. func (h *ResponseHeader) SetServerBytes(server []byte) { h.server = append(h.server[:0], server...) } // ContentType returns Content-Type header value. func (h *RequestHeader) ContentType() []byte { h.parseRawHeaders() return h.contentType } // SetContentType sets Content-Type header value. func (h *RequestHeader) SetContentType(contentType string) { h.parseRawHeaders() h.contentType = append(h.contentType[:0], contentType...) } // SetContentTypeBytes sets Content-Type header value. func (h *RequestHeader) SetContentTypeBytes(contentType []byte) { h.parseRawHeaders() h.contentType = append(h.contentType[:0], contentType...) } // SetMultipartFormBoundary sets the following Content-Type: // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. func (h *RequestHeader) SetMultipartFormBoundary(boundary string) { h.parseRawHeaders() b := h.bufKV.value[:0] b = append(b, strMultipartFormData...) b = append(b, ';', ' ') b = append(b, strBoundary...) b = append(b, '=') b = append(b, boundary...) h.bufKV.value = b h.SetContentTypeBytes(h.bufKV.value) } // SetMultipartFormBoundaryBytes sets the following Content-Type: // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. func (h *RequestHeader) SetMultipartFormBoundaryBytes(boundary []byte) { h.parseRawHeaders() b := h.bufKV.value[:0] b = append(b, strMultipartFormData...) b = append(b, ';', ' ') b = append(b, strBoundary...) b = append(b, '=') b = append(b, boundary...) h.bufKV.value = b h.SetContentTypeBytes(h.bufKV.value) } // MultipartFormBoundary returns boundary part // from 'multipart/form-data; boundary=...' Content-Type. func (h *RequestHeader) MultipartFormBoundary() []byte { b := h.ContentType() if !bytes.HasPrefix(b, strMultipartFormData) { return nil } b = b[len(strMultipartFormData):] if len(b) == 0 || b[0] != ';' { return nil } var n int for len(b) > 0 { n++ for len(b) > n && b[n] == ' ' { n++ } b = b[n:] if !bytes.HasPrefix(b, strBoundary) { if n = bytes.IndexByte(b, ';'); n < 0 { return nil } continue } b = b[len(strBoundary):] if len(b) == 0 || b[0] != '=' { return nil } b = b[1:] if n = bytes.IndexByte(b, ';'); n >= 0 { b = b[:n] } return b } return nil } // Host returns Host header value. func (h *RequestHeader) Host() []byte { if len(h.host) > 0 { return h.host } if !h.rawHeadersParsed { // fast path without employing full headers parsing. host := peekRawHeader(h.rawHeaders, strHost) if len(host) > 0 { h.host = append(h.host[:0], host...) return h.host } } // slow path. h.parseRawHeaders() return h.host } // SetHost sets Host header value. func (h *RequestHeader) SetHost(host string) { h.parseRawHeaders() h.host = append(h.host[:0], host...) } // SetHostBytes sets Host header value. func (h *RequestHeader) SetHostBytes(host []byte) { h.parseRawHeaders() h.host = append(h.host[:0], host...) } // UserAgent returns User-Agent header value. func (h *RequestHeader) UserAgent() []byte { h.parseRawHeaders() return h.userAgent } // SetUserAgent sets User-Agent header value. func (h *RequestHeader) SetUserAgent(userAgent string) { h.parseRawHeaders() h.userAgent = append(h.userAgent[:0], userAgent...) } // SetUserAgentBytes sets User-Agent header value. func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) { h.parseRawHeaders() h.userAgent = append(h.userAgent[:0], userAgent...) } // Referer returns Referer header value. func (h *RequestHeader) Referer() []byte { return h.PeekBytes(strReferer) } // SetReferer sets Referer header value. func (h *RequestHeader) SetReferer(referer string) { h.SetBytesK(strReferer, referer) } // SetRefererBytes sets Referer header value. func (h *RequestHeader) SetRefererBytes(referer []byte) { h.SetCanonical(strReferer, referer) } // Method returns HTTP request method. func (h *RequestHeader) Method() []byte { if len(h.method) == 0 { return strGet } return h.method } // SetMethod sets HTTP request method. func (h *RequestHeader) SetMethod(method string) { h.method = append(h.method, method...) } // SetMethodBytes sets HTTP request method. func (h *RequestHeader) SetMethodBytes(method []byte) { h.method = append(h.method[:0], method...) } // RequestURI returns RequestURI from the first HTTP request line. func (h *RequestHeader) RequestURI() []byte { requestURI := h.requestURI if len(requestURI) == 0 { requestURI = strSlash } return requestURI } // SetRequestURI sets RequestURI for the first HTTP request line. // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURI(requestURI string) { h.requestURI = append(h.requestURI[:0], requestURI...) } // SetRequestURIBytes sets RequestURI for the first HTTP request line. // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) { h.requestURI = append(h.requestURI[:0], requestURI...) } // IsGet returns true if request method is GET. func (h *RequestHeader) IsGet() bool { // Optimize fast path for GET requests. if !h.isGet { h.isGet = bytes.Equal(h.Method(), strGet) } return h.isGet } // IsPost returns true if request methos is POST. func (h *RequestHeader) IsPost() bool { return bytes.Equal(h.Method(), strPost) } // IsPut returns true if request method is PUT. func (h *RequestHeader) IsPut() bool { return bytes.Equal(h.Method(), strPut) } // IsHead returns true if request method is HEAD. func (h *RequestHeader) IsHead() bool { // Fast path if h.isGet { return false } return bytes.Equal(h.Method(), strHead) } // IsDelete returns true if request method is DELETE. func (h *RequestHeader) IsDelete() bool { return bytes.Equal(h.Method(), strDelete) } // IsHTTP11 returns true if the request is HTTP/1.1. func (h *RequestHeader) IsHTTP11() bool { return !h.noHTTP11 } // IsHTTP11 returns true if the response is HTTP/1.1. func (h *ResponseHeader) IsHTTP11() bool { return !h.noHTTP11 } // HasAcceptEncoding returns true if the header contains // the given Accept-Encoding value. func (h *RequestHeader) HasAcceptEncoding(acceptEncoding string) bool { h.bufKV.value = append(h.bufKV.value[:0], acceptEncoding...) return h.HasAcceptEncodingBytes(h.bufKV.value) } // HasAcceptEncodingBytes returns true if the header contains // the given Accept-Encoding value. func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool { ae := h.peek(strAcceptEncoding) n := bytes.Index(ae, acceptEncoding) if n < 0 { return false } b := ae[n+len(acceptEncoding):] if len(b) > 0 && b[0] != ',' { return false } if n == 0 { return true } return ae[n-1] == ' ' } // Len returns the number of headers set, // i.e. the number of times f is called in VisitAll. func (h *ResponseHeader) Len() int { n := 0 h.VisitAll(func(k, v []byte) { n++ }) return n } // Len returns the number of headers set, // i.e. the number of times f is called in VisitAll. func (h *RequestHeader) Len() int { n := 0 h.VisitAll(func(k, v []byte) { n++ }) return n } // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing // the first letter and all the first letters following dashes, // while lowercasing all the other letters. // Examples: // // * CONNECTION -> Connection // * conteNT-tYPE -> Content-Type // * foo-bar-baz -> Foo-Bar-Baz // // Disable header names' normalization only if know what are you doing. func (h *RequestHeader) DisableNormalizing() { h.disableNormalizing = true } // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing // the first letter and all the first letters following dashes, // while lowercasing all the other letters. // Examples: // // * CONNECTION -> Connection // * conteNT-tYPE -> Content-Type // * foo-bar-baz -> Foo-Bar-Baz // // Disable header names' normalization only if know what are you doing. func (h *ResponseHeader) DisableNormalizing() { h.disableNormalizing = true } // Reset clears response header. func (h *ResponseHeader) Reset() { h.disableNormalizing = false h.resetSkipNormalize() } func (h *ResponseHeader) resetSkipNormalize() { h.noHTTP11 = false h.connectionClose = false h.statusCode = 0 h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] h.contentType = h.contentType[:0] h.server = h.server[:0] h.h = h.h[:0] h.cookies = h.cookies[:0] } // Reset clears request header. func (h *RequestHeader) Reset() { h.disableNormalizing = false h.resetSkipNormalize() } func (h *RequestHeader) resetSkipNormalize() { h.noHTTP11 = false h.connectionClose = false h.isGet = false h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] h.method = h.method[:0] h.requestURI = h.requestURI[:0] h.host = h.host[:0] h.contentType = h.contentType[:0] h.userAgent = h.userAgent[:0] h.h = h.h[:0] h.cookies = h.cookies[:0] h.cookiesCollected = false h.rawHeaders = h.rawHeaders[:0] h.rawHeadersParsed = false } // CopyTo copies all the headers to dst. func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.Reset() dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 dst.connectionClose = h.connectionClose dst.statusCode = h.statusCode dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.server = append(dst.server[:0], h.server...) dst.h = copyArgs(dst.h, h.h) dst.cookies = copyArgs(dst.cookies, h.cookies) } // CopyTo copies all the headers to dst. func (h *RequestHeader) CopyTo(dst *RequestHeader) { dst.Reset() dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 dst.connectionClose = h.connectionClose dst.isGet = h.isGet dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.method = append(dst.method[:0], h.method...) dst.requestURI = append(dst.requestURI[:0], h.requestURI...) dst.host = append(dst.host[:0], h.host...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.userAgent = append(dst.userAgent[:0], h.userAgent...) dst.h = copyArgs(dst.h, h.h) dst.cookies = copyArgs(dst.cookies, h.cookies) dst.cookiesCollected = h.cookiesCollected dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...) dst.rawHeadersParsed = h.rawHeadersParsed } // VisitAll calls f for each header. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. func (h *ResponseHeader) VisitAll(f func(key, value []byte)) { if len(h.contentLengthBytes) > 0 { f(strContentLength, h.contentLengthBytes) } contentType := h.ContentType() if len(contentType) > 0 { f(strContentType, contentType) } server := h.Server() if len(server) > 0 { f(strServer, server) } if len(h.cookies) > 0 { visitArgs(h.cookies, func(k, v []byte) { f(strSetCookie, v) }) } visitArgs(h.h, f) if h.ConnectionClose() { f(strConnection, strClose) } } // VisitAllCookie calls f for each response cookie. // // Cookie name is passed in key and the whole Set-Cookie header value // is passed in value on each f invocation. Value may be parsed // with Cookie.ParseBytes(). // // f must not retain references to key and/or value after returning. func (h *ResponseHeader) VisitAllCookie(f func(key, value []byte)) { visitArgs(h.cookies, f) } // VisitAllCookie calls f for each request cookie. // // f must not retain references to key and/or value after returning. func (h *RequestHeader) VisitAllCookie(f func(key, value []byte)) { h.parseRawHeaders() h.collectCookies() visitArgs(h.cookies, f) } // VisitAll calls f for each header. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. func (h *RequestHeader) VisitAll(f func(key, value []byte)) { h.parseRawHeaders() host := h.Host() if len(host) > 0 { f(strHost, host) } if len(h.contentLengthBytes) > 0 { f(strContentLength, h.contentLengthBytes) } contentType := h.ContentType() if len(contentType) > 0 { f(strContentType, contentType) } userAgent := h.UserAgent() if len(userAgent) > 0 { f(strUserAgent, userAgent) } h.collectCookies() if len(h.cookies) > 0 { h.bufKV.value = appendRequestCookieBytes(h.bufKV.value[:0], h.cookies) f(strCookie, h.bufKV.value) } visitArgs(h.h, f) if h.ConnectionClose() { f(strConnection, strClose) } } // Del deletes header with the given key. func (h *ResponseHeader) Del(key string) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.del(k) } // DelBytes deletes header with the given key. func (h *ResponseHeader) DelBytes(key []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.del(h.bufKV.key) } func (h *ResponseHeader) del(key []byte) { switch string(key) { case "Content-Type": h.contentType = h.contentType[:0] case "Server": h.server = h.server[:0] case "Set-Cookie": h.cookies = h.cookies[:0] case "Content-Length": h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] case "Connection": h.connectionClose = false } h.h = delAllArgsBytes(h.h, key) } // Del deletes header with the given key. func (h *RequestHeader) Del(key string) { h.parseRawHeaders() k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.del(k) } // DelBytes deletes header with the given key. func (h *RequestHeader) DelBytes(key []byte) { h.parseRawHeaders() h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.del(h.bufKV.key) } func (h *RequestHeader) del(key []byte) { switch string(key) { case "Host": h.host = h.host[:0] case "Content-Type": h.contentType = h.contentType[:0] case "User-Agent": h.userAgent = h.userAgent[:0] case "Cookie": h.cookies = h.cookies[:0] case "Content-Length": h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] case "Connection": h.connectionClose = false } h.h = delAllArgsBytes(h.h, key) } // Add adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *ResponseHeader) Add(key, value string) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.h = appendArg(h.h, b2s(k), value) } // AddBytesK adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *ResponseHeader) AddBytesK(key []byte, value string) { h.Add(b2s(key), value) } // AddBytesV adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *ResponseHeader) AddBytesV(key string, value []byte) { h.Add(key, b2s(value)) } // AddBytesKV adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *ResponseHeader) AddBytesKV(key, value []byte) { h.Add(b2s(key), b2s(value)) } // Set sets the given 'key: value' header. func (h *ResponseHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } // SetBytesK sets the given 'key: value' header. func (h *ResponseHeader) SetBytesK(key []byte, value string) { h.bufKV.value = append(h.bufKV.value[:0], value...) h.SetBytesKV(key, h.bufKV.value) } // SetBytesV sets the given 'key: value' header. func (h *ResponseHeader) SetBytesV(key string, value []byte) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. func (h *ResponseHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.SetCanonical(h.bufKV.key, value) } // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *ResponseHeader) SetCanonical(key, value []byte) { switch string(key) { case "Content-Type": h.SetContentTypeBytes(value) case "Server": h.SetServerBytes(value) case "Set-Cookie": var kv *argsKV h.cookies, kv = allocArg(h.cookies) kv.key = getCookieKey(kv.key, value) kv.value = append(kv.value[:0], value...) case "Content-Length": if contentLength, err := parseContentLength(value); err == nil { h.contentLength = contentLength h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) } case "Connection": if bytes.Equal(strClose, value) { h.SetConnectionClose() } else { h.ResetConnectionClose() h.h = setArgBytes(h.h, key, value) } case "Transfer-Encoding": // Transfer-Encoding is managed automatically. case "Date": // Date is managed automatically. default: h.h = setArgBytes(h.h, key, value) } } // SetCookie sets the given response cookie. func (h *ResponseHeader) SetCookie(cookie *Cookie) { h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie()) } // SetCookie sets 'key: value' cookies. func (h *RequestHeader) SetCookie(key, value string) { h.parseRawHeaders() h.collectCookies() h.cookies = setArg(h.cookies, key, value) } // SetCookieBytesK sets 'key: value' cookies. func (h *RequestHeader) SetCookieBytesK(key []byte, value string) { h.SetCookie(b2s(key), value) } // SetCookieBytesKV sets 'key: value' cookies. func (h *RequestHeader) SetCookieBytesKV(key, value []byte) { h.SetCookie(b2s(key), b2s(value)) } // DelClientCookie instructs the client to remove the given cookie. // // Use DelCookie if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookie(key string) { h.DelCookie(key) c := AcquireCookie() c.SetKey(key) c.SetExpire(CookieExpireDelete) h.SetCookie(c) ReleaseCookie(c) } // DelClientCookieBytes instructs the client to remove the given cookie. // // Use DelCookieBytes if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookieBytes(key []byte) { h.DelClientCookie(b2s(key)) } // DelCookie removes cookie under the given key from response header. // // Note that DelCookie doesn't remove the cookie from the client. // Use DelClientCookie instead. func (h *ResponseHeader) DelCookie(key string) { h.cookies = delAllArgs(h.cookies, key) } // DelCookieBytes removes cookie under the given key from response header. // // Note that DelCookieBytes doesn't remove the cookie from the client. // Use DelClientCookieBytes instead. func (h *ResponseHeader) DelCookieBytes(key []byte) { h.DelCookie(b2s(key)) } // DelCookie removes cookie under the given key. func (h *RequestHeader) DelCookie(key string) { h.parseRawHeaders() h.collectCookies() h.cookies = delAllArgs(h.cookies, key) } // DelCookieBytes removes cookie under the given key. func (h *RequestHeader) DelCookieBytes(key []byte) { h.DelCookie(b2s(key)) } // DelAllCookies removes all the cookies from response headers. func (h *ResponseHeader) DelAllCookies() { h.cookies = h.cookies[:0] } // DelAllCookies removes all the cookies from request headers. func (h *RequestHeader) DelAllCookies() { h.parseRawHeaders() h.collectCookies() h.cookies = h.cookies[:0] } // Add adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *RequestHeader) Add(key, value string) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.h = appendArg(h.h, b2s(k), value) } // AddBytesK adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *RequestHeader) AddBytesK(key []byte, value string) { h.Add(b2s(key), value) } // AddBytesV adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *RequestHeader) AddBytesV(key string, value []byte) { h.Add(key, b2s(value)) } // AddBytesKV adds the given 'key: value' header. // // Multiple headers with the same key may be added. func (h *RequestHeader) AddBytesKV(key, value []byte) { h.Add(b2s(key), b2s(value)) } // Set sets the given 'key: value' header. func (h *RequestHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } // SetBytesK sets the given 'key: value' header. func (h *RequestHeader) SetBytesK(key []byte, value string) { h.bufKV.value = append(h.bufKV.value[:0], value...) h.SetBytesKV(key, h.bufKV.value) } // SetBytesV sets the given 'key: value' header. func (h *RequestHeader) SetBytesV(key string, value []byte) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. func (h *RequestHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.SetCanonical(h.bufKV.key, value) } // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *RequestHeader) SetCanonical(key, value []byte) { h.parseRawHeaders() switch string(key) { case "Host": h.SetHostBytes(value) case "Content-Type": h.SetContentTypeBytes(value) case "User-Agent": h.SetUserAgentBytes(value) case "Cookie": h.collectCookies() h.cookies = parseRequestCookies(h.cookies, value) case "Content-Length": if contentLength, err := parseContentLength(value); err == nil { h.contentLength = contentLength h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) } case "Connection": if bytes.Equal(strClose, value) { h.SetConnectionClose() } else { h.ResetConnectionClose() h.h = setArgBytes(h.h, key, value) } case "Transfer-Encoding": // Transfer-Encoding is managed automatically. default: h.h = setArgBytes(h.h, key, value) } } // Peek returns header value for the given key. // // Returned value is valid until the next call to ResponseHeader. // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) Peek(key string) []byte { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) return h.peek(k) } // PeekBytes returns header value for the given key. // // Returned value is valid until the next call to ResponseHeader. // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) return h.peek(h.bufKV.key) } // Peek returns header value for the given key. // // Returned value is valid until the next call to RequestHeader. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) Peek(key string) []byte { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) return h.peek(k) } // PeekBytes returns header value for the given key. // // Returned value is valid until the next call to RequestHeader. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) return h.peek(h.bufKV.key) } func (h *ResponseHeader) peek(key []byte) []byte { switch string(key) { case "Content-Type": return h.ContentType() case "Server": return h.Server() case "Connection": if h.ConnectionClose() { return strClose } return peekArgBytes(h.h, key) case "Content-Length": return h.contentLengthBytes default: return peekArgBytes(h.h, key) } } func (h *RequestHeader) peek(key []byte) []byte { h.parseRawHeaders() switch string(key) { case "Host": return h.Host() case "Content-Type": return h.ContentType() case "User-Agent": return h.UserAgent() case "Connection": if h.ConnectionClose() { return strClose } return peekArgBytes(h.h, key) case "Content-Length": return h.contentLengthBytes default: return peekArgBytes(h.h, key) } } // Cookie returns cookie for the given key. func (h *RequestHeader) Cookie(key string) []byte { h.parseRawHeaders() h.collectCookies() return peekArgStr(h.cookies, key) } // CookieBytes returns cookie for the given key. func (h *RequestHeader) CookieBytes(key []byte) []byte { h.parseRawHeaders() h.collectCookies() return peekArgBytes(h.cookies, key) } // Cookie fills cookie for the given cookie.Key. // // Returns false if cookie with the given cookie.Key is missing. func (h *ResponseHeader) Cookie(cookie *Cookie) bool { v := peekArgBytes(h.cookies, cookie.Key()) if v == nil { return false } cookie.ParseBytes(v) return true } // Read reads response header from r. // // io.EOF is returned if r is closed before reading the first header byte. func (h *ResponseHeader) Read(r *bufio.Reader) error { n := 1 for { err := h.tryRead(r, n) if err == nil { return nil } if err != errNeedMore { h.resetSkipNormalize() return err } n = r.Buffered() + 1 } } func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error { h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { // treat all errors on the first byte read as EOF if n == 1 || err == io.EOF { return io.EOF } if err == bufio.ErrBufferFull { err = bufferFullError(r) } return fmt.Errorf("error when reading response headers: %s", err) } isEOF := (err != nil) b = mustPeekBuffered(r) var headersLen int if headersLen, err = h.parse(b); err != nil { if err == errNeedMore { if !isEOF { return err } // Buggy servers may leave trailing CRLFs after response body. // Treat this case as EOF. if isOnlyCRLF(b) { return io.EOF } } bStart, bEnd := bufferStartEnd(b) return fmt.Errorf("error when reading response headers: %s. buf=%q...%q", err, bStart, bEnd) } mustDiscard(r, headersLen) return nil } // Read reads request header from r. // // io.EOF is returned if r is closed before reading the first header byte. func (h *RequestHeader) Read(r *bufio.Reader) error { n := 1 for { err := h.tryRead(r, n) if err == nil { return nil } if err != errNeedMore { h.resetSkipNormalize() return err } n = r.Buffered() + 1 } } func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error { h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { // treat all errors on the first byte read as EOF if n == 1 || err == io.EOF { return io.EOF } if err == bufio.ErrBufferFull { err = bufferFullError(r) } return fmt.Errorf("error when reading request headers: %s", err) } isEOF := (err != nil) b = mustPeekBuffered(r) var headersLen int if headersLen, err = h.parse(b); err != nil { if err == errNeedMore { if !isEOF { return err } // Buggy clients may leave trailing CRLFs after the request body. // Treat this case as EOF. if isOnlyCRLF(b) { return io.EOF } } bStart, bEnd := bufferStartEnd(b) return fmt.Errorf("error when reading request headers: %s. buf=%q...%q", err, bStart, bEnd) } mustDiscard(r, headersLen) return nil } func bufferFullError(r *bufio.Reader) error { n := r.Buffered() b, err := r.Peek(n) if err != nil { panic(fmt.Sprintf("BUG: unexpected error returned from bufio.Reader.Peek(Buffered()): %s", err)) } bStart, bEnd := bufferStartEnd(b) return fmt.Errorf("headers exceed %d bytes. Increase ReadBufferSize. buf=%q...%q", n, bStart, bEnd) } func bufferStartEnd(b []byte) ([]byte, []byte) { n := len(b) start := 200 end := n - start if start >= end { start = n end = n } return b[:start], b[end:] } func isOnlyCRLF(b []byte) bool { for _, ch := range b { if ch != '\r' && ch != '\n' { return false } } return true } func init() { refreshServerDate() go func() { for { time.Sleep(time.Second) refreshServerDate() } }() } var serverDate atomic.Value func refreshServerDate() { b := AppendHTTPDate(nil, time.Now()) serverDate.Store(b) } // Write writes response header to w. func (h *ResponseHeader) Write(w *bufio.Writer) error { _, err := w.Write(h.Header()) return err } // WriteTo writes response header to w. // // WriteTo implements io.WriterTo interface. func (h *ResponseHeader) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(h.Header()) return int64(n), err } // Header returns response header representation. // // The returned value is valid until the next call to ResponseHeader methods. func (h *ResponseHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value } // String returns response header representation. func (h *ResponseHeader) String() string { return string(h.Header()) } // AppendBytes appends response header representation to dst and returns // the extended dst. func (h *ResponseHeader) AppendBytes(dst []byte) []byte { statusCode := h.StatusCode() if statusCode < 0 { statusCode = StatusOK } dst = append(dst, statusLine(statusCode)...) server := h.Server() if len(server) == 0 { server = defaultServerName } dst = appendHeaderLine(dst, strServer, server) dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte)) // Append Content-Type only for non-zero responses // or if it is explicitly set. // See https://github.com/valyala/fasthttp/issues/28 . if h.ContentLength() != 0 || len(h.contentType) > 0 { dst = appendHeaderLine(dst, strContentType, h.ContentType()) } if len(h.contentLengthBytes) > 0 { dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes) } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] if !bytes.Equal(kv.key, strDate) { dst = appendHeaderLine(dst, kv.key, kv.value) } } n := len(h.cookies) if n > 0 { for i := 0; i < n; i++ { kv := &h.cookies[i] dst = appendHeaderLine(dst, strSetCookie, kv.value) } } if h.ConnectionClose() { dst = appendHeaderLine(dst, strConnection, strClose) } return append(dst, strCRLF...) } // Write writes request header to w. func (h *RequestHeader) Write(w *bufio.Writer) error { _, err := w.Write(h.Header()) return err } // WriteTo writes request header to w. // // WriteTo implements io.WriterTo interface. func (h *RequestHeader) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(h.Header()) return int64(n), err } // Header returns request header representation. // // The returned representation is valid until the next call to RequestHeader methods. func (h *RequestHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value } // String returns request header representation. func (h *RequestHeader) String() string { return string(h.Header()) } // AppendBytes appends request header representation to dst and returns // the extended dst. func (h *RequestHeader) AppendBytes(dst []byte) []byte { // there is no need in h.parseRawHeaders() here - raw headers are specially handled below. dst = append(dst, h.Method()...) dst = append(dst, ' ') dst = append(dst, h.RequestURI()...) dst = append(dst, ' ') dst = append(dst, strHTTP11...) dst = append(dst, strCRLF...) if !h.rawHeadersParsed && len(h.rawHeaders) > 0 { return append(dst, h.rawHeaders...) } userAgent := h.UserAgent() if len(userAgent) == 0 { userAgent = defaultUserAgent } dst = appendHeaderLine(dst, strUserAgent, userAgent) host := h.Host() if len(host) > 0 { dst = appendHeaderLine(dst, strHost, host) } contentType := h.ContentType() if !h.noBody() { if len(contentType) == 0 { contentType = strPostArgsContentType } dst = appendHeaderLine(dst, strContentType, contentType) if len(h.contentLengthBytes) > 0 { dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes) } } else if len(contentType) > 0 { dst = appendHeaderLine(dst, strContentType, contentType) } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] dst = appendHeaderLine(dst, kv.key, kv.value) } // there is no need in h.collectCookies() here, since if cookies aren't collected yet, // they all are located in h.h. n := len(h.cookies) if n > 0 { dst = append(dst, strCookie...) dst = append(dst, strColonSpace...) dst = appendRequestCookieBytes(dst, h.cookies) dst = append(dst, strCRLF...) } if h.ConnectionClose() { dst = appendHeaderLine(dst, strConnection, strClose) } return append(dst, strCRLF...) } func appendHeaderLine(dst, key, value []byte) []byte { dst = append(dst, key...) dst = append(dst, strColonSpace...) dst = append(dst, value...) return append(dst, strCRLF...) } func (h *ResponseHeader) parse(buf []byte) (int, error) { m, err := h.parseFirstLine(buf) if err != nil { return 0, err } n, err := h.parseHeaders(buf[m:]) if err != nil { return 0, err } return m + n, nil } func (h *RequestHeader) noBody() bool { return h.IsGet() || h.IsHead() } func (h *RequestHeader) parse(buf []byte) (int, error) { m, err := h.parseFirstLine(buf) if err != nil { return 0, err } var n int if !h.noBody() || h.noHTTP11 { n, err = h.parseHeaders(buf[m:]) if err != nil { return 0, err } h.rawHeadersParsed = true } else { var rawHeaders []byte rawHeaders, n, err = readRawHeaders(h.rawHeaders[:0], buf[m:]) if err != nil { return 0, err } h.rawHeaders = rawHeaders } return m + n, nil } func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) { bNext := buf var b []byte var err error for len(b) == 0 { if b, bNext, err = nextLine(bNext); err != nil { return 0, err } } // parse protocol n := bytes.IndexByte(b, ' ') if n < 0 { return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf) } h.noHTTP11 = !bytes.Equal(b[:n], strHTTP11) b = b[n+1:] // parse status code h.statusCode, n, err = parseUintBuf(b) if err != nil { return 0, fmt.Errorf("cannot parse response status code: %s. Response %q", err, buf) } if len(b) > n && b[n] != ' ' { return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf) } return len(buf) - len(bNext), nil } func (h *RequestHeader) parseFirstLine(buf []byte) (int, error) { bNext := buf var b []byte var err error for len(b) == 0 { if b, bNext, err = nextLine(bNext); err != nil { return 0, err } } // parse method n := bytes.IndexByte(b, ' ') if n <= 0 { return 0, fmt.Errorf("cannot find http request method in %q", buf) } h.method = append(h.method[:0], b[:n]...) b = b[n+1:] // parse requestURI n = bytes.LastIndexByte(b, ' ') if n < 0 { h.noHTTP11 = true n = len(b) } else if n == 0 { return 0, fmt.Errorf("requestURI cannot be empty in %q", buf) } else if !bytes.Equal(b[n+1:], strHTTP11) { h.noHTTP11 = true } h.requestURI = append(h.requestURI[:0], b[:n]...) return len(buf) - len(bNext), nil } func peekRawHeader(buf, key []byte) []byte { n := bytes.Index(buf, key) if n < 0 { return nil } if n > 0 && buf[n-1] != '\n' { return nil } n += len(key) if n >= len(buf) { return nil } if buf[n] != ':' { return nil } n++ if buf[n] != ' ' { return nil } n++ buf = buf[n:] n = bytes.IndexByte(buf, '\n') if n < 0 { return nil } if n > 0 && buf[n-1] == '\r' { n-- } return buf[:n] } func readRawHeaders(dst, buf []byte) ([]byte, int, error) { n := bytes.IndexByte(buf, '\n') if n < 0 { return nil, 0, errNeedMore } if (n == 1 && buf[0] == '\r') || n == 0 { // empty headers return dst, n + 1, nil } n++ b := buf m := n for { b = b[m:] m = bytes.IndexByte(b, '\n') if m < 0 { return nil, 0, errNeedMore } m++ n += m if (m == 2 && b[0] == '\r') || m == 1 { dst = append(dst, buf[:n]...) return dst, n, nil } } } func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { // 'identity' content-length by default h.contentLength = -2 var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing var err error var kv *argsKV for s.next() { switch string(s.key) { case "Content-Type": h.contentType = append(h.contentType[:0], s.value...) case "Server": h.server = append(h.server[:0], s.value...) case "Content-Length": if h.contentLength != -1 { if h.contentLength, err = parseContentLength(s.value); err != nil { h.contentLength = -2 } else { h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) } } case "Transfer-Encoding": if !bytes.Equal(s.value, strIdentity) { h.contentLength = -1 h.h = setArgBytes(h.h, strTransferEncoding, strChunked) } case "Set-Cookie": h.cookies, kv = allocArg(h.cookies) kv.key = getCookieKey(kv.key, s.value) kv.value = append(kv.value[:0], s.value...) case "Connection": if bytes.Equal(s.value, strClose) { h.connectionClose = true } else { h.connectionClose = false h.h = appendArgBytes(h.h, s.key, s.value) } default: h.h = appendArgBytes(h.h, s.key, s.value) } } if s.err != nil { h.connectionClose = true return 0, s.err } if h.contentLength < 0 { h.contentLengthBytes = h.contentLengthBytes[:0] } if h.contentLength == -2 && !h.ConnectionUpgrade() && !h.mustSkipContentLength() { h.h = setArgBytes(h.h, strTransferEncoding, strIdentity) h.connectionClose = true } if h.noHTTP11 && !h.connectionClose { // close connection for non-http/1.1 response unless 'Connection: keep-alive' is set. v := peekArgBytes(h.h, strConnection) h.connectionClose = !hasHeaderValue(v, strKeepAlive) && !hasHeaderValue(v, strKeepAliveCamelCase) } return len(buf) - len(s.b), nil } func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { h.contentLength = -2 var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing var err error for s.next() { switch string(s.key) { case "Host": h.host = append(h.host[:0], s.value...) case "User-Agent": h.userAgent = append(h.userAgent[:0], s.value...) case "Content-Type": h.contentType = append(h.contentType[:0], s.value...) case "Content-Length": if h.contentLength != -1 { if h.contentLength, err = parseContentLength(s.value); err != nil { h.contentLength = -2 } else { h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) } } case "Transfer-Encoding": if !bytes.Equal(s.value, strIdentity) { h.contentLength = -1 h.h = setArgBytes(h.h, strTransferEncoding, strChunked) } case "Connection": if bytes.Equal(s.value, strClose) { h.connectionClose = true } else { h.connectionClose = false h.h = appendArgBytes(h.h, s.key, s.value) } default: h.h = appendArgBytes(h.h, s.key, s.value) } } if s.err != nil { h.connectionClose = true return 0, s.err } if h.contentLength < 0 { h.contentLengthBytes = h.contentLengthBytes[:0] } if h.noBody() { h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] } if h.noHTTP11 && !h.connectionClose { // close connection for non-http/1.1 request unless 'Connection: keep-alive' is set. v := peekArgBytes(h.h, strConnection) h.connectionClose = !hasHeaderValue(v, strKeepAlive) && !hasHeaderValue(v, strKeepAliveCamelCase) } return len(buf) - len(s.b), nil } func (h *RequestHeader) parseRawHeaders() { if h.rawHeadersParsed { return } h.rawHeadersParsed = true if len(h.rawHeaders) == 0 { return } h.parseHeaders(h.rawHeaders) } func (h *RequestHeader) collectCookies() { if h.cookiesCollected { return } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] if bytes.Equal(kv.key, strCookie) { h.cookies = parseRequestCookies(h.cookies, kv.value) tmp := *kv copy(h.h[i:], h.h[i+1:]) n-- i-- h.h[n] = tmp h.h = h.h[:n] } } h.cookiesCollected = true } func parseContentLength(b []byte) (int, error) { v, n, err := parseUintBuf(b) if err != nil { return -1, err } if n != len(b) { return -1, fmt.Errorf("non-numeric chars at the end of Content-Length") } return v, nil } type headerScanner struct { b []byte key []byte value []byte err error disableNormalizing bool } func (s *headerScanner) next() bool { bLen := len(s.b) if bLen >= 2 && s.b[0] == '\r' && s.b[1] == '\n' { s.b = s.b[2:] return false } if bLen >= 1 && s.b[0] == '\n' { s.b = s.b[1:] return false } n := bytes.IndexByte(s.b, ':') if n < 0 { s.err = errNeedMore return false } s.key = s.b[:n] normalizeHeaderKey(s.key, s.disableNormalizing) n++ for len(s.b) > n && s.b[n] == ' ' { n++ } s.b = s.b[n:] n = bytes.IndexByte(s.b, '\n') if n < 0 { s.err = errNeedMore return false } s.value = s.b[:n] s.b = s.b[n+1:] if n > 0 && s.value[n-1] == '\r' { n-- } for n > 0 && s.value[n-1] == ' ' { n-- } s.value = s.value[:n] return true } type headerValueScanner struct { b []byte value []byte } func (s *headerValueScanner) next() bool { b := s.b if len(b) == 0 { return false } n := bytes.IndexByte(b, ',') if n < 0 { s.value = stripSpace(b) s.b = b[len(b):] return true } s.value = stripSpace(b[:n]) s.b = b[n+1:] return true } func stripSpace(b []byte) []byte { for len(b) > 0 && b[0] == ' ' { b = b[1:] } for len(b) > 0 && b[len(b)-1] == ' ' { b = b[:len(b)-1] } return b } func hasHeaderValue(s, value []byte) bool { var vs headerValueScanner vs.b = s for vs.next() { if bytes.Equal(vs.value, value) { return true } } return false } func nextLine(b []byte) ([]byte, []byte, error) { nNext := bytes.IndexByte(b, '\n') if nNext < 0 { return nil, nil, errNeedMore } n := nNext if n > 0 && b[n-1] == '\r' { n-- } return b[:n], b[nNext+1:], nil } func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) { kv.key = getHeaderKeyBytes(kv, key, disableNormalizing) kv.value = append(kv.value[:0], value...) } func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte { kv.key = append(kv.key[:0], key...) normalizeHeaderKey(kv.key, disableNormalizing) return kv.key } func normalizeHeaderKey(b []byte, disableNormalizing bool) { if disableNormalizing { return } n := len(b) up := true for i := 0; i < n; i++ { switch b[i] { case '-': up = true default: if up { up = false uppercaseByte(&b[i]) } else { lowercaseByte(&b[i]) } } } } // AppendNormalizedHeaderKey appends normalized header key (name) to dst // and returns the resulting dst. // // Normalized header key starts with uppercase letter. The first letters // after dashes are also uppercased. All the other letters are lowercased. // Examples: // // * coNTENT-TYPe -> Content-Type // * HOST -> Host // * foo-bar-baz -> Foo-Bar-Baz func AppendNormalizedHeaderKey(dst []byte, key string) []byte { dst = append(dst, key...) normalizeHeaderKey(dst[len(dst)-len(key):], false) return dst } // AppendNormalizedHeaderKeyBytes appends normalized header key (name) to dst // and returns the resulting dst. // // Normalized header key starts with uppercase letter. The first letters // after dashes are also uppercased. All the other letters are lowercased. // Examples: // // * coNTENT-TYPe -> Content-Type // * HOST -> Host // * foo-bar-baz -> Foo-Bar-Baz func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte { return AppendNormalizedHeaderKey(dst, b2s(key)) } var errNeedMore = errors.New("need more data: cannot find trailing lf") func mustPeekBuffered(r *bufio.Reader) []byte { buf, err := r.Peek(r.Buffered()) if len(buf) == 0 || err != nil { panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err)) } return buf } func mustDiscard(r *bufio.Reader, n int) { if _, err := r.Discard(n); err != nil { panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err)) } } golang-github-valyala-fasthttp-20160617/header_regression_test.go000066400000000000000000000050121273074646000250600ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "fmt" "strings" "testing" ) func TestIssue28ResponseWithoutBodyNoContentType(t *testing.T) { var r Response // Empty response without content-type s := r.String() if strings.Contains(s, "Content-Type") { t.Fatalf("unexpected Content-Type found in response header with empty body: %q", s) } // Explicitly set content-type r.Header.SetContentType("foo/bar") s = r.String() if !strings.Contains(s, "Content-Type: foo/bar\r\n") { t.Fatalf("missing explicitly set content-type for empty response: %q", s) } // Non-empty response. r.Reset() r.SetBodyString("foobar") s = r.String() if !strings.Contains(s, fmt.Sprintf("Content-Type: %s\r\n", defaultContentType)) { t.Fatalf("missing default content-type for non-empty response: %q", s) } // Non-empty response with custom content-type. r.Header.SetContentType("aaa/bbb") s = r.String() if !strings.Contains(s, "Content-Type: aaa/bbb\r\n") { t.Fatalf("missing custom content-type: %q", s) } } func TestIssue6RequestHeaderSetContentType(t *testing.T) { testIssue6RequestHeaderSetContentType(t, "GET") testIssue6RequestHeaderSetContentType(t, "POST") testIssue6RequestHeaderSetContentType(t, "PUT") testIssue6RequestHeaderSetContentType(t, "PATCH") } func testIssue6RequestHeaderSetContentType(t *testing.T, method string) { contentType := "application/json" contentLength := 123 var h RequestHeader h.SetMethod(method) h.SetRequestURI("http://localhost/test") h.SetContentType(contentType) h.SetContentLength(contentLength) issue6VerifyRequestHeader(t, &h, contentType, contentLength, method) s := h.String() var h1 RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } issue6VerifyRequestHeader(t, &h1, contentType, contentLength, method) } func issue6VerifyRequestHeader(t *testing.T, h *RequestHeader, contentType string, contentLength int, method string) { if string(h.ContentType()) != contentType { t.Fatalf("unexpected content-type: %q. Expecting %q. method=%q", h.ContentType(), contentType, method) } if string(h.Method()) != method { t.Fatalf("unexpected method: %q. Expecting %q", h.Method(), method) } if method != "GET" { if h.ContentLength() != contentLength { t.Fatalf("unexpected content-length: %d. Expecting %d. method=%q", h.ContentLength(), contentLength, method) } } else if h.ContentLength() != 0 { t.Fatalf("unexpected content-length for GET method: %d. Expecting 0", h.ContentLength()) } } golang-github-valyala-fasthttp-20160617/header_test.go000066400000000000000000001722041273074646000226300ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "fmt" "io" "io/ioutil" "strings" "testing" ) func TestResponseHeaderDefaultStatusCode(t *testing.T) { var h ResponseHeader statusCode := h.StatusCode() if statusCode != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } } func TestResponseHeaderDelClientCookie(t *testing.T) { cookieName := "foobar" var h ResponseHeader c := AcquireCookie() c.SetKey(cookieName) c.SetValue("aasdfsdaf") h.SetCookie(c) h.DelClientCookieBytes([]byte(cookieName)) if !h.Cookie(c) { t.Fatalf("expecting cookie %q", c.Key()) } if !c.Expire().Equal(CookieExpireDelete) { t.Fatalf("unexpected cookie expiration time: %s. Expecting %s", c.Expire(), CookieExpireDelete) } if len(c.Value()) > 0 { t.Fatalf("unexpected cookie value: %q. Expecting empty value", c.Value()) } ReleaseCookie(c) } func TestResponseHeaderAdd(t *testing.T) { m := make(map[string]struct{}) var h ResponseHeader h.Add("aaa", "bbb") m["bbb"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } if h.Len() != 12 { t.Fatalf("unexpected header len %d. Expecting 12", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) > 0 { t.Fatalf("%d headers are missed", len(m)) } s := h.String() br := bufio.NewReader(bytes.NewBufferString(s)) var h1 ResponseHeader if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar": m[string(v)] = struct{}{} case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) != 11 { t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) } } func TestRequestHeaderAdd(t *testing.T) { m := make(map[string]struct{}) var h RequestHeader h.Add("aaa", "bbb") m["bbb"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } if h.Len() != 11 { t.Fatalf("unexpected header len %d. Expecting 11", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) > 0 { t.Fatalf("%d headers are missed", len(m)) } s := h.String() br := bufio.NewReader(bytes.NewBufferString(s)) var h1 RequestHeader if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar": m[string(v)] = struct{}{} case "User-Agent": default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) != 11 { t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) } s1 := h1.String() if s != s1 { t.Fatalf("unexpected headers %q. Expecting %q", s1, s) } } func TestHasHeaderValue(t *testing.T) { testHasHeaderValue(t, "foobar", "foobar", true) testHasHeaderValue(t, "foobar", "foo", false) testHasHeaderValue(t, "foobar", "bar", false) testHasHeaderValue(t, "keep-alive, Upgrade", "keep-alive", true) testHasHeaderValue(t, "keep-alive , Upgrade", "Upgrade", true) testHasHeaderValue(t, "keep-alive, Upgrade", "Upgrade-foo", false) testHasHeaderValue(t, "keep-alive, Upgrade", "Upgr", false) testHasHeaderValue(t, "foo , bar, baz ,", "foo", true) testHasHeaderValue(t, "foo , bar, baz ,", "bar", true) testHasHeaderValue(t, "foo , bar, baz ,", "baz", true) testHasHeaderValue(t, "foo , bar, baz ,", "ba", false) testHasHeaderValue(t, "foo, ", "", true) testHasHeaderValue(t, "foo", "", false) } func testHasHeaderValue(t *testing.T, s, value string, has bool) { ok := hasHeaderValue([]byte(s), []byte(value)) if ok != has { t.Fatalf("unexpected hasHeaderValue(%q, %q)=%v. Expecting %v", s, value, ok, has) } } func TestRequestHeaderDel(t *testing.T) { var h RequestHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") h.Set("Connection", "keep-alive") h.Set("Content-Type", "aaa") h.Set("Host", "aaabbb") h.Set("User-Agent", "asdfas") h.Set("Content-Length", "1123") h.Set("Cookie", "foobar=baz") h.Del("foo-bar") h.Del("connection") h.DelBytes([]byte("content-type")) h.Del("Host") h.Del("user-agent") h.Del("content-length") h.Del("cookie") hv := h.Peek("aaa") if string(hv) != "bbb" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") } hv = h.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Connection") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Content-Type") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Host") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("User-Agent") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Content-Length") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Cookie") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } cv := h.Cookie("foobar") if len(cv) > 0 { t.Fatalf("unexpected cookie obtianed: %q", cv) } if h.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) } } func TestResponseHeaderDel(t *testing.T) { var h ResponseHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") h.Set("Connection", "keep-alive") h.Set("Content-Type", "aaa") h.Set("Server", "aaabbb") h.Set("Content-Length", "1123") var c Cookie c.SetKey("foo") c.SetValue("bar") h.SetCookie(&c) h.Del("foo-bar") h.Del("connection") h.DelBytes([]byte("content-type")) h.Del("Server") h.Del("content-length") h.Del("set-cookie") hv := h.Peek("aaa") if string(hv) != "bbb" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") } hv = h.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("non-zero header value: %q", hv) } hv = h.Peek("Connection") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Content-Type") if string(hv) != string(defaultContentType) { t.Fatalf("unexpected content-type: %q. Expecting %q", hv, defaultContentType) } hv = h.Peek("Server") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek("Content-Length") if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } if h.Cookie(&c) { t.Fatalf("unexpected cookie obtianed: %q", &c) } if h.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) } } func TestAppendNormalizedHeaderKeyBytes(t *testing.T) { testAppendNormalizedHeaderKeyBytes(t, "", "") testAppendNormalizedHeaderKeyBytes(t, "Content-Type", "Content-Type") testAppendNormalizedHeaderKeyBytes(t, "foO-bAr-BAZ", "Foo-Bar-Baz") } func testAppendNormalizedHeaderKeyBytes(t *testing.T, key, expectedKey string) { buf := []byte("foobar") result := AppendNormalizedHeaderKeyBytes(buf, []byte(key)) normalizedKey := result[len(buf):] if string(normalizedKey) != expectedKey { t.Fatalf("unexpected normalized key %q. Expecting %q", normalizedKey, expectedKey) } } func TestRequestHeaderHTTP10ConnectionClose(t *testing.T) { s := "GET / HTTP/1.0\r\nHost: foobar\r\n\r\n" var h RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if !h.connectionCloseFast() { t.Fatalf("expecting 'Connection: close' request header") } if !h.ConnectionClose() { t.Fatalf("expecting 'Connection: close' request header") } } func TestRequestHeaderHTTP10ConnectionKeepAlive(t *testing.T) { s := "GET / HTTP/1.0\r\nHost: foobar\r\nConnection: keep-alive\r\n\r\n" var h RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if h.ConnectionClose() { t.Fatalf("unexpected 'Connection: close' request header") } } func TestBufferStartEnd(t *testing.T) { testBufferStartEnd(t, "", "", "") testBufferStartEnd(t, "foobar", "foobar", "") b := string(createFixedBody(199)) testBufferStartEnd(t, b, b, "") for i := 0; i < 10; i++ { b += "foobar" testBufferStartEnd(t, b, b, "") } b = string(createFixedBody(400)) testBufferStartEnd(t, b, b, "") for i := 0; i < 10; i++ { b += "sadfqwer" testBufferStartEnd(t, b, b[:200], b[len(b)-200:]) } } func testBufferStartEnd(t *testing.T, buf, expectedStart, expectedEnd string) { start, end := bufferStartEnd([]byte(buf)) if string(start) != expectedStart { t.Fatalf("unexpected start %q. Expecting %q. buf %q", start, expectedStart, buf) } if string(end) != expectedEnd { t.Fatalf("unexpected end %q. Expecting %q. buf %q", end, expectedEnd, buf) } } func TestResponseHeaderTrailingCRLFSuccess(t *testing.T) { trailingCRLF := "\r\n\r\n\r\n" s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF var r ResponseHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } // try reading the trailing CRLF. It must return EOF err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } } func TestResponseHeaderTrailingCRLFError(t *testing.T) { trailingCRLF := "\r\nerror\r\n\r\n" s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF var r ResponseHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } // try reading the trailing CRLF. It must return EOF err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("unexpected error: %s", err) } } func TestRequestHeaderTrailingCRLFSuccess(t *testing.T) { trailingCRLF := "\r\n\r\n\r\n" s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF var r RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } // try reading the trailing CRLF. It must return EOF err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } } func TestRequestHeaderTrailingCRLFError(t *testing.T) { trailingCRLF := "\r\nerror\r\n\r\n" s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF var r RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } // try reading the trailing CRLF. It must return EOF err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("unexpected error: %s", err) } } func TestRequestHeaderReadEOF(t *testing.T) { var r RequestHeader br := bufio.NewReader(&bytes.Buffer{}) err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } // incomplete request header mustn't return io.EOF br = bufio.NewReader(bytes.NewBufferString("GET ")) err = r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("expecting non-EOF error") } } func TestResponseHeaderReadEOF(t *testing.T) { var r ResponseHeader br := bufio.NewReader(&bytes.Buffer{}) err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } // incomplete response header mustn't return io.EOF br = bufio.NewReader(bytes.NewBufferString("HTTP/1.1 ")) err = r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("expecting non-EOF error") } } func TestResponseHeaderOldVersion(t *testing.T) { var h ResponseHeader s := "HTTP/1.0 200 OK\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345" s += "HTTP/1.0 200 OK\r\nContent-Length: 2\r\nContent-Type: ass\r\nConnection: keep-alive\r\n\r\n42" br := bufio.NewReader(bytes.NewBufferString(s)) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if !h.ConnectionClose() { t.Fatalf("expecting 'Connection: close' for the response with old http protocol") } if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if h.ConnectionClose() { t.Fatalf("unexpected 'Connection: close' for keep-alive response with old http protocol") } } func TestRequestHeaderSetByteRange(t *testing.T) { testRequestHeaderSetByteRange(t, 0, 10, "bytes=0-10") testRequestHeaderSetByteRange(t, 123, -1, "bytes=123-") testRequestHeaderSetByteRange(t, -234, 58349, "bytes=-234") } func testRequestHeaderSetByteRange(t *testing.T, startPos, endPos int, expectedV string) { var h RequestHeader h.SetByteRange(startPos, endPos) v := h.Peek("Range") if string(v) != expectedV { t.Fatalf("unexpected range: %q. Expecting %q. startPos=%d, endPos=%d", v, expectedV, startPos, endPos) } } func TestResponseHeaderSetContentRange(t *testing.T) { testResponseHeaderSetContentRange(t, 0, 0, 1, "bytes 0-0/1") testResponseHeaderSetContentRange(t, 123, 456, 789, "bytes 123-456/789") } func testResponseHeaderSetContentRange(t *testing.T, startPos, endPos, contentLength int, expectedV string) { var h ResponseHeader h.SetContentRange(startPos, endPos, contentLength) v := h.Peek("Content-Range") if string(v) != expectedV { t.Fatalf("unexpected content-range: %q. Expecting %q. startPos=%d, endPos=%d, contentLength=%d", v, expectedV, startPos, endPos, contentLength) } } func TestRequestHeaderHasAcceptEncoding(t *testing.T) { testRequestHeaderHasAcceptEncoding(t, "", "gzip", false) testRequestHeaderHasAcceptEncoding(t, "gzip", "sdhc", false) testRequestHeaderHasAcceptEncoding(t, "deflate", "deflate", true) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzi", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "dhc", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdh", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "zip", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flat", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flate", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "def", false) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzip", true) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "deflate", true) testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdhc", true) } func testRequestHeaderHasAcceptEncoding(t *testing.T, ae, v string, resultExpected bool) { var h RequestHeader h.Set("Accept-Encoding", ae) result := h.HasAcceptEncoding(v) if result != resultExpected { t.Fatalf("unexpected result in HasAcceptEncoding(%q, %q): %v. Expecting %v", ae, v, result, resultExpected) } } func TestRequestMultipartFormBoundary(t *testing.T) { testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=foobar\r\n\r\n", "foobar") // incorrect content-type testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: foo/bar\r\n\r\n", "") // empty boundary testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\r\n\r\n", "") // missing boundary testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data\r\n\r\n", "") // boundary after other content-type params testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; foo=bar; boundary=--aaabb \r\n\r\n", "--aaabb") var h RequestHeader h.SetMultipartFormBoundary("foobarbaz") b := h.MultipartFormBoundary() if string(b) != "foobarbaz" { t.Fatalf("unexpected boundary %q. Expecting %q", b, "foobarbaz") } } func testRequestMultipartFormBoundary(t *testing.T, s, boundary string) { var h RequestHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. s=%q, boundary=%q", err, s, boundary) } b := h.MultipartFormBoundary() if string(b) != boundary { t.Fatalf("unexpected boundary %q. Expecting %q. s=%q", b, boundary, s) } } func TestResponseHeaderConnectionUpgrade(t *testing.T) { testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, HTTP2-Settings\r\n\r\n", true, true) testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: keep-alive, Upgrade\r\n\r\n", true, true) // non-http/1.1 protocol has 'connection: close' by default, which also disables 'connection: upgrade' testResponseHeaderConnectionUpgrade(t, "HTTP/1.0 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, HTTP2-Settings\r\n\r\n", false, false) // explicit keep-alive for non-http/1.1, so 'connection: upgrade' works testResponseHeaderConnectionUpgrade(t, "HTTP/1.0 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, keep-alive\r\n\r\n", true, true) // implicit keep-alive for http/1.1 testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n", false, true) // no content-length, so 'connection: close' is assumed testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\n\r\n", false, false) } func testResponseHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKeepAlive bool) { var h ResponseHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. Response header %q", err, s) } upgrade := h.ConnectionUpgrade() if upgrade != isUpgrade { t.Fatalf("unexpected 'connection: upgrade' when parsing response header: %v. Expecting %v. header %q. v=%q", upgrade, isUpgrade, s, h.Peek("Connection")) } keepAlive := !h.ConnectionClose() if keepAlive != isKeepAlive { t.Fatalf("unexpected 'connection: keep-alive' when parsing response header: %v. Expecting %v. header %q. v=%q", keepAlive, isKeepAlive, s, &h) } } func TestRequestHeaderConnectionUpgrade(t *testing.T) { testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: Upgrade, HTTP2-Settings\r\nHost: foobar.com\r\n\r\n", true, true) testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: keep-alive,Upgrade\r\nHost: foobar.com\r\n\r\n", true, true) // non-http/1.1 has 'connection: close' by default, which resets 'connection: upgrade' testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.0\r\nConnection: Upgrade, HTTP2-Settings\r\nHost: foobar.com\r\n\r\n", false, false) // explicit 'connection: keep-alive' in non-http/1.1 testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.0\r\nConnection: foo, Upgrade, keep-alive\r\nHost: foobar.com\r\n\r\n", true, true) // no upgrade testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: Upgradess, foobar\r\nHost: foobar.com\r\n\r\n", false, true) testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nHost: foobar.com\r\n\r\n", false, true) // explicit connection close testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: close\r\nHost: foobar.com\r\n\r\n", false, false) } func testRequestHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKeepAlive bool) { var h RequestHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. Request header %q", err, s) } upgrade := h.ConnectionUpgrade() if upgrade != isUpgrade { t.Fatalf("unexpected 'connection: upgrade' when parsing request header: %v. Expecting %v. header %q", upgrade, isUpgrade, s) } keepAlive := !h.ConnectionClose() if keepAlive != isKeepAlive { t.Fatalf("unexpected 'connection: keep-alive' when parsing request header: %v. Expecting %v. header %q", keepAlive, isKeepAlive, s) } } func TestRequestHeaderProxyWithCookie(t *testing.T) { // Proxy request header (read it, then write it without touching any headers). var h RequestHeader r := bytes.NewBufferString("GET /foo HTTP/1.1\r\nFoo: bar\r\nHost: aaa.com\r\nCookie: foo=bar; bazzz=aaaaaaa; x=y\r\nCookie: aqqqqq=123\r\n\r\n") br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } w := &bytes.Buffer{} bw := bufio.NewWriter(w) if err := h.Write(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var h1 RequestHeader br.Reset(w) if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h1.RequestURI()) != "/foo" { t.Fatalf("unexpected requestURI: %q. Expecting %q", h1.RequestURI(), "/foo") } if string(h1.Host()) != "aaa.com" { t.Fatalf("unexpected host: %q. Expecting %q", h1.Host(), "aaa.com") } if string(h1.Peek("Foo")) != "bar" { t.Fatalf("unexpected Foo: %q. Expecting %q", h1.Peek("Foo"), "bar") } if string(h1.Cookie("foo")) != "bar" { t.Fatalf("unexpected coookie foo=%q. Expecting %q", h1.Cookie("foo"), "bar") } if string(h1.Cookie("bazzz")) != "aaaaaaa" { t.Fatalf("unexpected cookie bazzz=%q. Expecting %q", h1.Cookie("bazzz"), "aaaaaaa") } if string(h1.Cookie("x")) != "y" { t.Fatalf("unexpected cookie x=%q. Expecting %q", h1.Cookie("x"), "y") } if string(h1.Cookie("aqqqqq")) != "123" { t.Fatalf("unexpected cookie aqqqqq=%q. Expecting %q", h1.Cookie("aqqqqq"), "123") } } func TestPeekRawHeader(t *testing.T) { // empty header testPeekRawHeader(t, "", "Foo-Bar", "") // different case testPeekRawHeader(t, "Content-Length: 3443\r\n", "content-length", "") // no trailing crlf testPeekRawHeader(t, "Content-Length: 234", "Content-Length", "") // single header testPeekRawHeader(t, "Content-Length: 12345\r\n", "Content-Length", "12345") // multiple headers testPeekRawHeader(t, "Host: foobar\r\nContent-Length: 434\r\nFoo: bar\r\n\r\n", "Content-Length", "434") // lf without cr testPeekRawHeader(t, "Foo: bar\nConnection: close\nAaa: bbb\ncc: ddd\n", "Connection", "close") } func testPeekRawHeader(t *testing.T, rawHeaders, key string, expectedValue string) { v := peekRawHeader([]byte(rawHeaders), []byte(key)) if string(v) != expectedValue { t.Fatalf("unexpected raw headers value %q. Expected %q. key %q, rawHeaders %q", v, expectedValue, key, rawHeaders) } } func TestResponseHeaderFirstByteReadEOF(t *testing.T) { var h ResponseHeader r := &errorReader{fmt.Errorf("non-eof error")} br := bufio.NewReader(r) err := h.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error %s. Expecting %s", err, io.EOF) } } func TestRequestHeaderFirstByteReadEOF(t *testing.T) { var h RequestHeader r := &errorReader{fmt.Errorf("non-eof error")} br := bufio.NewReader(r) err := h.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error %s. Expecting %s", err, io.EOF) } } type errorReader struct { err error } func (r *errorReader) Read(p []byte) (int, error) { return 0, r.err } func TestRequestHeaderEmptyMethod(t *testing.T) { var h RequestHeader if !h.IsGet() { t.Fatalf("empty method must be equivalent to GET") } if h.IsPost() { t.Fatalf("empty method cannot be POST") } if h.IsHead() { t.Fatalf("empty method cannot be HEAD") } if h.IsDelete() { t.Fatalf("empty method cannot be DELETE") } } func TestResponseHeaderHTTPVer(t *testing.T) { // non-http/1.1 testResponseHeaderHTTPVer(t, "HTTP/1.0 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true) testResponseHeaderHTTPVer(t, "HTTP/0.9 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true) testResponseHeaderHTTPVer(t, "foobar 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true) // http/1.1 testResponseHeaderHTTPVer(t, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", false) } func TestRequestHeaderHTTPVer(t *testing.T) { // non-http/1.1 testRequestHeaderHTTPVer(t, "GET / HTTP/1.0\r\nHost: aa.com\r\n\r\n", true) testRequestHeaderHTTPVer(t, "GET / HTTP/0.9\r\nHost: aa.com\r\n\r\n", true) testRequestHeaderHTTPVer(t, "GET / foobar\r\nHost: aa.com\r\n\r\n", true) // empty http version testRequestHeaderHTTPVer(t, "GET /\r\nHost: aaa.com\r\n\r\n", true) testRequestHeaderHTTPVer(t, "GET / \r\nHost: aaa.com\r\n\r\n", true) // http/1.1 testRequestHeaderHTTPVer(t, "GET / HTTP/1.1\r\nHost: a.com\r\n\r\n", false) } func testResponseHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { var h ResponseHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. response=%q", err, s) } if h.ConnectionClose() != connectionClose { t.Fatalf("unexpected connectionClose %v. Expecting %v. response=%q", h.ConnectionClose(), connectionClose, s) } } func testRequestHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { var h RequestHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. request=%q", err, s) } if h.ConnectionClose() != connectionClose { t.Fatalf("unexpected connectionClose %v. Expecting %v. request=%q", h.ConnectionClose(), connectionClose, s) } } func TestResponseHeaderCopyTo(t *testing.T) { var h ResponseHeader h.Set("Set-Cookie", "foo=bar") h.Set("Content-Type", "foobar") h.Set("AAA-BBB", "aaaa") var h1 ResponseHeader h.CopyTo(&h1) if !bytes.Equal(h1.Peek("Set-cookie"), h.Peek("Set-Cookie")) { t.Fatalf("unexpected cookie %q. Expected %q", h1.Peek("set-cookie"), h.Peek("set-cookie")) } if !bytes.Equal(h1.Peek("Content-Type"), h.Peek("Content-Type")) { t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type")) } if !bytes.Equal(h1.Peek("aaa-bbb"), h.Peek("AAA-BBB")) { t.Fatalf("unexpected aaa-bbb %q. Expected %q", h1.Peek("aaa-bbb"), h.Peek("aaa-bbb")) } } func TestRequestHeaderCopyTo(t *testing.T) { var h RequestHeader h.Set("Cookie", "aa=bb; cc=dd") h.Set("Content-Type", "foobar") h.Set("Host", "aaaa") h.Set("aaaxxx", "123") var h1 RequestHeader h.CopyTo(&h1) if !bytes.Equal(h1.Peek("cookie"), h.Peek("Cookie")) { t.Fatalf("unexpected cookie after copying: %q. Expected %q", h1.Peek("cookie"), h.Peek("cookie")) } if !bytes.Equal(h1.Peek("content-type"), h.Peek("Content-Type")) { t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type")) } if !bytes.Equal(h1.Peek("host"), h.Peek("host")) { t.Fatalf("unexpected host %q. Expected %q", h1.Peek("host"), h.Peek("host")) } if !bytes.Equal(h1.Peek("aaaxxx"), h.Peek("aaaxxx")) { t.Fatalf("unexpected aaaxxx %q. Expected %q", h1.Peek("aaaxxx"), h.Peek("aaaxxx")) } } func TestRequestHeaderConnectionClose(t *testing.T) { var h RequestHeader h.Set("Connection", "close") h.Set("Host", "foobar") if !h.ConnectionClose() { t.Fatalf("connection: close not set") } var w bytes.Buffer bw := bufio.NewWriter(&w) if err := h.Write(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var h1 RequestHeader br := bufio.NewReader(&w) if err := h1.Read(br); err != nil { t.Fatalf("error when reading request header: %s", err) } if !h1.ConnectionClose() { t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose()) } if string(h1.Peek("Connection")) != "close" { t.Fatalf("unexpected connection value: %q. Expecting %q", h.Peek("Connection"), "close") } } func TestRequestHeaderSetCookie(t *testing.T) { var h RequestHeader h.Set("Cookie", "foo=bar; baz=aaa") h.Set("cOOkie", "xx=yyy") if string(h.Cookie("foo")) != "bar" { t.Fatalf("Unexpected cookie %q. Expecting %q", h.Cookie("foo"), "bar") } if string(h.Cookie("baz")) != "aaa" { t.Fatalf("Unexpected cookie %q. Expecting %q", h.Cookie("baz"), "aaa") } if string(h.Cookie("xx")) != "yyy" { t.Fatalf("unexpected cookie %q. Expecting %q", h.Cookie("xx"), "yyy") } } func TestResponseHeaderSetCookie(t *testing.T) { var h ResponseHeader h.Set("set-cookie", "foo=bar; path=/aa/bb; domain=aaa.com") h.Set("Set-Cookie", "aaaaa=bxx") var c Cookie c.SetKey("foo") if !h.Cookie(&c) { t.Fatalf("cannot obtain %q cookie", c.Key()) } if string(c.Value()) != "bar" { t.Fatalf("unexpected cookie value %q. Expected %q", c.Value(), "bar") } if string(c.Path()) != "/aa/bb" { t.Fatalf("unexpected cookie path %q. Expected %q", c.Path(), "/aa/bb") } if string(c.Domain()) != "aaa.com" { t.Fatalf("unexpected cookie domain %q. Expected %q", c.Domain(), "aaa.com") } c.SetKey("aaaaa") if !h.Cookie(&c) { t.Fatalf("cannot obtain %q cookie", c.Key()) } if string(c.Value()) != "bxx" { t.Fatalf("unexpected cookie value %q. Expecting %q", c.Value(), "bxx") } } func TestResponseHeaderVisitAll(t *testing.T) { var h ResponseHeader r := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\nSet-Cookie: aa=bb; path=/foo/bar\r\nSet-Cookie: ccc\r\n\r\n") br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("Unepxected error: %s", err) } if h.Len() != 4 { t.Fatalf("Unexpected number of headers: %d. Expected 4", h.Len()) } contentLengthCount := 0 contentTypeCount := 0 cookieCount := 0 h.VisitAll(func(key, value []byte) { k := string(key) v := string(value) switch k { case "Content-Length": if v != string(h.Peek(k)) { t.Fatalf("unexpected content-length: %q. Expecting %q", v, h.Peek(k)) } contentLengthCount++ case "Content-Type": if v != string(h.Peek(k)) { t.Fatalf("Unexpected content-type: %q. Expected %q", v, h.Peek(k)) } contentTypeCount++ case "Set-Cookie": if cookieCount == 0 && v != "aa=bb; path=/foo/bar" { t.Fatalf("unexpected cookie header: %q. Expected %q", v, "aa=bb; path=/foo/bar") } if cookieCount == 1 && v != "ccc" { t.Fatalf("unexpected cookie header: %q. Expected %q", v, "ccc") } cookieCount++ default: t.Fatalf("unexpected header %q=%q", k, v) } }) if contentLengthCount != 1 { t.Fatalf("unexpected number of content-length headers: %d. Expected 1", contentLengthCount) } if contentTypeCount != 1 { t.Fatalf("unexpected number of content-type headers: %d. Expected 1", contentTypeCount) } if cookieCount != 2 { t.Fatalf("unexpected number of cookie header: %d. Expected 2", cookieCount) } } func TestRequestHeaderVisitAll(t *testing.T) { var h RequestHeader r := bytes.NewBufferString("GET / HTTP/1.1\r\nHost: aa.com\r\nXX: YYY\r\nXX: ZZ\r\nCookie: a=b; c=d\r\n\r\n") br := bufio.NewReader(r) if err := h.Read(br); err != nil { t.Fatalf("Unexpected error: %s", err) } if h.Len() != 4 { t.Fatalf("Unexpected number of header: %d. Expected 4", h.Len()) } hostCount := 0 xxCount := 0 cookieCount := 0 h.VisitAll(func(key, value []byte) { k := string(key) v := string(value) switch k { case "Host": if v != string(h.Peek(k)) { t.Fatalf("Unexpected host value %q. Expected %q", v, h.Peek(k)) } hostCount++ case "Xx": if xxCount == 0 && v != "YYY" { t.Fatalf("Unexpected value %q. Expected %q", v, "YYY") } if xxCount == 1 && v != "ZZ" { t.Fatalf("Unexpected value %q. Expected %q", v, "ZZ") } xxCount++ case "Cookie": if v != "a=b; c=d" { t.Fatalf("Unexpected cookie %q. Expected %q", v, "a=b; c=d") } cookieCount++ default: t.Fatalf("Unepxected header %q=%q", k, v) } }) if hostCount != 1 { t.Fatalf("Unepxected number of host headers detected %d. Expected 1", hostCount) } if xxCount != 2 { t.Fatalf("Unexpected number of xx headers detected %d. Expected 2", xxCount) } if cookieCount != 1 { t.Fatalf("Unexpected number of cookie headers %d. Expected 1", cookieCount) } } func TestResponseHeaderCookie(t *testing.T) { var h ResponseHeader var c Cookie c.SetKey("foobar") c.SetValue("aaa") h.SetCookie(&c) c.SetKey("йцук") c.SetDomain("foobar.com") h.SetCookie(&c) c.Reset() c.SetKey("foobar") if !h.Cookie(&c) { t.Fatalf("Cannot find cookie %q", c.Key()) } var expectedC1 Cookie expectedC1.SetKey("foobar") expectedC1.SetValue("aaa") if !equalCookie(&expectedC1, &c) { t.Fatalf("unexpected cookie\n%#v\nExpected\n%#v\n", &c, &expectedC1) } c.SetKey("йцук") if !h.Cookie(&c) { t.Fatalf("cannot find cookie %q", c.Key()) } var expectedC2 Cookie expectedC2.SetKey("йцук") expectedC2.SetValue("aaa") expectedC2.SetDomain("foobar.com") if !equalCookie(&expectedC2, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2) } h.VisitAllCookie(func(key, value []byte) { var cc Cookie cc.ParseBytes(value) if !bytes.Equal(key, cc.Key()) { t.Fatalf("Unexpected cookie key %q. Expected %q", key, cc.Key()) } switch { case bytes.Equal(key, []byte("foobar")): if !equalCookie(&expectedC1, &cc) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC1) } case bytes.Equal(key, []byte("йцук")): if !equalCookie(&expectedC2, &cc) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC2) } default: t.Fatalf("unexpected cookie key %q", key) } }) w := &bytes.Buffer{} bw := bufio.NewWriter(w) if err := h.Write(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } h.DelAllCookies() var h1 ResponseHeader br := bufio.NewReader(w) if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } c.SetKey("foobar") if !h1.Cookie(&c) { t.Fatalf("Cannot find cookie %q", c.Key()) } if !equalCookie(&expectedC1, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC1) } h1.DelCookie("foobar") if h.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } if h1.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } c.SetKey("йцук") if !h1.Cookie(&c) { t.Fatalf("cannot find cookie %q", c.Key()) } if !equalCookie(&expectedC2, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2) } h1.DelCookie("йцук") if h.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } if h1.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } } func equalCookie(c1, c2 *Cookie) bool { if !bytes.Equal(c1.Key(), c2.Key()) { return false } if !bytes.Equal(c1.Value(), c2.Value()) { return false } if !c1.Expire().Equal(c2.Expire()) { return false } if !bytes.Equal(c1.Domain(), c2.Domain()) { return false } if !bytes.Equal(c1.Path(), c2.Path()) { return false } return true } func TestRequestHeaderCookie(t *testing.T) { var h RequestHeader h.SetRequestURI("/foobar") h.Set("Host", "foobar.com") h.SetCookie("foo", "bar") h.SetCookie("привет", "мир") if string(h.Cookie("foo")) != "bar" { t.Fatalf("Unexpected cookie value %q. Exepcted %q", h.Cookie("foo"), "bar") } if string(h.Cookie("привет")) != "мир" { t.Fatalf("Unexpected cookie value %q. Expected %q", h.Cookie("привет"), "мир") } w := &bytes.Buffer{} bw := bufio.NewWriter(w) if err := h.Write(bw); err != nil { t.Fatalf("Unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("Unexpected error: %s", err) } var h1 RequestHeader br := bufio.NewReader(w) if err := h1.Read(br); err != nil { t.Fatalf("Unexpected error: %s", err) } if !bytes.Equal(h1.Cookie("foo"), h.Cookie("foo")) { t.Fatalf("Unexpected cookie value %q. Exepcted %q", h1.Cookie("foo"), h.Cookie("foo")) } h1.DelCookie("foo") if len(h1.Cookie("foo")) > 0 { t.Fatalf("Unexpected cookie found: %q", h1.Cookie("foo")) } if !bytes.Equal(h1.Cookie("привет"), h.Cookie("привет")) { t.Fatalf("Unexpected cookie value %q. Expected %q", h1.Cookie("привет"), h.Cookie("привет")) } h1.DelCookie("привет") if len(h1.Cookie("привет")) > 0 { t.Fatalf("Unexpected cookie found: %q", h1.Cookie("привет")) } h.DelAllCookies() if len(h.Cookie("foo")) > 0 { t.Fatalf("Unexpected cookie found: %q", h.Cookie("foo")) } if len(h.Cookie("привет")) > 0 { t.Fatalf("Unexpected cookie found: %q", h.Cookie("привет")) } } func TestRequestHeaderMethod(t *testing.T) { // common http methods testRequestHeaderMethod(t, "GET") testRequestHeaderMethod(t, "POST") testRequestHeaderMethod(t, "HEAD") testRequestHeaderMethod(t, "DELETE") // non-http methods testRequestHeaderMethod(t, "foobar") testRequestHeaderMethod(t, "ABC") } func testRequestHeaderMethod(t *testing.T, expectedMethod string) { var h RequestHeader h.SetMethod(expectedMethod) m := h.Method() if string(m) != expectedMethod { t.Fatalf("unexpected method: %q. Expecting %q", m, expectedMethod) } s := h.String() var h1 RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } m1 := h1.Method() if string(m) != string(m1) { t.Fatalf("unexpected method: %q. Expecting %q", m, m1) } } func TestRequestHeaderSetGet(t *testing.T) { h := &RequestHeader{} h.SetRequestURI("/aa/bbb") h.SetMethod("POST") h.Set("foo", "bar") h.Set("host", "12345") h.Set("content-type", "aaa/bbb") h.Set("content-length", "1234") h.Set("user-agent", "aaabbb") h.Set("referer", "axcv") h.Set("baz", "xxxxx") h.Set("transfer-encoding", "chunked") h.Set("connection", "close") expectRequestHeaderGet(t, h, "Foo", "bar") expectRequestHeaderGet(t, h, "Host", "12345") expectRequestHeaderGet(t, h, "Content-Type", "aaa/bbb") expectRequestHeaderGet(t, h, "Content-Length", "1234") expectRequestHeaderGet(t, h, "USER-AGent", "aaabbb") expectRequestHeaderGet(t, h, "Referer", "axcv") expectRequestHeaderGet(t, h, "baz", "xxxxx") expectRequestHeaderGet(t, h, "Transfer-Encoding", "") expectRequestHeaderGet(t, h, "connecTION", "close") if !h.ConnectionClose() { t.Fatalf("unset connection: close") } if h.ContentLength() != 1234 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234) } w := &bytes.Buffer{} bw := bufio.NewWriter(w) err := h.Write(bw) if err != nil { t.Fatalf("Unexpected error when writing request header: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing request header: %s", err) } var h1 RequestHeader br := bufio.NewReader(w) if err = h1.Read(br); err != nil { t.Fatalf("Unexpected error when reading request header: %s", err) } if h1.ContentLength() != h.ContentLength() { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength()) } expectRequestHeaderGet(t, &h1, "Foo", "bar") expectRequestHeaderGet(t, &h1, "HOST", "12345") expectRequestHeaderGet(t, &h1, "Content-Type", "aaa/bbb") expectRequestHeaderGet(t, &h1, "Content-Length", "1234") expectRequestHeaderGet(t, &h1, "USER-AGent", "aaabbb") expectRequestHeaderGet(t, &h1, "Referer", "axcv") expectRequestHeaderGet(t, &h1, "baz", "xxxxx") expectRequestHeaderGet(t, &h1, "Transfer-Encoding", "") expectRequestHeaderGet(t, &h1, "Connection", "close") if !h1.ConnectionClose() { t.Fatalf("unset connection: close") } } func TestResponseHeaderSetGet(t *testing.T) { h := &ResponseHeader{} h.Set("foo", "bar") h.Set("content-type", "aaa/bbb") h.Set("connection", "close") h.Set("content-length", "1234") h.Set("Server", "aaaa") h.Set("baz", "xxxxx") h.Set("Transfer-Encoding", "chunked") expectResponseHeaderGet(t, h, "Foo", "bar") expectResponseHeaderGet(t, h, "Content-Type", "aaa/bbb") expectResponseHeaderGet(t, h, "Connection", "close") expectResponseHeaderGet(t, h, "Content-Length", "1234") expectResponseHeaderGet(t, h, "seRVer", "aaaa") expectResponseHeaderGet(t, h, "baz", "xxxxx") expectResponseHeaderGet(t, h, "Transfer-Encoding", "") if h.ContentLength() != 1234 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234) } if !h.ConnectionClose() { t.Fatalf("Unexpected Connection: close value %v. Expected %v", h.ConnectionClose(), true) } w := &bytes.Buffer{} bw := bufio.NewWriter(w) err := h.Write(bw) if err != nil { t.Fatalf("Unexpected error when writing response header: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing response header: %s", err) } var h1 ResponseHeader br := bufio.NewReader(w) if err = h1.Read(br); err != nil { t.Fatalf("Unexpected error when reading response header: %s", err) } if h1.ContentLength() != h.ContentLength() { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength()) } if h1.ConnectionClose() != h.ConnectionClose() { t.Fatalf("unexpected connection: close %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose()) } expectResponseHeaderGet(t, &h1, "Foo", "bar") expectResponseHeaderGet(t, &h1, "Content-Type", "aaa/bbb") expectResponseHeaderGet(t, &h1, "Connection", "close") expectResponseHeaderGet(t, &h1, "seRVer", "aaaa") expectResponseHeaderGet(t, &h1, "baz", "xxxxx") } func expectRequestHeaderGet(t *testing.T, h *RequestHeader, key, expectedValue string) { if string(h.Peek(key)) != expectedValue { t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.Peek(key), expectedValue) } } func expectResponseHeaderGet(t *testing.T, h *ResponseHeader, key, expectedValue string) { if string(h.Peek(key)) != expectedValue { t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.Peek(key), expectedValue) } } func TestResponseHeaderConnectionClose(t *testing.T) { testResponseHeaderConnectionClose(t, true) testResponseHeaderConnectionClose(t, false) } func testResponseHeaderConnectionClose(t *testing.T, connectionClose bool) { h := &ResponseHeader{} if connectionClose { h.SetConnectionClose() } h.SetContentLength(123) w := &bytes.Buffer{} bw := bufio.NewWriter(w) err := h.Write(bw) if err != nil { t.Fatalf("Unexpected error when writing response header: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing response header: %s", err) } var h1 ResponseHeader br := bufio.NewReader(w) err = h1.Read(br) if err != nil { t.Fatalf("Unexpected error when reading response header: %s", err) } if h1.ConnectionClose() != h.ConnectionClose() { t.Fatalf("Unexpected value for ConnectionClose: %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose()) } } func TestRequestHeaderTooBig(t *testing.T) { s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n" + getHeaders(10500) + "\r\n" r := bytes.NewBufferString(s) br := bufio.NewReaderSize(r, 4096) h := &RequestHeader{} err := h.Read(br) if err == nil { t.Fatalf("Expecting error when reading too big header") } } func TestResponseHeaderTooBig(t *testing.T) { s := "HTTP/1.1 200 OK\r\nContent-Type: sss\r\nContent-Length: 0\r\n" + getHeaders(100500) + "\r\n" r := bytes.NewBufferString(s) br := bufio.NewReaderSize(r, 4096) h := &ResponseHeader{} err := h.Read(br) if err == nil { t.Fatalf("Expecting error when reading too big header") } } type bufioPeekReader struct { s string n int } func (r *bufioPeekReader) Read(b []byte) (int, error) { if len(r.s) == 0 { return 0, io.EOF } r.n++ n := r.n if len(r.s) < n { n = len(r.s) } src := []byte(r.s[:n]) r.s = r.s[n:] n = copy(b, src) return n, nil } func TestRequestHeaderBufioPeek(t *testing.T) { r := &bufioPeekReader{ s: "GET / HTTP/1.1\r\nHost: foobar.com\r\n" + getHeaders(10) + "\r\naaaa", } br := bufio.NewReaderSize(r, 4096) h := &RequestHeader{} if err := h.Read(br); err != nil { t.Fatalf("Unexpected error when reading request: %s", err) } verifyRequestHeader(t, h, 0, "/", "foobar.com", "", "") verifyTrailer(t, br, "aaaa") } func TestResponseHeaderBufioPeek(t *testing.T) { r := &bufioPeekReader{ s: "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: aaa\r\n" + getHeaders(10) + "\r\n0123456789", } br := bufio.NewReaderSize(r, 4096) h := &ResponseHeader{} if err := h.Read(br); err != nil { t.Fatalf("Unexpected error when reading response: %s", err) } verifyResponseHeader(t, h, 200, 10, "aaa") verifyTrailer(t, br, "0123456789") } func getHeaders(n int) string { var h []string for i := 0; i < n; i++ { h = append(h, fmt.Sprintf("Header_%d: Value_%d\r\n", i, i)) } return strings.Join(h, "") } func TestResponseHeaderReadSuccess(t *testing.T) { h := &ResponseHeader{} // straight order of content-length and content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n", 200, 123, "text/html", "") if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // reverse order of content-length and content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\nConnection: close\r\n\r\n", 202, 543, "text/plain; encoding=utf-8", "") if !h.ConnectionClose() { t.Fatalf("expecting connection: close") } // tranfer-encoding: chunked testResponseHeaderReadSuccess(t, h, "HTTP/1.1 505 Internal error\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n", 505, -1, "text/html", "") if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // reverse order of content-type and tranfer-encoding testResponseHeaderReadSuccess(t, h, "HTTP/1.1 343 foobar\r\nTransfer-Encoding: chunked\r\nContent-Type: text/json\r\n\r\n", 343, -1, "text/json", "") // additional headers testResponseHeaderReadSuccess(t, h, "HTTP/1.1 100 Continue\r\nFoobar: baz\r\nContent-Type: aaa/bbb\r\nUser-Agent: x\r\nContent-Length: 123\r\nZZZ: werer\r\n\r\n", 100, 123, "aaa/bbb", "") // trailer (aka body) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 32245\r\n\r\nqwert aaa", 200, 32245, "text/plain", "qwert aaa") // ancient http protocol testResponseHeaderReadSuccess(t, h, "HTTP/0.9 300 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\nqqqq", 300, 123, "text/html", "qqqq") // lf instead of crlf testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length: 123\nContent-Type: text/html\n\n", 200, 123, "text/html", "") // Zero-length headers with mixed crlf and lf testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nContent-Length: 345\nZero-Value: \r\nContent-Type: aaa\n: zero-key\r\n\r\nooa", 400, 345, "aaa", "ooa") // No space after colon testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length:34\nContent-Type: sss\n\naaaa", 200, 34, "sss", "aaaa") // invalid case testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nconTEnt-leNGTH: 123\nConTENT-TYPE: ass\n\n", 400, 123, "ass", "") // duplicate content-length testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 456\r\nContent-Type: foo/bar\r\nContent-Length: 321\r\n\r\n", 200, 321, "foo/bar", "") // duplicate content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 234\r\nContent-Type: foo/bar\r\nContent-Type: baz/bar\r\n\r\n", 200, 234, "baz/bar", "") // both transfer-encoding: chunked and content-length testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 123\r\nTransfer-Encoding: chunked\r\n\r\n", 200, -1, "foo/bar", "") testResponseHeaderReadSuccess(t, h, "HTTP/1.1 300 OK\r\nContent-Type: foo/barr\r\nTransfer-Encoding: chunked\r\nContent-Length: 354\r\n\r\n", 300, -1, "foo/barr", "") // duplicate transfer-encoding: chunked testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\nTransfer-Encoding: chunked\r\n\r\n", 200, -1, "text/html", "") // no reason string in the first line testResponseHeaderReadSuccess(t, h, "HTTP/1.1 456\r\nContent-Type: xxx/yyy\r\nContent-Length: 134\r\n\r\naaaxxx", 456, 134, "xxx/yyy", "aaaxxx") // blank lines before the first line testResponseHeaderReadSuccess(t, h, "\r\nHTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 0\r\n\r\nsss", 200, 0, "aa", "sss") if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // no content-length (informational responses) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 101 OK\r\n\r\n", 101, -2, "text/plain; charset=utf-8", "") if h.ConnectionClose() { t.Fatalf("expecting connection: keep-alive for informational response") } // no content-length (no-content responses) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 204 OK\r\n\r\n", 204, -2, "text/plain; charset=utf-8", "") if h.ConnectionClose() { t.Fatalf("expecting connection: keep-alive for no-content response") } // no content-length (not-modified responses) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 304 OK\r\n\r\n", 304, -2, "text/plain; charset=utf-8", "") if h.ConnectionClose() { t.Fatalf("expecting connection: keep-alive for not-modified response") } // no content-length (identity transfer-encoding) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\n\r\nabcdefg", 200, -2, "foo/bar", "abcdefg") if !h.ConnectionClose() { t.Fatalf("expecting connection: close for identity response") } // non-numeric content-length testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: faaa\r\nContent-Type: text/html\r\n\r\nfoobar", 200, -2, "text/html", "foobar") testResponseHeaderReadSuccess(t, h, "HTTP/1.1 201 OK\r\nContent-Length: 123aa\r\nContent-Type: text/ht\r\n\r\naaa", 201, -2, "text/ht", "aaa") testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: aa124\r\nContent-Type: html\r\n\r\nxx", 200, -2, "html", "xx") // no content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa", 400, 123, string(defaultContentType), "foiaaa") // no headers testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\n\r\naaaabbb", 200, -2, string(defaultContentType), "aaaabbb") if !h.IsHTTP11() { t.Fatalf("expecting http/1.1 protocol") } // ancient http protocol testResponseHeaderReadSuccess(t, h, "HTTP/1.0 203 OK\r\nContent-Length: 123\r\nContent-Type: foobar\r\n\r\naaa", 203, 123, "foobar", "aaa") if h.IsHTTP11() { t.Fatalf("ancient protocol must be non-http/1.1") } if !h.ConnectionClose() { t.Fatalf("expecting connection: close for ancient protocol") } // ancient http protocol with 'Connection: keep-alive' header. testResponseHeaderReadSuccess(t, h, "HTTP/1.0 403 aa\r\nContent-Length: 0\r\nContent-Type: 2\r\nConnection: Keep-Alive\r\n\r\nww", 403, 0, "2", "ww") if h.IsHTTP11() { t.Fatalf("ancient protocol must be non-http/1.1") } if h.ConnectionClose() { t.Fatalf("expecting connection: keep-alive for ancient protocol") } } func TestRequestHeaderReadSuccess(t *testing.T) { h := &RequestHeader{} // simple headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n", 0, "/foo/bar", "google.com", "", "", "") if h.ConnectionClose() { t.Fatalf("unexpected connection: close header") } // simple headers with body testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar", 0, "/a/bar", "gole.com", "", "", "foobar") if !h.ConnectionClose() { t.Fatalf("connection: close unset") } // ancient http protocol testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp", 0, "/bar", "gole", "", "", "pppp") if h.IsHTTP11() { t.Fatalf("ancient http protocol cannot be http/1.1") } if !h.ConnectionClose() { t.Fatalf("expecting connectionClose for ancient http protocol") } // ancient http protocol with 'Connection: keep-alive' header testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.0\r\nHost: bb\r\nConnection: keep-alive\r\n\r\nxxx", 0, "/aa", "bb", "", "", "xxx") if h.IsHTTP11() { t.Fatalf("ancient http protocol cannot be http/1.1") } if h.ConnectionClose() { t.Fatalf("unexpected 'connection: close' for ancient http protocol") } // complex headers with body testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz", 0, "/aabar", "ole.com", "", "", "zzz") if !h.IsHTTP11() { t.Fatalf("expecting http/1.1 protocol") } if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // lf instead of crlf testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\nHost: google.com\n\n", 0, "/foo/bar", "google.com", "", "", "") // post method testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef", 1235, "/aaa?bbb", "foobar.com", "", "aaa", "abcdef") // zero-length headers with mixed crlf and lf testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv", 0, "/a", "aaa", "", "", "xccv") // no space after colon testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds", 0, "/a", "aaaxd", "", "", "sdfds") // get with zero content-length testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 0\n\n", 0, "/xxx", "aaa.com", "", "", "") // get with non-zero content-length testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 123\n\n", 0, "/xxx", "aaa.com", "", "", "") // invalid case testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\nhoST: bbb.com\n\naas", 0, "/aaa", "bbb.com", "", "", "aas") // referer testRequestHeaderReadSuccess(t, h, "GET /asdf HTTP/1.1\nHost: aaa.com\nReferer: bb.com\n\naaa", 0, "/asdf", "aaa.com", "bb.com", "", "aaa") // duplicate host testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aaaaaa.com\r\nHost: bb.com\r\n\r\n", 0, "/aa", "bb.com", "", "", "") // post with duplicate content-type testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n", 123, "/a", "aa", "", "xx", "") // post with duplicate content-length testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n", 1, "/xx", "aa", "", "s", "") // non-post with content-type testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n", 0, "/aaa", "bbb.com", "", "aaab", "") // non-post with content-length testRequestHeaderReadSuccess(t, h, "HEAD / HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\n\r\n", 0, "/", "aaa.com", "", "", "") // non-post with content-type and content-length testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aa.com\r\nContent-Type: abd/test\r\nContent-Length: 123\r\n\r\n", 0, "/aa", "aa.com", "", "abd/test", "") // request uri with hostname testRequestHeaderReadSuccess(t, h, "GET http://gooGle.com/foO/%20bar?xxx#aaa HTTP/1.1\r\nHost: aa.cOM\r\n\r\ntrail", 0, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", "trail") // no protocol in the first line testRequestHeaderReadSuccess(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD", 0, "/foo/bar", "google.com", "", "", "isdD") // blank lines before the first line testRequestHeaderReadSuccess(t, h, "\r\n\n\r\nGET /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nsss", 0, "/aaa", "aaa.com", "", "", "sss") // request uri with spaces testRequestHeaderReadSuccess(t, h, "GET /foo/ bar baz HTTP/1.1\r\nHost: aa.com\r\n\r\nxxx", 0, "/foo/ bar baz", "aa.com", "", "", "xxx") // no host testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa", 0, "/foo/bar", "", "", "", "aaa") // no host, no headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar", 0, "/foo/bar", "", "", "", "foobar") // post with invalid content-length testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty", -2, "/a", "bb", "", "aa", "qwerty") // post without content-length and content-type testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc", -2, "/aaa", "aaa.com", "", "", "zxc") // post without content-type testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Length: 123\r\n\r\npoiuy", 123, "/abc", "aa.com", "", "", "poiuy") // post without content-length testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Type: adv\r\n\r\n123456", -2, "/abc", "aa.com", "", "adv", "123456") // invalid method testRequestHeaderReadSuccess(t, h, "POST /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\nmnbv", -2, "/foo/bar", "google.com", "", "", "mnbv") // put request testRequestHeaderReadSuccess(t, h, "PUT /faa HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\nContent-Type: aaa\r\n\r\nxwwere", 123, "/faa", "aaa.com", "", "aaa", "xwwere") } func TestResponseHeaderReadError(t *testing.T) { h := &ResponseHeader{} // incorrect first line testResponseHeaderReadError(t, h, "") testResponseHeaderReadError(t, h, "fo") testResponseHeaderReadError(t, h, "foobarbaz") testResponseHeaderReadError(t, h, "HTTP/1.1") testResponseHeaderReadError(t, h, "HTTP/1.1 ") testResponseHeaderReadError(t, h, "HTTP/1.1 s") // non-numeric status code testResponseHeaderReadError(t, h, "HTTP/1.1 foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") testResponseHeaderReadError(t, h, "HTTP/1.1 123foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") testResponseHeaderReadError(t, h, "HTTP/1.1 foobar344 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") // no headers testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\n") // no trailing crlf testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n") } func TestRequestHeaderReadError(t *testing.T) { h := &RequestHeader{} // incorrect first line testRequestHeaderReadError(t, h, "") testRequestHeaderReadError(t, h, "fo") testRequestHeaderReadError(t, h, "GET ") testRequestHeaderReadError(t, h, "GET / HTTP/1.1\r") // missing RequestURI testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n") } func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) { r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) if err == nil { t.Fatalf("Expecting error when reading response header %q", headers) } // make sure response header works after error testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss", 200, 12345, "foo/bar", "sss") } func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) { r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) if err == nil { t.Fatalf("Expecting error when reading request header %q", headers) } // make sure request header works after error testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx", 0, "/foo/bar", "aaaa", "", "", "xxx") } func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string) { r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) if err != nil { t.Fatalf("Unexpected error when parsing response headers: %s. headers=%q", err, headers) } verifyResponseHeader(t, h, expectedStatusCode, expectedContentLength, expectedContentType) verifyTrailer(t, br, expectedTrailer) } func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType, expectedTrailer string) { r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) if err != nil { t.Fatalf("Unexpected error when parsing request headers: %s. headers=%q", err, headers) } verifyRequestHeader(t, h, expectedContentLength, expectedRequestURI, expectedHost, expectedReferer, expectedContentType) verifyTrailer(t, br, expectedTrailer) } func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType string) { if h.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) } if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } if string(h.Peek("Content-Type")) != expectedContentType { t.Fatalf("Unexpected content type %q. Expected %q", h.Peek("Content-Type"), expectedContentType) } } func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType string) { if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected Content-Length %d. Expected %d", h.ContentLength(), expectedContentLength) } if string(h.RequestURI()) != expectedRequestURI { t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI(), expectedRequestURI) } if string(h.Peek("Host")) != expectedHost { t.Fatalf("Unexpected host %q. Expected %q", h.Peek("Host"), expectedHost) } if string(h.Peek("Referer")) != expectedReferer { t.Fatalf("Unexpected referer %q. Expected %q", h.Peek("Referer"), expectedReferer) } if string(h.Peek("Content-Type")) != expectedContentType { t.Fatalf("Unexpected content-type %q. Expected %q", h.Peek("Content-Type"), expectedContentType) } } func verifyTrailer(t *testing.T, r *bufio.Reader, expectedTrailer string) { trailer, err := ioutil.ReadAll(r) if err != nil { t.Fatalf("Cannot read trailer: %s", err) } if !bytes.Equal(trailer, []byte(expectedTrailer)) { t.Fatalf("Unexpected trailer %q. Expected %q", trailer, expectedTrailer) } } golang-github-valyala-fasthttp-20160617/header_timing_test.go000066400000000000000000000065431273074646000242010ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "io" "testing" ) var strFoobar = []byte("foobar.com") type benchReadBuf struct { s []byte n int } func (r *benchReadBuf) Read(p []byte) (int, error) { if r.n == len(r.s) { return 0, io.EOF } n := copy(p, r.s[r.n:]) r.n += n return n, nil } func BenchmarkRequestHeaderRead(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h RequestHeader buf := &benchReadBuf{ s: []byte("GET /foo/bar HTTP/1.1\r\nHost: foobar.com\r\nUser-Agent: aaa.bbb\r\nReferer: http://google.com/aaa/bbb\r\n\r\n"), } br := bufio.NewReader(buf) for pb.Next() { buf.n = 0 br.Reset(buf) if err := h.Read(br); err != nil { b.Fatalf("unexpected error when reading header: %s", err) } } }) } func BenchmarkResponseHeaderRead(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h ResponseHeader buf := &benchReadBuf{ s: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 1256\r\nServer: aaa 1/2.3\r\nTest: 1.2.3\r\n\r\n"), } br := bufio.NewReader(buf) for pb.Next() { buf.n = 0 br.Reset(buf) if err := h.Read(br); err != nil { b.Fatalf("unexpected error when reading header: %s", err) } } }) } func BenchmarkRequestHeaderWrite(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h RequestHeader h.SetRequestURI("/foo/bar") h.SetHost("foobar.com") h.SetUserAgent("aaa.bbb") h.SetReferer("http://google.com/aaa/bbb") var w ByteBuffer for pb.Next() { if _, err := h.WriteTo(&w); err != nil { b.Fatalf("unexpected error when writing header: %s", err) } w.Reset() } }) } func BenchmarkResponseHeaderWrite(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h ResponseHeader h.SetStatusCode(200) h.SetContentType("text/html") h.SetContentLength(1256) h.SetServer("aaa 1/2.3") h.Set("Test", "1.2.3") var w ByteBuffer for pb.Next() { if _, err := h.WriteTo(&w); err != nil { b.Fatalf("unexpected error when writing header: %s", err) } w.Reset() } }) } func BenchmarkRequestHeaderPeekBytesCanonical(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h RequestHeader h.SetBytesV("Host", strFoobar) for pb.Next() { v := h.PeekBytes(strHost) if !bytes.Equal(v, strFoobar) { b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar) } } }) } func BenchmarkRequestHeaderPeekBytesNonCanonical(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var h RequestHeader h.SetBytesV("Host", strFoobar) hostBytes := []byte("HOST") for pb.Next() { v := h.PeekBytes(hostBytes) if !bytes.Equal(v, strFoobar) { b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar) } } }) } func BenchmarkNormalizeHeaderKeyCommonCase(b *testing.B) { src := []byte("User-Agent-Host-Content-Type-Content-Length-Server") benchmarkNormalizeHeaderKey(b, src) } func BenchmarkNormalizeHeaderKeyLowercase(b *testing.B) { src := []byte("user-agent-host-content-type-content-length-server") benchmarkNormalizeHeaderKey(b, src) } func BenchmarkNormalizeHeaderKeyUppercase(b *testing.B) { src := []byte("USER-AGENT-HOST-CONTENT-TYPE-CONTENT-LENGTH-SERVER") benchmarkNormalizeHeaderKey(b, src) } func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) { b.RunParallel(func(pb *testing.PB) { buf := make([]byte, len(src)) for pb.Next() { copy(buf, src) normalizeHeaderKey(buf, false) } }) } golang-github-valyala-fasthttp-20160617/http.go000066400000000000000000001144451273074646000213230ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "errors" "fmt" "io" "mime/multipart" "os" "sync" ) // Request represents HTTP request. // // It is forbidden copying Request instances. Create new instances // and use CopyTo instead. // // Request instance MUST NOT be used from concurrently running goroutines. type Request struct { noCopy noCopy // Request header // // Copying Header by value is forbidden. Use pointer to Header instead. Header RequestHeader uri URI postArgs Args bodyStream io.Reader w requestBodyWriter body *ByteBuffer multipartForm *multipart.Form multipartFormBoundary string // Group bool members in order to reduce Request object size. parsedURI bool parsedPostArgs bool } // Response represents HTTP response. // // It is forbidden copying Response instances. Create new instances // and use CopyTo instead. // // Response instance MUST NOT be used from concurrently running goroutines. type Response struct { noCopy noCopy // Response header // // Copying Header by value is forbidden. Use pointer to Header instead. Header ResponseHeader bodyStream io.Reader w responseBodyWriter body *ByteBuffer // Response.Read() skips reading body if set to true. // Use it for reading HEAD responses. // // Response.Write() skips writing body if set to true. // Use it for writing HEAD responses. SkipBody bool // This is a hackish field for client implementation, which allows // avoiding body copying. keepBodyBuffer bool } // SetHost sets host for the request. func (req *Request) SetHost(host string) { req.URI().SetHost(host) } // SetHostBytes sets host for the request. func (req *Request) SetHostBytes(host []byte) { req.URI().SetHostBytes(host) } // Host returns the host for the given request. func (req *Request) Host() []byte { return req.URI().Host() } // SetRequestURI sets RequestURI. func (req *Request) SetRequestURI(requestURI string) { req.Header.SetRequestURI(requestURI) req.parsedURI = false } // SetRequestURIBytes sets RequestURI. func (req *Request) SetRequestURIBytes(requestURI []byte) { req.Header.SetRequestURIBytes(requestURI) req.parsedURI = false } // RequestURI returns request's URI. func (req *Request) RequestURI() []byte { if req.parsedURI { requestURI := req.uri.RequestURI() req.SetRequestURIBytes(requestURI) } return req.Header.RequestURI() } // StatusCode returns response status code. func (resp *Response) StatusCode() int { return resp.Header.StatusCode() } // SetStatusCode sets response status code. func (resp *Response) SetStatusCode(statusCode int) { resp.Header.SetStatusCode(statusCode) } // ConnectionClose returns true if 'Connection: close' header is set. func (resp *Response) ConnectionClose() bool { return resp.Header.ConnectionClose() } // SetConnectionClose sets 'Connection: close' header. func (resp *Response) SetConnectionClose() { resp.Header.SetConnectionClose() } // ConnectionClose returns true if 'Connection: close' header is set. func (req *Request) ConnectionClose() bool { return req.Header.ConnectionClose() } // SetConnectionClose sets 'Connection: close' header. func (req *Request) SetConnectionClose() { req.Header.SetConnectionClose() } // SendFile registers file on the given path to be used as response body // when Write is called. // // Note that SendFile doesn't set Content-Type, so set it yourself // with Header.SetContentType. func (resp *Response) SendFile(path string) error { f, err := os.Open(path) if err != nil { return err } fileInfo, err := f.Stat() if err != nil { f.Close() return err } size64 := fileInfo.Size() size := int(size64) if int64(size) != size64 { size = -1 } resp.Header.SetLastModified(fileInfo.ModTime()) resp.SetBodyStream(f, size) return nil } // SetBodyStream sets request body stream and, optionally body size. // // If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // Note that GET and HEAD requests cannot have body. // // See also SetBodyStreamWriter. func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) { req.ResetBody() req.bodyStream = bodyStream req.Header.SetContentLength(bodySize) } // SetBodyStream sets response body stream and, optionally body size. // // If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // See also SetBodyStreamWriter. func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) { resp.ResetBody() resp.bodyStream = bodyStream resp.Header.SetContentLength(bodySize) } // IsBodyStream returns true if body is set via SetBodyStream* func (req *Request) IsBodyStream() bool { return req.bodyStream != nil } // IsBodyStream returns true if body is set via SetBodyStream* func (resp *Response) IsBodyStream() bool { return resp.bodyStream != nil } // SetBodyStreamWriter registers the given sw for populating request body. // // This function may be used in the following cases: // // * if request body is too big (more than 10MB). // * if request body is streamed from slow external sources. // * if request body must be streamed to the server in chunks // (aka `http client push` or `chunked transfer-encoding`). // // Note that GET and HEAD requests cannot have body. // /// See also SetBodyStream. func (req *Request) SetBodyStreamWriter(sw StreamWriter) { sr := NewStreamReader(sw) req.SetBodyStream(sr, -1) } // SetBodyStreamWriter registers the given sw for populating response body. // // This function may be used in the following cases: // // * if response body is too big (more than 10MB). // * if response body is streamed from slow external sources. // * if response body must be streamed to the client in chunks // (aka `http server push` or `chunked transfer-encoding`). // // See also SetBodyStream. func (resp *Response) SetBodyStreamWriter(sw StreamWriter) { sr := NewStreamReader(sw) resp.SetBodyStream(sr, -1) } // BodyWriter returns writer for populating response body. // // If used inside RequestHandler, the returned writer must not be used // after returning from RequestHandler. Use RequestCtx.Write // or SetBodyStreamWriter in this case. func (resp *Response) BodyWriter() io.Writer { resp.w.r = resp return &resp.w } // BodyWriter returns writer for populating request body. func (req *Request) BodyWriter() io.Writer { req.w.r = req return &req.w } type responseBodyWriter struct { r *Response } func (w *responseBodyWriter) Write(p []byte) (int, error) { w.r.AppendBody(p) return len(p), nil } type requestBodyWriter struct { r *Request } func (w *requestBodyWriter) Write(p []byte) (int, error) { w.r.AppendBody(p) return len(p), nil } // Body returns response body. func (resp *Response) Body() []byte { if resp.bodyStream != nil { bodyBuf := resp.bodyBuffer() bodyBuf.Reset() _, err := copyZeroAlloc(bodyBuf, resp.bodyStream) resp.closeBodyStream() if err != nil { bodyBuf.SetString(err.Error()) } } return resp.bodyBytes() } func (resp *Response) bodyBytes() []byte { if resp.body == nil { return nil } return resp.body.B } func (req *Request) bodyBytes() []byte { if req.body == nil { return nil } return req.body.B } func (resp *Response) bodyBuffer() *ByteBuffer { if resp.body == nil { resp.body = responseBodyPool.Acquire() } return resp.body } func (req *Request) bodyBuffer() *ByteBuffer { if req.body == nil { req.body = requestBodyPool.Acquire() } return req.body } var ( requestBodyPool byteBufferPool responseBodyPool byteBufferPool ) // BodyGunzip returns un-gzipped body data. // // This method may be used if the request header contains // 'Content-Encoding: gzip' for reading un-gzipped body. // Use Body for reading gzipped request body. func (req *Request) BodyGunzip() ([]byte, error) { return gunzipData(req.Body()) } // BodyGunzip returns un-gzipped body data. // // This method may be used if the response header contains // 'Content-Encoding: gzip' for reading un-gzipped body. // Use Body for reading gzipped response body. func (resp *Response) BodyGunzip() ([]byte, error) { return gunzipData(resp.Body()) } func gunzipData(p []byte) ([]byte, error) { var bb ByteBuffer _, err := WriteGunzip(&bb, p) if err != nil { return nil, err } return bb.B, nil } // BodyInflate returns inflated body data. // // This method may be used if the response header contains // 'Content-Encoding: deflate' for reading inflated request body. // Use Body for reading deflated request body. func (req *Request) BodyInflate() ([]byte, error) { return inflateData(req.Body()) } // BodyInflate returns inflated body data. // // This method may be used if the response header contains // 'Content-Encoding: deflate' for reading inflated response body. // Use Body for reading deflated response body. func (resp *Response) BodyInflate() ([]byte, error) { return inflateData(resp.Body()) } func inflateData(p []byte) ([]byte, error) { var bb ByteBuffer _, err := WriteInflate(&bb, p) if err != nil { return nil, err } return bb.B, nil } // BodyWriteTo writes request body to w. func (req *Request) BodyWriteTo(w io.Writer) error { if req.bodyStream != nil { _, err := copyZeroAlloc(w, req.bodyStream) req.closeBodyStream() return err } if req.onlyMultipartForm() { return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary) } _, err := w.Write(req.bodyBytes()) return err } // BodyWriteTo writes response body to w. func (resp *Response) BodyWriteTo(w io.Writer) error { if resp.bodyStream != nil { _, err := copyZeroAlloc(w, resp.bodyStream) resp.closeBodyStream() return err } _, err := w.Write(resp.bodyBytes()) return err } // AppendBody appends p to response body. // // It is safe re-using p after the function returns. func (resp *Response) AppendBody(p []byte) { resp.AppendBodyString(b2s(p)) } // AppendBodyString appends s to response body. func (resp *Response) AppendBodyString(s string) { resp.closeBodyStream() resp.bodyBuffer().WriteString(s) } // SetBody sets response body. // // It is safe re-using body argument after the function returns. func (resp *Response) SetBody(body []byte) { resp.SetBodyString(b2s(body)) } // SetBodyString sets response body. func (resp *Response) SetBodyString(body string) { resp.closeBodyStream() bodyBuf := resp.bodyBuffer() bodyBuf.Reset() bodyBuf.WriteString(body) } // ResetBody resets response body. func (resp *Response) ResetBody() { resp.closeBodyStream() if resp.body != nil { if resp.keepBodyBuffer { resp.body.Reset() } else { responseBodyPool.Release(resp.body) resp.body = nil } } } // ReleaseBody retires the response body if it is greater than "size" bytes. // // This permits GC to reclaim the large buffer. If used, must be before // ReleaseResponse. func (resp *Response) ReleaseBody(size int) { if cap(resp.body.B) > size { resp.closeBodyStream() resp.body = nil } } // ReleaseBody retires the request body if it is greater than "size" bytes. // // This permits GC to reclaim the large buffer. If used, must be before // ReleaseRequest. func (req *Request) ReleaseBody(size int) { if cap(req.body.B) > size { req.closeBodyStream() req.body = nil } } // Body returns request body. func (req *Request) Body() []byte { if req.bodyStream != nil { bodyBuf := req.bodyBuffer() bodyBuf.Reset() _, err := copyZeroAlloc(bodyBuf, req.bodyStream) req.closeBodyStream() if err != nil { bodyBuf.SetString(err.Error()) } } else if req.onlyMultipartForm() { body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { return []byte(err.Error()) } return body } return req.bodyBytes() } // AppendBody appends p to request body. // // It is safe re-using p after the function returns. func (req *Request) AppendBody(p []byte) { req.AppendBodyString(b2s(p)) } // AppendBodyString appends s to request body. func (req *Request) AppendBodyString(s string) { req.RemoveMultipartFormFiles() req.closeBodyStream() req.bodyBuffer().WriteString(s) } // SetBody sets request body. // // It is safe re-using body argument after the function returns. func (req *Request) SetBody(body []byte) { req.SetBodyString(b2s(body)) } // SetBodyString sets request body. func (req *Request) SetBodyString(body string) { req.RemoveMultipartFormFiles() req.closeBodyStream() req.bodyBuffer().SetString(body) } // ResetBody resets request body. func (req *Request) ResetBody() { req.RemoveMultipartFormFiles() req.closeBodyStream() if req.body != nil { requestBodyPool.Release(req.body) req.body = nil } } // CopyTo copies req contents to dst except of body stream. func (req *Request) CopyTo(dst *Request) { req.copyToSkipBody(dst) if req.body != nil { dst.bodyBuffer().Set(req.body.B) } else if dst.body != nil { dst.body.Reset() } } func (req *Request) copyToSkipBody(dst *Request) { dst.Reset() req.Header.CopyTo(&dst.Header) req.uri.CopyTo(&dst.uri) dst.parsedURI = req.parsedURI req.postArgs.CopyTo(&dst.postArgs) dst.parsedPostArgs = req.parsedPostArgs // do not copy multipartForm - it will be automatically // re-created on the first call to MultipartForm. } // CopyTo copies resp contents to dst except of body stream. func (resp *Response) CopyTo(dst *Response) { resp.copyToSkipBody(dst) if resp.body != nil { dst.bodyBuffer().Set(resp.body.B) } else if dst.body != nil { dst.body.Reset() } } func (resp *Response) copyToSkipBody(dst *Response) { dst.Reset() resp.Header.CopyTo(&dst.Header) dst.SkipBody = resp.SkipBody } func swapRequestBody(a, b *Request) { a.body, b.body = b.body, a.body a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream } func swapResponseBody(a, b *Response) { a.body, b.body = b.body, a.body a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream } // URI returns request URI func (req *Request) URI() *URI { req.parseURI() return &req.uri } func (req *Request) parseURI() { if req.parsedURI { return } req.parsedURI = true req.uri.parseQuick(req.Header.RequestURI(), &req.Header) } // PostArgs returns POST arguments. func (req *Request) PostArgs() *Args { req.parsePostArgs() return &req.postArgs } func (req *Request) parsePostArgs() { if req.parsedPostArgs { return } req.parsedPostArgs = true if !bytes.HasPrefix(req.Header.ContentType(), strPostArgsContentType) { return } req.postArgs.ParseBytes(req.bodyBytes()) } // ErrNoMultipartForm means that the request's Content-Type // isn't 'multipart/form-data'. var ErrNoMultipartForm = errors.New("request has no multipart/form-data Content-Type") // MultipartForm returns requests's multipart form. // // Returns ErrNoMultipartForm if request's Content-Type // isn't 'multipart/form-data'. // // RemoveMultipartFormFiles must be called after returned multipart form // is processed. func (req *Request) MultipartForm() (*multipart.Form, error) { if req.multipartForm != nil { return req.multipartForm, nil } req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) if len(req.multipartFormBoundary) == 0 { return nil, ErrNoMultipartForm } ce := req.Header.peek(strContentEncoding) body := req.bodyBytes() if bytes.Equal(ce, strGzip) { // Do not care about memory usage here. var err error if body, err = AppendGunzipBytes(nil, body); err != nil { return nil, fmt.Errorf("cannot gunzip request body: %s", err) } } else if len(ce) > 0 { return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) } f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) if err != nil { return nil, err } req.multipartForm = f return f, nil } func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) { var buf ByteBuffer if err := WriteMultipartForm(&buf, f, boundary); err != nil { return nil, err } return buf.B, nil } // WriteMultipartForm writes the given multipart form f with the given // boundary to w. func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error { // Do not care about memory allocations here, since multipart // form processing is slooow. if len(boundary) == 0 { panic("BUG: form boundary cannot be empty") } mw := multipart.NewWriter(w) if err := mw.SetBoundary(boundary); err != nil { return fmt.Errorf("cannot use form boundary %q: %s", boundary, err) } // marshal values for k, vv := range f.Value { for _, v := range vv { if err := mw.WriteField(k, v); err != nil { return fmt.Errorf("cannot write form field %q value %q: %s", k, v, err) } } } // marshal files for k, fvv := range f.File { for _, fv := range fvv { vw, err := mw.CreateFormFile(k, fv.Filename) if err != nil { return fmt.Errorf("cannot create form file %q (%q): %s", k, fv.Filename, err) } fh, err := fv.Open() if err != nil { return fmt.Errorf("cannot open form file %q (%q): %s", k, fv.Filename, err) } if _, err = copyZeroAlloc(vw, fh); err != nil { return fmt.Errorf("error when copying form file %q (%q): %s", k, fv.Filename, err) } if err = fh.Close(); err != nil { return fmt.Errorf("cannot close form file %q (%q): %s", k, fv.Filename, err) } } } if err := mw.Close(); err != nil { return fmt.Errorf("error when closing multipart form writer: %s", err) } return nil } func readMultipartForm(r io.Reader, boundary string, size, maxInMemoryFileSize int) (*multipart.Form, error) { // Do not care about memory allocations here, since they are tiny // compared to multipart data (aka multi-MB files) usually sent // in multipart/form-data requests. if size <= 0 { panic(fmt.Sprintf("BUG: form size must be greater than 0. Given %d", size)) } lr := io.LimitReader(r, int64(size)) mr := multipart.NewReader(lr, boundary) f, err := mr.ReadForm(int64(maxInMemoryFileSize)) if err != nil { return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err) } return f, nil } // Reset clears request contents. func (req *Request) Reset() { req.Header.Reset() req.resetSkipHeader() } func (req *Request) resetSkipHeader() { req.ResetBody() req.uri.Reset() req.parsedURI = false req.postArgs.Reset() req.parsedPostArgs = false } // RemoveMultipartFormFiles removes multipart/form-data temporary files // associated with the request. func (req *Request) RemoveMultipartFormFiles() { if req.multipartForm != nil { // Do not check for error, since these files may be deleted or moved // to new places by user code. req.multipartForm.RemoveAll() req.multipartForm = nil } req.multipartFormBoundary = "" } // Reset clears response contents. func (resp *Response) Reset() { resp.Header.Reset() resp.resetSkipHeader() resp.SkipBody = false } func (resp *Response) resetSkipHeader() { resp.ResetBody() } // Read reads request (including body) from the given r. // // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. // // If MayContinue returns true, the caller must: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. // // io.EOF is returned if r is closed before reading the first header byte. func (req *Request) Read(r *bufio.Reader) error { return req.ReadLimitBody(r, 0) } const defaultMaxInMemoryFileSize = 16 * 1024 * 1024 var errGetOnly = errors.New("non-GET request received") // ReadLimitBody reads request from the given r, limiting the body size. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. // // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. // // If MayContinue returns true, the caller must: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. // // io.EOF is returned if r is closed before reading the first header byte. func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { return req.readLimitBody(r, maxBodySize, false) } func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool) error { req.resetSkipHeader() err := req.Header.Read(r) if err != nil { return err } if getOnly && !req.Header.IsGet() { return errGetOnly } if req.Header.noBody() { return nil } if req.MayContinue() { // 'Expect: 100-continue' header found. Let the caller deciding // whether to read request body or // to return StatusExpectationFailed. return nil } return req.ContinueReadBody(r, maxBodySize) } // MayContinue returns true if the request contains // 'Expect: 100-continue' header. // // The caller must do one of the following actions if MayContinue returns true: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. func (req *Request) MayContinue() bool { return bytes.Equal(req.Header.peek(strExpect), str100Continue) } // ContinueReadBody reads request body if request header contains // 'Expect: 100-continue'. // // The caller must send StatusContinue response before calling this method. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int) error { var err error contentLength := req.Header.ContentLength() if contentLength > 0 { if maxBodySize > 0 && contentLength > maxBodySize { return ErrBodyTooLarge } // Pre-read multipart form data of known length. // This way we limit memory usage for large file uploads, since their contents // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) if err != nil { req.Reset() } return err } } if contentLength == -2 { // identity body has no sense for http requests, since // the end of body is determined by connection close. // So just ignore request body for requests without // 'Content-Length' and 'Transfer-Encoding' headers. req.Header.SetContentLength(0) return nil } bodyBuf := req.bodyBuffer() bodyBuf.Reset() bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B) if err != nil { req.Reset() return err } req.Header.SetContentLength(len(bodyBuf.B)) return nil } // Read reads response (including body) from the given r. // // io.EOF is returned if r is closed before reading the first header byte. func (resp *Response) Read(r *bufio.Reader) error { return resp.ReadLimitBody(r, 0) } // ReadLimitBody reads response from the given r, limiting the body size. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. // // io.EOF is returned if r is closed before reading the first header byte. func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { resp.resetSkipHeader() err := resp.Header.Read(r) if err != nil { return err } if resp.Header.StatusCode() == StatusContinue { // Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html . if err = resp.Header.Read(r); err != nil { return err } } if !resp.mustSkipBody() { bodyBuf := resp.bodyBuffer() bodyBuf.Reset() bodyBuf.B, err = readBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B) if err != nil { resp.Reset() return err } resp.Header.SetContentLength(len(bodyBuf.B)) } return nil } func (resp *Response) mustSkipBody() bool { return resp.SkipBody || resp.Header.mustSkipContentLength() } var errRequestHostRequired = errors.New("missing required Host header in request") // WriteTo writes request to w. It implements io.WriterTo. func (req *Request) WriteTo(w io.Writer) (int64, error) { return writeBufio(req, w) } // WriteTo writes response to w. It implements io.WriterTo. func (resp *Response) WriteTo(w io.Writer) (int64, error) { return writeBufio(resp, w) } func writeBufio(hw httpWriter, w io.Writer) (int64, error) { sw := acquireStatsWriter(w) bw := acquireBufioWriter(sw) err1 := hw.Write(bw) err2 := bw.Flush() releaseBufioWriter(bw) n := sw.bytesWritten releaseStatsWriter(sw) err := err1 if err == nil { err = err2 } return n, err } type statsWriter struct { w io.Writer bytesWritten int64 } func (w *statsWriter) Write(p []byte) (int, error) { n, err := w.w.Write(p) w.bytesWritten += int64(n) return n, err } func acquireStatsWriter(w io.Writer) *statsWriter { v := statsWriterPool.Get() if v == nil { return &statsWriter{ w: w, } } sw := v.(*statsWriter) sw.w = w return sw } func releaseStatsWriter(sw *statsWriter) { sw.w = nil sw.bytesWritten = 0 statsWriterPool.Put(sw) } var statsWriterPool sync.Pool func acquireBufioWriter(w io.Writer) *bufio.Writer { v := bufioWriterPool.Get() if v == nil { return bufio.NewWriter(w) } bw := v.(*bufio.Writer) bw.Reset(w) return bw } func releaseBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } var bufioWriterPool sync.Pool func (req *Request) onlyMultipartForm() bool { return req.multipartForm != nil && (req.body == nil || len(req.body.B) == 0) } // Write writes request to w. // // Write doesn't flush request to w for performance reasons. // // See also WriteTo. func (req *Request) Write(w *bufio.Writer) error { if len(req.Header.Host()) == 0 || req.parsedURI { uri := req.URI() host := uri.Host() if len(host) == 0 { return errRequestHostRequired } req.Header.SetHostBytes(host) req.Header.SetRequestURIBytes(uri.RequestURI()) } if req.bodyStream != nil { return req.writeBodyStream(w) } body := req.bodyBytes() var err error if req.onlyMultipartForm() { body, err = marshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { return fmt.Errorf("error when marshaling multipart form: %s", err) } req.Header.SetMultipartFormBoundary(req.multipartFormBoundary) } hasBody := !req.Header.noBody() if hasBody { req.Header.SetContentLength(len(body)) } if err = req.Header.Write(w); err != nil { return err } if hasBody { _, err = w.Write(body) } else if len(body) > 0 { return fmt.Errorf("non-zero body for non-POST request. body=%q", body) } return err } // WriteGzip writes response with gzipped body to w. // // The method gzips response body and sets 'Content-Encoding: gzip' // header before writing response to w. // // WriteGzip doesn't flush response to w for performance reasons. func (resp *Response) WriteGzip(w *bufio.Writer) error { return resp.WriteGzipLevel(w, CompressDefaultCompression) } // WriteGzipLevel writes response with gzipped body to w. // // Level is the desired compression level: // // * CompressNoCompression // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression // // The method gzips response body and sets 'Content-Encoding: gzip' // header before writing response to w. // // WriteGzipLevel doesn't flush response to w for performance reasons. func (resp *Response) WriteGzipLevel(w *bufio.Writer, level int) error { if err := resp.gzipBody(level); err != nil { return err } return resp.Write(w) } // WriteDeflate writes response with deflated body to w. // // The method deflates response body and sets 'Content-Encoding: deflate' // header before writing response to w. // // WriteDeflate doesn't flush response to w for performance reasons. func (resp *Response) WriteDeflate(w *bufio.Writer) error { return resp.WriteDeflateLevel(w, CompressDefaultCompression) } // WriteDeflateLevel writes response with deflated body to w. // // Level is the desired compression level: // // * CompressNoCompression // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression // // The method deflates response body and sets 'Content-Encoding: deflate' // header before writing response to w. // // WriteDeflateLevel doesn't flush response to w for performance reasons. func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error { if err := resp.deflateBody(level); err != nil { return err } return resp.Write(w) } func (resp *Response) gzipBody(level int) error { // Do not care about memory allocations here, since gzip is slow // and allocates a lot of memory by itself. if resp.bodyStream != nil { bs := resp.bodyStream resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { zw := acquireGzipWriter(sw, level) copyZeroAlloc(zw, bs) releaseGzipWriter(zw) if bsc, ok := bs.(io.Closer); ok { bsc.Close() } }) } else { w := responseBodyPool.Acquire() zw := acquireGzipWriter(w, level) _, err := zw.Write(resp.bodyBytes()) releaseGzipWriter(zw) if err != nil { return err } // Hack: swap resp.body with w. responseBodyPool.Release(resp.body) resp.body = w } resp.Header.SetCanonical(strContentEncoding, strGzip) return nil } func (resp *Response) deflateBody(level int) error { // Do not care about memory allocations here, since flate is slow // and allocates a lot of memory by itself. if resp.bodyStream != nil { bs := resp.bodyStream resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { zw := acquireFlateWriter(sw, level) copyZeroAlloc(zw, bs) releaseFlateWriter(zw) if bsc, ok := bs.(io.Closer); ok { bsc.Close() } }) } else { w := responseBodyPool.Acquire() zw := acquireFlateWriter(w, level) _, err := zw.Write(resp.bodyBytes()) releaseFlateWriter(zw) if err != nil { return err } // Hack: swap resp.body with w. responseBodyPool.Release(resp.body) resp.body = w } resp.Header.SetCanonical(strContentEncoding, strDeflate) return nil } // Write writes response to w. // // Write doesn't flush response to w for performance reasons. // // See also WriteTo. func (resp *Response) Write(w *bufio.Writer) error { sendBody := !resp.mustSkipBody() if resp.bodyStream != nil { return resp.writeBodyStream(w, sendBody) } body := resp.bodyBytes() bodyLen := len(body) if sendBody || bodyLen > 0 { resp.Header.SetContentLength(bodyLen) } if err := resp.Header.Write(w); err != nil { return err } if sendBody { if _, err := w.Write(body); err != nil { return err } } return nil } func (req *Request) writeBodyStream(w *bufio.Writer) error { var err error contentLength := req.Header.ContentLength() if contentLength < 0 { lrSize := limitedReaderSize(req.bodyStream) if lrSize >= 0 { contentLength = int(lrSize) if int64(contentLength) != lrSize { contentLength = -1 } if contentLength >= 0 { req.Header.SetContentLength(contentLength) } } } if contentLength >= 0 { if err = req.Header.Write(w); err == nil { err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength)) } } else { req.Header.SetContentLength(-1) if err = req.Header.Write(w); err == nil { err = writeBodyChunked(w, req.bodyStream) } } err1 := req.closeBodyStream() if err == nil { err = err1 } return err } func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) error { var err error contentLength := resp.Header.ContentLength() if contentLength < 0 { lrSize := limitedReaderSize(resp.bodyStream) if lrSize >= 0 { contentLength = int(lrSize) if int64(contentLength) != lrSize { contentLength = -1 } if contentLength >= 0 { resp.Header.SetContentLength(contentLength) } } } if contentLength >= 0 { if err = resp.Header.Write(w); err == nil && sendBody { err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)) } } else { resp.Header.SetContentLength(-1) if err = resp.Header.Write(w); err == nil && sendBody { err = writeBodyChunked(w, resp.bodyStream) } } err1 := resp.closeBodyStream() if err == nil { err = err1 } return err } func (req *Request) closeBodyStream() error { if req.bodyStream == nil { return nil } var err error if bsc, ok := req.bodyStream.(io.Closer); ok { err = bsc.Close() } req.bodyStream = nil return err } func (resp *Response) closeBodyStream() error { if resp.bodyStream == nil { return nil } var err error if bsc, ok := resp.bodyStream.(io.Closer); ok { err = bsc.Close() } resp.bodyStream = nil return err } // String returns request representation. // // Returns error message instead of request representation on error. // // Use Write instead of String for performance-critical code. func (req *Request) String() string { return getHTTPString(req) } // String returns response representation. // // Returns error message instead of response representation on error. // // Use Write instead of String for performance-critical code. func (resp *Response) String() string { return getHTTPString(resp) } func getHTTPString(hw httpWriter) string { w := AcquireByteBuffer() bw := bufio.NewWriter(w) if err := hw.Write(bw); err != nil { return err.Error() } if err := bw.Flush(); err != nil { return err.Error() } s := string(w.B) ReleaseByteBuffer(w) return s } type httpWriter interface { Write(w *bufio.Writer) error } func writeBodyChunked(w *bufio.Writer, r io.Reader) error { vbuf := copyBufPool.Get() buf := vbuf.([]byte) var err error var n int for { n, err = r.Read(buf) if n == 0 { if err == nil { panic("BUG: io.Reader returned 0, nil") } if err == io.EOF { if err = writeChunk(w, buf[:0]); err != nil { break } err = nil } break } if err = writeChunk(w, buf[:n]); err != nil { break } } copyBufPool.Put(vbuf) return err } func limitedReaderSize(r io.Reader) int64 { lr, ok := r.(*io.LimitedReader) if !ok { return -1 } return lr.N } func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error { if size > maxSmallFileSize { // w buffer must be empty for triggering // sendfile path in bufio.Writer.ReadFrom. if err := w.Flush(); err != nil { return err } } // Unwrap a single limited reader for triggering sendfile path // in net.TCPConn.ReadFrom. lr, ok := r.(*io.LimitedReader) if ok { r = lr.R } n, err := copyZeroAlloc(w, r) if ok { lr.N -= n } if n != size && err == nil { err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size) } return err } func copyZeroAlloc(w io.Writer, r io.Reader) (int64, error) { vbuf := copyBufPool.Get() buf := vbuf.([]byte) n, err := io.CopyBuffer(w, r, buf) copyBufPool.Put(vbuf) return n, err } var copyBufPool = sync.Pool{ New: func() interface{} { return make([]byte, 4096) }, } func writeChunk(w *bufio.Writer, b []byte) error { n := len(b) writeHexInt(w, n) w.Write(strCRLF) w.Write(b) _, err := w.Write(strCRLF) err1 := w.Flush() if err == nil { err = err1 } return err } // ErrBodyTooLarge is returned if either request or response body exceeds // the given limit. var ErrBodyTooLarge = errors.New("body size exceeds the given limit") func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ([]byte, error) { dst = dst[:0] if contentLength >= 0 { if maxBodySize > 0 && contentLength > maxBodySize { return dst, ErrBodyTooLarge } return appendBodyFixedSize(r, dst, contentLength) } if contentLength == -1 { return readBodyChunked(r, maxBodySize, dst) } return readBodyIdentity(r, maxBodySize, dst) } func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) { dst = dst[:cap(dst)] if len(dst) == 0 { dst = make([]byte, 1024) } offset := 0 for { nn, err := r.Read(dst[offset:]) if nn <= 0 { if err != nil { if err == io.EOF { return dst[:offset], nil } return dst[:offset], err } panic(fmt.Sprintf("BUG: bufio.Read() returned (%d, nil)", nn)) } offset += nn if maxBodySize > 0 && offset > maxBodySize { return dst[:offset], ErrBodyTooLarge } if len(dst) == offset { n := round2(2 * offset) if maxBodySize > 0 && n > maxBodySize { n = maxBodySize + 1 } b := make([]byte, n) copy(b, dst) dst = b } } } func appendBodyFixedSize(r *bufio.Reader, dst []byte, n int) ([]byte, error) { if n == 0 { return dst, nil } offset := len(dst) dstLen := offset + n if cap(dst) < dstLen { b := make([]byte, round2(dstLen)) copy(b, dst) dst = b } dst = dst[:dstLen] for { nn, err := r.Read(dst[offset:]) if nn <= 0 { if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } return dst[:offset], err } panic(fmt.Sprintf("BUG: bufio.Read() returned (%d, nil)", nn)) } offset += nn if offset == dstLen { return dst, nil } } } func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) { if len(dst) > 0 { panic("BUG: expected zero-length buffer") } strCRLFLen := len(strCRLF) for { chunkSize, err := parseChunkSize(r) if err != nil { return dst, err } if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize { return dst, ErrBodyTooLarge } dst, err = appendBodyFixedSize(r, dst, chunkSize+strCRLFLen) if err != nil { return dst, err } if !bytes.Equal(dst[len(dst)-strCRLFLen:], strCRLF) { return dst, fmt.Errorf("cannot find crlf at the end of chunk") } dst = dst[:len(dst)-strCRLFLen] if chunkSize == 0 { return dst, nil } } } func parseChunkSize(r *bufio.Reader) (int, error) { n, err := readHexInt(r) if err != nil { return -1, err } c, err := r.ReadByte() if err != nil { return -1, fmt.Errorf("cannot read '\r' char at the end of chunk size: %s", err) } if c != '\r' { return -1, fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, '\r') } c, err = r.ReadByte() if err != nil { return -1, fmt.Errorf("cannot read '\n' char at the end of chunk size: %s", err) } if c != '\n' { return -1, fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, '\n') } return n, nil } func round2(n int) int { if n <= 0 { return 0 } n-- x := uint(0) for n > 0 { n >>= 1 x++ } return 1 << x } golang-github-valyala-fasthttp-20160617/http_test.go000066400000000000000000001307231273074646000223570ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "fmt" "io" "io/ioutil" "mime/multipart" "strings" "testing" ) func TestRequestHostFromRequestURI(t *testing.T) { hExpected := "foobar.com" var req Request req.SetRequestURI("http://proxy-host:123/foobar?baz") req.SetHost(hExpected) h := req.Host() if string(h) != hExpected { t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected) } } func TestRequestHostFromHeader(t *testing.T) { hExpected := "foobar.com" var req Request req.Header.SetHost(hExpected) h := req.Host() if string(h) != hExpected { t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected) } } func TestRequestContentTypeWithCharsetIssue100(t *testing.T) { expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8" expectedBody := "0123=56789" s := fmt.Sprintf("POST / HTTP/1.1\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", expectedContentType, len(expectedBody), expectedBody) br := bufio.NewReader(bytes.NewBufferString(s)) var r Request if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } body := r.Body() if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } ct := r.Header.ContentType() if string(ct) != expectedContentType { t.Fatalf("unexpected content-type %q. Expecting %q", ct, expectedContentType) } args := r.PostArgs() if args.Len() != 1 { t.Fatalf("unexpected number of POST args: %d. Expecting 1", args.Len()) } av := args.Peek("0123") if string(av) != "56789" { t.Fatalf("unexpected POST arg value: %q. Expecting %q", av, "56789") } } func TestRequestReadMultipartFormWithFile(t *testing.T) { s := `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` br := bufio.NewReader(bytes.NewBufferString(s)) var r Request if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "tailfoobar" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") } f, err := r.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer r.RemoveMultipartFormFiles() // verify values if len(f.Value) != 1 { t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value)) } for k, vv := range f.Value { if k != "f1" { t.Fatalf("unexpected value name %q. Expecting %q", k, "f1") } if len(vv) != 1 { t.Fatalf("unexpected number of values %d. Expecting 1", len(vv)) } v := vv[0] if v != "value1" { t.Fatalf("unexpected value %q. Expecting %q", v, "value1") } } // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != "application/octet-stream" { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } } func TestRequestRequestURI(t *testing.T) { var r Request // Set request uri via SetRequestURI() uri := "/foo/bar?baz" r.SetRequestURI(uri) if string(r.RequestURI()) != uri { t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri) } // Set request uri via Request.URI().Update() r.Reset() uri = "/aa/bbb?ccc=sdfsdf" r.URI().Update(uri) if string(r.RequestURI()) != uri { t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri) } // update query args in the request uri qa := r.URI().QueryArgs() qa.Reset() qa.Set("foo", "bar") uri = "/aa/bbb?foo=bar" if string(r.RequestURI()) != uri { t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri) } } func TestRequestUpdateURI(t *testing.T) { var r Request r.Header.SetHost("aaa.bbb") r.SetRequestURI("/lkjkl/kjl") // Modify request uri and host via URI() object and make sure // the requestURI and Host header are properly updated u := r.URI() u.SetPath("/123/432.html") u.SetHost("foobar.com") a := u.QueryArgs() a.Set("aaa", "bcse") s := r.String() if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") { t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s) } if strings.Index(s, "\r\nHost: foobar.com\r\n") < 0 { t.Fatalf("cannot find %q in %q", "\r\nHost: foobar.com\r\n", s) } } func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) { var r Request s := "foobar baz abc" if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(bytes.NewBufferString(s), len(s)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } for i := 0; i < 10; i++ { body := r.Body() if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i) } } } func TestResponseBodyStreamMultipleBodyCalls(t *testing.T) { var r Response s := "foobar baz abc" if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(bytes.NewBufferString(s), len(s)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } for i := 0; i < 10; i++ { body := r.Body() if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i) } } } func TestRequestBodyWriteToPlain(t *testing.T) { var r Request expectedS := "foobarbaz" r.AppendBodyString(expectedS) testBodyWriteTo(t, &r, expectedS, true) } func TestResponseBodyWriteToPlain(t *testing.T) { var r Response expectedS := "foobarbaz" r.AppendBodyString(expectedS) testBodyWriteTo(t, &r, expectedS, true) } func TestResponseBodyWriteToStream(t *testing.T) { var r Response expectedS := "aaabbbccc" buf := bytes.NewBufferString(expectedS) if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(buf, len(expectedS)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } testBodyWriteTo(t, &r, expectedS, false) } func TestRequestBodyWriteToMultipart(t *testing.T) { expectedS := "--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar--\r\n" s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: %d\r\n\r\n%s", len(expectedS), expectedS) var r Request br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } testBodyWriteTo(t, &r, expectedS, true) } type bodyWriterTo interface { BodyWriteTo(io.Writer) error Body() []byte } func testBodyWriteTo(t *testing.T, bw bodyWriterTo, expectedS string, isRetainedBody bool) { var buf ByteBuffer if err := bw.BodyWriteTo(&buf); err != nil { t.Fatalf("unexpected error: %s", err) } s := buf.B if string(s) != expectedS { t.Fatalf("unexpected result %q. Expecting %q", s, expectedS) } body := bw.Body() if isRetainedBody { if string(body) != expectedS { t.Fatalf("unexpected body %q. Expecting %q", body, expectedS) } } else { if len(body) > 0 { t.Fatalf("unexpected non-zero body after BodyWriteTo: %q", body) } } } func TestRequestReadEOF(t *testing.T) { var r Request br := bufio.NewReader(&bytes.Buffer{}) err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } // incomplete request mustn't return io.EOF br = bufio.NewReader(bytes.NewBufferString("POST / HTTP/1.1\r\nContent-Type: aa\r\nContent-Length: 1234\r\n\r\nIncomplete body")) err = r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("expecting non-EOF error") } } func TestResponseReadEOF(t *testing.T) { var r Response br := bufio.NewReader(&bytes.Buffer{}) err := r.Read(br) if err == nil { t.Fatalf("expecting error") } if err != io.EOF { t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF) } // incomplete response mustn't return io.EOF br = bufio.NewReader(bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\nIncomplete body")) err = r.Read(br) if err == nil { t.Fatalf("expecting error") } if err == io.EOF { t.Fatalf("expecting non-EOF error") } } func TestResponseWriteTo(t *testing.T) { var r Response r.SetBodyString("foobar") s := r.String() var buf ByteBuffer n, err := r.WriteTo(&buf) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != int64(len(s)) { t.Fatalf("unexpected response length %d. Expecting %d", n, len(s)) } if string(buf.B) != s { t.Fatalf("unexpected response %q. Expecting %q", buf.B, s) } } func TestRequestWriteTo(t *testing.T) { var r Request r.SetRequestURI("http://foobar.com/aaa/bbb") s := r.String() var buf ByteBuffer n, err := r.WriteTo(&buf) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != int64(len(s)) { t.Fatalf("unexpected request length %d. Expecting %d", n, len(s)) } if string(buf.B) != s { t.Fatalf("unexpected request %q. Expecting %q", buf.B, s) } } func TestResponseSkipBody(t *testing.T) { var r Response // set StatusNotModified r.Header.SetStatusCode(StatusNotModified) r.SetBodyString("foobar") s := r.String() if strings.Contains(s, "\r\n\r\nfoobar") { t.Fatalf("unexpected non-zero body in response %q", s) } if strings.Contains(s, "Content-Length: ") { t.Fatalf("unexpected content-length in response %q", s) } if strings.Contains(s, "Content-Type: ") { t.Fatalf("unexpected content-type in response %q", s) } // set StatusNoContent r.Header.SetStatusCode(StatusNoContent) r.SetBodyString("foobar") s = r.String() if strings.Contains(s, "\r\n\r\nfoobar") { t.Fatalf("unexpected non-zero body in response %q", s) } if strings.Contains(s, "Content-Length: ") { t.Fatalf("unexpected content-length in response %q", s) } if strings.Contains(s, "Content-Type: ") { t.Fatalf("unexpected content-type in response %q", s) } // explicitly skip body r.Header.SetStatusCode(StatusOK) r.SkipBody = true r.SetBodyString("foobar") s = r.String() if strings.Contains(s, "\r\n\r\nfoobar") { t.Fatalf("unexpected non-zero body in response %q", s) } if !strings.Contains(s, "Content-Length: 6\r\n") { t.Fatalf("expecting content-length in response %q", s) } if !strings.Contains(s, "Content-Type: ") { t.Fatalf("expecting content-type in response %q", s) } } func TestRequestNoContentLength(t *testing.T) { var r Request r.Header.SetMethod("HEAD") r.Header.SetHost("foobar") s := r.String() if strings.Contains(s, "Content-Length: ") { t.Fatalf("unexpected content-length in HEAD request %q", s) } r.Header.SetMethod("POST") fmt.Fprintf(r.BodyWriter(), "foobar body") s = r.String() if !strings.Contains(s, "Content-Length: ") { t.Fatalf("missing content-length header in non-GET request %q", s) } } func TestRequestReadGzippedBody(t *testing.T) { var r Request bodyOriginal := "foo bar baz compress me better!" body := AppendGzipBytes(nil, []byte(bodyOriginal)) s := fmt.Sprintf("POST /foobar HTTP/1.1\r\nContent-Type: foo/bar\r\nContent-Encoding: gzip\r\nContent-Length: %d\r\n\r\n%s", len(body), body) br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if string(r.Header.Peek("Content-Encoding")) != "gzip" { t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek("Content-Encoding"), "gzip") } if r.Header.ContentLength() != len(body) { t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body)) } if string(r.Body()) != string(body) { t.Fatalf("unexpected body: %q. Expecting %q", r.Body(), body) } bodyGunzipped, err := AppendGunzipBytes(nil, r.Body()) if err != nil { t.Fatalf("unexpected error when uncompressing data: %s", err) } if string(bodyGunzipped) != bodyOriginal { t.Fatalf("unexpected uncompressed body %q. Expecting %q", bodyGunzipped, bodyOriginal) } } func TestRequestReadPostNoBody(t *testing.T) { var r Request s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa" br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if string(r.Header.RequestURI()) != "/foo/bar" { t.Fatalf("unexpected request uri %q. Expecting %q", r.Header.RequestURI(), "/foo/bar") } if string(r.Header.ContentType()) != "aaa/bbb" { t.Fatalf("unexpected content-type %q. Expecting %q", r.Header.ContentType(), "aaa/bbb") } if len(r.Body()) != 0 { t.Fatalf("unexpected body found %q. Expecting empty body", r.Body()) } if r.Header.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", r.Header.ContentLength()) } tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "aaaa" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "aaaa") } } func TestRequestContinueReadBody(t *testing.T) { s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" br := bufio.NewReader(bytes.NewBufferString(s)) var r Request if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if !r.MayContinue() { t.Fatalf("MayContinue must return true") } if err := r.ContinueReadBody(br, 0); err != nil { t.Fatalf("error when reading request body: %s", err) } body := r.Body() if string(body) != "abcde" { t.Fatalf("unexpected body %q. Expecting %q", body, "abcde") } tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "f4343" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343") } } func TestRequestMayContinue(t *testing.T) { var r Request if r.MayContinue() { t.Fatalf("MayContinue on empty request must return false") } r.Header.Set("Expect", "123sdfds") if r.MayContinue() { t.Fatalf("MayContinue on invalid Expect header must return false") } r.Header.Set("Expect", "100-continue") if !r.MayContinue() { t.Fatalf("MayContinue on 'Expect: 100-continue' header must return true") } } func TestResponseGzipStream(t *testing.T) { var r Response if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "foo") w.Flush() w.Write([]byte("barbaz")) w.Flush() fmt.Fprintf(w, "1234") if err := w.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } }) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } testResponseGzipExt(t, &r, "foobarbaz1234") } func TestResponseDeflateStream(t *testing.T) { var r Response if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStreamWriter(func(w *bufio.Writer) { w.Write([]byte("foo")) w.Flush() fmt.Fprintf(w, "barbaz") w.Flush() w.Write([]byte("1234")) if err := w.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } }) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } testResponseDeflateExt(t, &r, "foobarbaz1234") } func TestResponseDeflate(t *testing.T) { testResponseDeflate(t, "") testResponseDeflate(t, "abdasdfsdaa") testResponseDeflate(t, "asoiowqoieroqweiruqwoierqo") } func TestResponseGzip(t *testing.T) { testResponseGzip(t, "") testResponseGzip(t, "foobarbaz") testResponseGzip(t, "abasdwqpweoweporweprowepr") } func testResponseDeflate(t *testing.T, s string) { var r Response r.SetBodyString(s) testResponseDeflateExt(t, &r, s) } func testResponseDeflateExt(t *testing.T, r *Response, s string) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) if err := r.WriteDeflate(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var r1 Response br := bufio.NewReader(&buf) if err := r1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := r1.Header.Peek("Content-Encoding") if string(ce) != "deflate" { t.Fatalf("unexpected Content-Encoding %q. Expecting %q", ce, "deflate") } body, err := r1.BodyInflate() if err != nil { t.Fatalf("unexpected error: %s", err) } if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q", body, s) } } func testResponseGzip(t *testing.T, s string) { var r Response r.SetBodyString(s) testResponseGzipExt(t, &r, s) } func testResponseGzipExt(t *testing.T, r *Response, s string) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) if err := r.WriteGzip(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var r1 Response br := bufio.NewReader(&buf) if err := r1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := r1.Header.Peek("Content-Encoding") if string(ce) != "gzip" { t.Fatalf("unexpected Content-Encoding %q. Expecting %q", ce, "gzip") } body, err := r1.BodyGunzip() if err != nil { t.Fatalf("unexpected error: %s", err) } if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q", body, s) } } func TestRequestMultipartForm(t *testing.T) { var w bytes.Buffer mw := multipart.NewWriter(&w) for i := 0; i < 10; i++ { k := fmt.Sprintf("key_%d", i) v := fmt.Sprintf("value_%d", i) if err := mw.WriteField(k, v); err != nil { t.Fatalf("unexpected error: %s", err) } } boundary := mw.Boundary() if err := mw.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } formData := w.Bytes() for i := 0; i < 5; i++ { formData = testRequestMultipartForm(t, boundary, formData, 10) } // verify request unmarshalling / marshalling s := "POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: 213\r\n\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_1\"\r\n\r\nvalue_1\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_2\"\r\n\r\nvalue_2\r\n--foobar--\r\n" var req Request br := bufio.NewReader(bytes.NewBufferString(s)) if err := req.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } s = req.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := req.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } testRequestMultipartForm(t, "foobar", req.Body(), 3) } func testRequestMultipartForm(t *testing.T, boundary string, formData []byte, partsCount int) []byte { s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s", boundary, len(formData), formData) var req Request r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := req.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } f, err := req.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer req.RemoveMultipartFormFiles() if len(f.File) > 0 { t.Fatalf("unexpected files found in the multipart form: %d", len(f.File)) } if len(f.Value) != partsCount { t.Fatalf("unexpected number of values found: %d. Expecting %d", len(f.Value), partsCount) } for k, vv := range f.Value { if len(vv) != 1 { t.Fatalf("unexpected number of values found for key=%q: %d. Expecting 1", k, len(vv)) } if !strings.HasPrefix(k, "key_") { t.Fatalf("unexpected key prefix=%q. Expecting %q", k, "key_") } v := vv[0] if !strings.HasPrefix(v, "value_") { t.Fatalf("unexpected value prefix=%q. expecting %q", v, "value_") } if k[len("key_"):] != v[len("value_"):] { t.Fatalf("key and value suffixes don't match: %q vs %q", k, v) } } return req.Body() } func TestResponseReadLimitBody(t *testing.T) { // response with content-length testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100) testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9) // chunked response testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 100) testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 2) // identity response testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 6) testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 106) testResponseReadLimitBodyError(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 5) } func TestRequestReadLimitBody(t *testing.T) { // request with content-length testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 9) testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 92) testRequestReadLimitBodyError(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 5) // chunked request testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9) testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 999) testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8) } func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int) { var req Response r := bytes.NewBufferString(s) br := bufio.NewReader(r) err := req.ReadLimitBody(br, maxBodySize) if err == nil { t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize) } if err != ErrBodyTooLarge { t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize) } } func testResponseReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) { var req Response r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := req.ReadLimitBody(br, maxBodySize); err != nil { t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize) } } func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int) { var req Request r := bytes.NewBufferString(s) br := bufio.NewReader(r) err := req.ReadLimitBody(br, maxBodySize) if err == nil { t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize) } if err != ErrBodyTooLarge { t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize) } } func testRequestReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) { var req Request r := bytes.NewBufferString(s) br := bufio.NewReader(r) if err := req.ReadLimitBody(br, maxBodySize); err != nil { t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize) } } func TestRequestString(t *testing.T) { var r Request r.SetRequestURI("http://foobar.com/aaa") s := r.String() expectedS := "GET /aaa HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar.com\r\n\r\n" if s != expectedS { t.Fatalf("unexpected request: %q. Expecting %q", s, expectedS) } } func TestRequestBodyWriter(t *testing.T) { var r Request w := r.BodyWriter() for i := 0; i < 10; i++ { fmt.Fprintf(w, "%d", i) } if string(r.Body()) != "0123456789" { t.Fatalf("unexpected body %q. Expecting %q", r.Body(), "0123456789") } } func TestResponseBodyWriter(t *testing.T) { var r Response w := r.BodyWriter() for i := 0; i < 10; i++ { fmt.Fprintf(w, "%d", i) } if string(r.Body()) != "0123456789" { t.Fatalf("unexpected body %q. Expecting %q", r.Body(), "0123456789") } } func TestRequestWriteRequestURINoHost(t *testing.T) { var req Request req.Header.SetRequestURI("http://google.com/foo/bar?baz=aaa") var w bytes.Buffer bw := bufio.NewWriter(&w) if err := req.Write(bw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("unexepcted error: %s", err) } var req1 Request br := bufio.NewReader(&w) if err := req1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if string(req1.Header.Host()) != "google.com" { t.Fatalf("unexpected host: %q. Expecting %q", req1.Header.Host(), "google.com") } if string(req.Header.RequestURI()) != "/foo/bar?baz=aaa" { t.Fatalf("unexpected requestURI: %q. Expecting %q", req.Header.RequestURI(), "/foo/bar?baz=aaa") } // verify that Request.Write returns error on non-absolute RequestURI req.Reset() req.Header.SetRequestURI("/foo/bar") w.Reset() bw.Reset(&w) if err := req.Write(bw); err == nil { t.Fatalf("expecting error") } } func TestSetRequestBodyStreamFixedSize(t *testing.T) { testSetRequestBodyStream(t, "a", false) testSetRequestBodyStream(t, string(createFixedBody(4097)), false) testSetRequestBodyStream(t, string(createFixedBody(100500)), false) } func TestSetResponseBodyStreamFixedSize(t *testing.T) { testSetResponseBodyStream(t, "a", false) testSetResponseBodyStream(t, string(createFixedBody(4097)), false) testSetResponseBodyStream(t, string(createFixedBody(100500)), false) } func TestSetRequestBodyStreamChunked(t *testing.T) { testSetRequestBodyStream(t, "", true) body := "foobar baz aaa bbb ccc" testSetRequestBodyStream(t, body, true) body = string(createFixedBody(10001)) testSetRequestBodyStream(t, body, true) } func TestSetResponseBodyStreamChunked(t *testing.T) { testSetResponseBodyStream(t, "", true) body := "foobar baz aaa bbb ccc" testSetResponseBodyStream(t, body, true) body = string(createFixedBody(10001)) testSetResponseBodyStream(t, body, true) } func testSetRequestBodyStream(t *testing.T, body string, chunked bool) { var req Request req.Header.SetHost("foobar.com") req.Header.SetMethod("POST") bodySize := len(body) if chunked { bodySize = -1 } if req.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } req.SetBodyStream(bytes.NewBufferString(body), bodySize) if !req.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer bw := bufio.NewWriter(&w) if err := req.Write(bw); err != nil { t.Fatalf("unexpected error when writing request: %s. body=%q", err, body) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body) } var req1 Request br := bufio.NewReader(&w) if err := req1.Read(br); err != nil { t.Fatalf("unexpected error when reading request: %s. body=%q", err, body) } if string(req1.Body()) != body { t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body) } } func testSetResponseBodyStream(t *testing.T, body string, chunked bool) { var resp Response bodySize := len(body) if chunked { bodySize = -1 } if resp.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } resp.SetBodyStream(bytes.NewBufferString(body), bodySize) if !resp.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer bw := bufio.NewWriter(&w) if err := resp.Write(bw); err != nil { t.Fatalf("unexpected error when writing response: %s. body=%q", err, body) } if err := bw.Flush(); err != nil { t.Fatalf("unexpected error when flushing response: %s. body=%q", err, body) } var resp1 Response br := bufio.NewReader(&w) if err := resp1.Read(br); err != nil { t.Fatalf("unexpected error when reading response: %s. body=%q", err, body) } if string(resp1.Body()) != body { t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body) } } func TestRound2(t *testing.T) { testRound2(t, 0, 0) testRound2(t, 1, 1) testRound2(t, 2, 2) testRound2(t, 3, 4) testRound2(t, 4, 4) testRound2(t, 5, 8) testRound2(t, 7, 8) testRound2(t, 8, 8) testRound2(t, 9, 16) testRound2(t, 0x10001, 0x20000) } func testRound2(t *testing.T, n, expectedRound2 int) { if round2(n) != expectedRound2 { t.Fatalf("Unexpected round2(%d)=%d. Expected %d", n, round2(n), expectedRound2) } } func TestRequestReadChunked(t *testing.T) { var req Request s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail" r := bytes.NewBufferString(s) rb := bufio.NewReader(r) err := req.Read(rb) if err != nil { t.Fatalf("Unexpected error when reading chunked request: %s", err) } expectedBody := "abc12345" if string(req.Body()) != expectedBody { t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody) } verifyRequestHeader(t, &req.Header, 8, "/foo", "google.com", "", "aa/bb") verifyTrailer(t, rb, "trail") } func TestResponseReadWithoutBody(t *testing.T) { var resp Response testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\nfoobar", false, 304, 1235, "aa", "foobar") testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false, 204, -1, "aab", "123\r\nss") testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false, 123, 3434, "xxx", "aaaa") testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true, 200, 123, "text/xml", "xxxx") // '100 Continue' must be skipped. testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true, 329, 894, "qwe", "foobar") } func testResponseReadWithoutBody(t *testing.T, resp *Response, s string, skipBody bool, expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string) { r := bytes.NewBufferString(s) rb := bufio.NewReader(r) resp.SkipBody = skipBody err := resp.Read(rb) if err != nil { t.Fatalf("Unexpected error when reading response without body: %s. response=%q", err, s) } if len(resp.Body()) != 0 { t.Fatalf("Unexpected response body %q. Expected %q. response=%q", resp.Body(), "", s) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType) verifyTrailer(t, rb, expectedTrailer) // verify that ordinal response is read after null-body response resp.SkipBody = false testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa", 300, 5, "bar", "56789", "aaa") } func TestRequestSuccess(t *testing.T) { // empty method, user-agent and body testRequestSuccess(t, "", "/foo/bar", "google.com", "", "", "GET") // non-empty user-agent testRequestSuccess(t, "GET", "/foo/bar", "google.com", "MSIE", "", "GET") // non-empty method testRequestSuccess(t, "HEAD", "/aaa", "fobar", "", "", "HEAD") // POST method with body testRequestSuccess(t, "POST", "/bbb", "aaa.com", "Chrome aaa", "post body", "POST") // PUT method with body testRequestSuccess(t, "PUT", "/aa/bb", "a.com", "ome aaa", "put body", "PUT") // only host is set testRequestSuccess(t, "", "", "gooble.com", "", "", "GET") } func TestResponseSuccess(t *testing.T) { // 200 response testResponseSuccess(t, 200, "test/plain", "server", "foobar", 200, "test/plain", "server") // response with missing statusCode testResponseSuccess(t, 0, "text/plain", "server", "foobar", 200, "text/plain", "server") // response with missing server testResponseSuccess(t, 500, "aaa", "", "aaadfsd", 500, "aaa", string(defaultServerName)) // empty body testResponseSuccess(t, 200, "bbb", "qwer", "", 200, "bbb", "qwer") // missing content-type testResponseSuccess(t, 200, "", "asdfsd", "asdf", 200, string(defaultContentType), "asdfsd") } func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, body string, expectedStatusCode int, expectedContentType, expectedServerName string) { var resp Response resp.SetStatusCode(statusCode) resp.Header.Set("Content-Type", contentType) resp.Header.Set("Server", serverName) resp.SetBody([]byte(body)) w := &bytes.Buffer{} bw := bufio.NewWriter(w) err := resp.Write(bw) if err != nil { t.Fatalf("Unexpected error when calling Response.Write(): %s", err) } if err = bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err) } var resp1 Response br := bufio.NewReader(w) if err = resp1.Read(br); err != nil { t.Fatalf("Unexpected error when calling Response.Read(): %s", err) } if resp1.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code: %d. Expected %d", resp1.StatusCode(), expectedStatusCode) } if resp1.Header.ContentLength() != len(body) { t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength(), len(body)) } if string(resp1.Header.Peek("Content-Type")) != expectedContentType { t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Peek("Content-Type"), expectedContentType) } if string(resp1.Header.Peek("Server")) != expectedServerName { t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Peek("Server"), expectedServerName) } if !bytes.Equal(resp1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body(), body) } } func TestRequestWriteError(t *testing.T) { // no host testRequestWriteError(t, "", "/foo/bar", "", "", "") // get with body testRequestWriteError(t, "GET", "/foo/bar", "aaa.com", "", "foobar") } func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, body string) { var req Request req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) req.Header.Set("Host", host) req.Header.Set("User-Agent", userAgent) req.SetBody([]byte(body)) w := &ByteBuffer{} bw := bufio.NewWriter(w) err := req.Write(bw) if err == nil { t.Fatalf("Expecting error when writing request=%#v", &req) } } func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body, expectedMethod string) { var req Request req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) req.Header.Set("Host", host) req.Header.Set("User-Agent", userAgent) req.SetBody([]byte(body)) contentType := "foobar" if method == "POST" { req.Header.Set("Content-Type", contentType) } w := &bytes.Buffer{} bw := bufio.NewWriter(w) err := req.Write(bw) if err != nil { t.Fatalf("Unexpected error when calling Request.Write(): %s", err) } if err = bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err) } var req1 Request br := bufio.NewReader(w) if err = req1.Read(br); err != nil { t.Fatalf("Unexpected error when calling Request.Read(): %s", err) } if string(req1.Header.Method()) != expectedMethod { t.Fatalf("Unexpected method: %q. Expected %q", req1.Header.Method(), expectedMethod) } if len(requestURI) == 0 { requestURI = "/" } if string(req1.Header.RequestURI()) != requestURI { t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI(), requestURI) } if string(req1.Header.Peek("Host")) != host { t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Peek("Host"), host) } if len(userAgent) == 0 { userAgent = string(defaultUserAgent) } if string(req1.Header.Peek("User-Agent")) != userAgent { t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Peek("User-Agent"), userAgent) } if !bytes.Equal(req1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", req1.Body(), body) } if method == "POST" && string(req1.Header.Peek("Content-Type")) != contentType { t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Peek("Content-Type"), contentType) } } func TestResponseReadSuccess(t *testing.T) { resp := &Response{} // usual response testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789", 200, 10, "foo/bar", "0123456789", "") // zero response testResponseReadSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n", 500, 0, "foo/bar", "", "") // response with trailer testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa", 300, 5, "bar", "56789", "aaa") // no conent-length ('identity' transfer-encoding) testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxc", 200, 4, "foobar", "zxxc", "") // explicitly stated 'Transfer-Encoding: identity' testResponseReadSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag", 234, 3, "xxx", "xag", "") // big 'identity' response body := string(createFixedBody(100500)) testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body, 200, 100500, "aa", body, "") // chunked response testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\n\r\nzzzzz", 200, 6, "text/html", "qwerty", "zzzzz") // chunked response with non-chunked Transfer-Encoding. testResponseReadSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\n\r\nwe", 230, 4, "text", "erty", "we") // zero chunked response testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\nzzz", 200, 0, "text/html", "", "zzz") } func TestResponseReadError(t *testing.T) { resp := &Response{} // empty response testResponseReadError(t, resp, "") // invalid header testResponseReadError(t, resp, "foobar") // empty body testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\n") // short body testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\nshort") } func testResponseReadError(t *testing.T, resp *Response, response string) { r := bytes.NewBufferString(response) rb := bufio.NewReader(r) err := resp.Read(rb) if err == nil { t.Fatalf("Expecting error for response=%q", response) } testResponseReadSuccess(t, resp, "HTTP/1.1 303 Redisred sedfs sdf\r\nContent-Type: aaa\r\nContent-Length: 5\r\n\r\nHELLOaaa", 303, 5, "aaa", "HELLO", "aaa") } func testResponseReadSuccess(t *testing.T, resp *Response, response string, expectedStatusCode, expectedContentLength int, expectedContenType, expectedBody, expectedTrailer string) { r := bytes.NewBufferString(response) rb := bufio.NewReader(r) err := resp.Read(rb) if err != nil { t.Fatalf("Unexpected error: %s", err) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContenType) if !bytes.Equal(resp.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyTrailer(t, rb, expectedTrailer) } func TestReadBodyFixedSize(t *testing.T) { var b []byte // zero-size body testReadBodyFixedSize(t, b, 0) // small-size body testReadBodyFixedSize(t, b, 3) // medium-size body testReadBodyFixedSize(t, b, 1024) // large-size body testReadBodyFixedSize(t, b, 1024*1024) // smaller body after big one testReadBodyFixedSize(t, b, 34345) } func TestReadBodyChunked(t *testing.T) { var b []byte // zero-size body testReadBodyChunked(t, b, 0) // small-size body testReadBodyChunked(t, b, 5) // medium-size body testReadBodyChunked(t, b, 43488) // big body testReadBodyChunked(t, b, 3*1024*1024) // smaler body after big one testReadBodyChunked(t, b, 12343) } func TestRequestURI(t *testing.T) { host := "foobar.com" requestURI := "/aaa/bb+b%20d?ccc=ddd&qqq#1334dfds&=d" expectedPathOriginal := "/aaa/bb+b%20d" expectedPath := "/aaa/bb+b d" expectedQueryString := "ccc=ddd&qqq" expectedHash := "1334dfds&=d" var req Request req.Header.Set("Host", host) req.Header.SetRequestURI(requestURI) uri := req.URI() if string(uri.Host()) != host { t.Fatalf("Unexpected host %q. Expected %q", uri.Host(), host) } if string(uri.PathOriginal()) != expectedPathOriginal { t.Fatalf("Unexpected source path %q. Expected %q", uri.PathOriginal(), expectedPathOriginal) } if string(uri.Path()) != expectedPath { t.Fatalf("Unexpected path %q. Expected %q", uri.Path(), expectedPath) } if string(uri.QueryString()) != expectedQueryString { t.Fatalf("Unexpected query string %q. Expected %q", uri.QueryString(), expectedQueryString) } if string(uri.Hash()) != expectedHash { t.Fatalf("Unexpected hash %q. Expected %q", uri.Hash(), expectedHash) } } func TestRequestPostArgsSuccess(t *testing.T) { var req Request testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 0\r\n\r\n", 0, "foo=", "=") testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 18\r\n\r\nfoo&b%20r=b+z=&qwe", 3, "foo=", "b r=b z=", "qwe=") } func TestRequestPostArgsError(t *testing.T) { var req Request // non-post testRequestPostArgsError(t, &req, "GET /aa HTTP/1.1\r\nHost: aaa\r\n\r\n") // invalid content-type testRequestPostArgsError(t, &req, "POST /aa HTTP/1.1\r\nHost: aaa\r\nContent-Type: text/html\r\nContent-Length: 5\r\n\r\nabcde") } func testRequestPostArgsError(t *testing.T, req *Request, s string) { r := bytes.NewBufferString(s) br := bufio.NewReader(r) err := req.Read(br) if err != nil { t.Fatalf("Unexpected error when reading %q: %s", s, err) } ss := req.PostArgs().String() if len(ss) != 0 { t.Fatalf("unexpected post args: %q. Expecting empty post args", ss) } } func testRequestPostArgsSuccess(t *testing.T, req *Request, s string, expectedArgsLen int, expectedArgs ...string) { r := bytes.NewBufferString(s) br := bufio.NewReader(r) err := req.Read(br) if err != nil { t.Fatalf("Unexpected error when reading %q: %s", s, err) } args := req.PostArgs() if args.Len() != expectedArgsLen { t.Fatalf("Unexpected args len %d. Expected %d for %q", args.Len(), expectedArgsLen, s) } for _, x := range expectedArgs { tmp := strings.SplitN(x, "=", 2) k := tmp[0] v := tmp[1] vv := string(args.Peek(k)) if vv != v { t.Fatalf("Unexpected value for key %q: %q. Expected %q for %q", k, vv, v, s) } } } func testReadBodyChunked(t *testing.T, b []byte, bodySize int) { body := createFixedBody(bodySize) chunkedBody := createChunkedBody(body) expectedTrailer := []byte("chunked shit") chunkedBody = append(chunkedBody, expectedTrailer...) r := bytes.NewBuffer(chunkedBody) br := bufio.NewReader(r) b, err := readBody(br, -1, 0, nil) if err != nil { t.Fatalf("Unexpected error for bodySize=%d: %s. body=%q, chunkedBody=%q", bodySize, err, body, chunkedBody) } if !bytes.Equal(b, body) { t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q. chunkedBody=%q", bodySize, b, body, chunkedBody) } verifyTrailer(t, br, string(expectedTrailer)) } func testReadBodyFixedSize(t *testing.T, b []byte, bodySize int) { body := createFixedBody(bodySize) expectedTrailer := []byte("traler aaaa") bodyWithTrailer := append(body, expectedTrailer...) r := bytes.NewBuffer(bodyWithTrailer) br := bufio.NewReader(r) b, err := readBody(br, bodySize, 0, nil) if err != nil { t.Fatalf("Unexpected error in ReadResponseBody(%d): %s", bodySize, err) } if !bytes.Equal(b, body) { t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q", bodySize, b, body) } verifyTrailer(t, br, string(expectedTrailer)) } func createFixedBody(bodySize int) []byte { var b []byte for i := 0; i < bodySize; i++ { b = append(b, byte(i%10)+'0') } return b } func createChunkedBody(body []byte) []byte { var b []byte chunkSize := 1 for len(body) > 0 { if chunkSize > len(body) { chunkSize = len(body) } b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) b = append(b, body[:chunkSize]...) b = append(b, []byte("\r\n")...) body = body[chunkSize:] chunkSize++ } return append(b, []byte("0\r\n\r\n")...) } golang-github-valyala-fasthttp-20160617/nocopy.go000066400000000000000000000004141273074646000216410ustar00rootroot00000000000000package fasthttp // Embed this type into a struct, which mustn't be copied, // so `go vet` gives a warning if this struct is copied. // // See https://github.com/golang/go/issues/8005#issuecomment-190753527 for details. type noCopy struct{} func (*noCopy) Lock() {} golang-github-valyala-fasthttp-20160617/peripconn.go000066400000000000000000000033271273074646000223350ustar00rootroot00000000000000package fasthttp import ( "fmt" "net" "sync" ) type perIPConnCounter struct { pool sync.Pool lock sync.Mutex m map[uint32]int } func (cc *perIPConnCounter) Register(ip uint32) int { cc.lock.Lock() if cc.m == nil { cc.m = make(map[uint32]int) } n := cc.m[ip] + 1 cc.m[ip] = n cc.lock.Unlock() return n } func (cc *perIPConnCounter) Unregister(ip uint32) { cc.lock.Lock() if cc.m == nil { cc.lock.Unlock() panic("BUG: perIPConnCounter.Register() wasn't called") } n := cc.m[ip] - 1 if n < 0 { cc.lock.Unlock() panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip)) } cc.m[ip] = n cc.lock.Unlock() } type perIPConn struct { net.Conn ip uint32 perIPConnCounter *perIPConnCounter } func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn { v := counter.pool.Get() if v == nil { v = &perIPConn{ perIPConnCounter: counter, } } c := v.(*perIPConn) c.Conn = conn c.ip = ip return c } func releasePerIPConn(c *perIPConn) { c.Conn = nil c.perIPConnCounter.pool.Put(c) } func (c *perIPConn) Close() error { err := c.Conn.Close() c.perIPConnCounter.Unregister(c.ip) releasePerIPConn(c) return err } func getUint32IP(c net.Conn) uint32 { return ip2uint32(getConnIP4(c)) } func getConnIP4(c net.Conn) net.IP { addr := c.RemoteAddr() ipAddr, ok := addr.(*net.TCPAddr) if !ok { return net.IPv4zero } return ipAddr.IP.To4() } func ip2uint32(ip net.IP) uint32 { if len(ip) != 4 { return 0 } return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } func uint322ip(ip uint32) net.IP { b := make([]byte, 4) b[0] = byte(ip >> 24) b[1] = byte(ip >> 16) b[2] = byte(ip >> 8) b[3] = byte(ip) return b } golang-github-valyala-fasthttp-20160617/peripconn_test.go000066400000000000000000000021121273074646000233630ustar00rootroot00000000000000package fasthttp import ( "testing" ) func TestIPxUint32(t *testing.T) { testIPxUint32(t, 0) testIPxUint32(t, 10) testIPxUint32(t, 0x12892392) } func testIPxUint32(t *testing.T, n uint32) { ip := uint322ip(n) nn := ip2uint32(ip) if n != nn { t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n) } } func TestPerIPConnCounter(t *testing.T) { var cc perIPConnCounter expectPanic(t, func() { cc.Unregister(123) }) for i := 1; i < 100; i++ { if n := cc.Register(123); n != i { t.Fatalf("Unexpected counter value=%d. Expected %d", n, i) } } n := cc.Register(456) if n != 1 { t.Fatalf("Unexpected counter value=%d. Expected 1", n) } for i := 1; i < 100; i++ { cc.Unregister(123) } cc.Unregister(456) expectPanic(t, func() { cc.Unregister(123) }) expectPanic(t, func() { cc.Unregister(456) }) n = cc.Register(123) if n != 1 { t.Fatalf("Unexpected counter value=%d. Expected 1", n) } cc.Unregister(123) } func expectPanic(t *testing.T, f func()) { defer func() { if r := recover(); r == nil { t.Fatalf("Expecting panic") } }() f() } golang-github-valyala-fasthttp-20160617/requestctx_setbodystreamwriter_example_test.go000066400000000000000000000013461273074646000315220ustar00rootroot00000000000000package fasthttp_test import ( "bufio" "fmt" "log" "time" "github.com/valyala/fasthttp" ) func ExampleRequestCtx_SetBodyStreamWriter() { // Start fasthttp server for streaming responses. if err := fasthttp.ListenAndServe(":8080", responseStreamHandler); err != nil { log.Fatalf("unexpected error in server: %s", err) } } func responseStreamHandler(ctx *fasthttp.RequestCtx) { // Send the response in chunks and wait for a second between each chunk. ctx.SetBodyStreamWriter(func(w *bufio.Writer) { for i := 0; i < 10; i++ { fmt.Fprintf(w, "this is a message number %d", i) // Do not forget flushing streamed data to the client. if err := w.Flush(); err != nil { return } time.Sleep(time.Second) } }) } golang-github-valyala-fasthttp-20160617/reuseport/000077500000000000000000000000001273074646000220345ustar00rootroot00000000000000golang-github-valyala-fasthttp-20160617/reuseport/LICENSE000066400000000000000000000020651273074646000230440ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2014 Max Riveiro Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.golang-github-valyala-fasthttp-20160617/reuseport/reuseport.go000066400000000000000000000050051273074646000244130ustar00rootroot00000000000000// +build linux darwin dragonfly freebsd netbsd openbsd rumprun // Package reuseport provides TCP net.Listener with SO_REUSEPORT support. // // SO_REUSEPORT allows linear scaling server performance on multi-CPU servers. // See https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ for more details :) // // The package is based on https://github.com/kavu/go_reuseport . package reuseport import ( "errors" "fmt" "net" "os" "syscall" ) func getSockaddr(network, addr string) (sa syscall.Sockaddr, soType int, err error) { // TODO: add support for tcp and tcp6 networks. if network != "tcp4" { return nil, -1, errors.New("only tcp4 network is supported") } tcpAddr, err := net.ResolveTCPAddr(network, addr) if err != nil { return nil, -1, err } var sa4 syscall.SockaddrInet4 sa4.Port = tcpAddr.Port copy(sa4.Addr[:], tcpAddr.IP.To4()) return &sa4, syscall.AF_INET, nil } // ErrNoReusePort is returned if the OS doesn't support SO_REUSEPORT. type ErrNoReusePort struct { err error } // Error implements error interface. func (e *ErrNoReusePort) Error() string { return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %s", e.err) } // Listen returns TCP listener with SO_REUSEPORT option set. // // Only tcp4 network is supported. // // ErrNoReusePort error is returned if the system doesn't support SO_REUSEPORT. func Listen(network, addr string) (l net.Listener, err error) { var ( soType, fd int file *os.File sockaddr syscall.Sockaddr ) if sockaddr, soType, err = getSockaddr(network, addr); err != nil { return nil, err } syscall.ForkLock.RLock() fd, err = syscall.Socket(soType, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) if err == nil { syscall.CloseOnExec(fd) } syscall.ForkLock.RUnlock() if err != nil { return nil, err } if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { syscall.Close(fd) return nil, err } if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil { syscall.Close(fd) return nil, &ErrNoReusePort{err} } if err = syscall.Bind(fd, sockaddr); err != nil { syscall.Close(fd) return nil, err } if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil { syscall.Close(fd) return nil, err } name := fmt.Sprintf("reuseport.%d.%s.%s", os.Getpid(), network, addr) file = os.NewFile(uintptr(fd), name) if l, err = net.FileListener(file); err != nil { file.Close() return nil, err } if err = file.Close(); err != nil { l.Close() return nil, err } return l, err } golang-github-valyala-fasthttp-20160617/reuseport/reuseport_bsd.go000066400000000000000000000002161273074646000252420ustar00rootroot00000000000000// +build darwin dragonfly freebsd netbsd openbsd rumprun package reuseport import ( "syscall" ) const soReusePort = syscall.SO_REUSEPORT golang-github-valyala-fasthttp-20160617/reuseport/reuseport_example_test.go000066400000000000000000000007341273074646000271710ustar00rootroot00000000000000package reuseport_test import ( "fmt" "log" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/reuseport" ) func ExampleListen() { ln, err := reuseport.Listen("tcp4", "localhost:12345") if err != nil { log.Fatalf("error in reuseport listener: %s", err) } if err = fasthttp.Serve(ln, requestHandler); err != nil { log.Fatalf("error in fasthttp Server: %s", err) } } func requestHandler(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hello, world!") } golang-github-valyala-fasthttp-20160617/reuseport/reuseport_linux.go000066400000000000000000000000751273074646000256340ustar00rootroot00000000000000// +build linux package reuseport const soReusePort = 0x0F golang-github-valyala-fasthttp-20160617/reuseport/reuseport_test.go000066400000000000000000000043531273074646000254570ustar00rootroot00000000000000package reuseport import ( "fmt" "io/ioutil" "net" "testing" "time" ) func TestNewListener(t *testing.T) { addr := "localhost:10081" serversCount := 20 requestsCount := 1000 var lns []net.Listener doneCh := make(chan struct{}, serversCount) for i := 0; i < serversCount; i++ { ln, err := Listen("tcp4", addr) if err != nil { t.Fatalf("cannot create listener %d: %s", i, err) } go func() { serveEcho(t, ln) doneCh <- struct{}{} }() lns = append(lns, ln) } for i := 0; i < requestsCount; i++ { c, err := net.Dial("tcp4", addr) if err != nil { t.Fatalf("%d. unexpected error when dialing: %s", i, err) } req := fmt.Sprintf("request number %d", i) if _, err = c.Write([]byte(req)); err != nil { t.Fatalf("%d. unexpected error when writing request: %s", i, err) } if err = c.(*net.TCPConn).CloseWrite(); err != nil { t.Fatalf("%d. unexpected error when closing write end of the connection: %s", i, err) } var resp []byte ch := make(chan struct{}) go func() { if resp, err = ioutil.ReadAll(c); err != nil { t.Fatalf("%d. unexpected error when reading response: %s", i, err) } close(ch) }() select { case <-ch: case <-time.After(200 * time.Millisecond): t.Fatalf("%d. timeout when waiting for response: %s", i, err) } if string(resp) != req { t.Fatalf("%d. unexpected response %q. Expecting %q", i, resp, req) } if err = c.Close(); err != nil { t.Fatalf("%d. unexpected error when closing connection: %s", i, err) } } for _, ln := range lns { if err := ln.Close(); err != nil { t.Fatalf("unexpected error when closing listener: %s", err) } } for i := 0; i < serversCount; i++ { select { case <-doneCh: case <-time.After(200 * time.Millisecond): t.Fatalf("timeout when waiting for servers to be closed") } } } func serveEcho(t *testing.T, ln net.Listener) { for { c, err := ln.Accept() if err != nil { break } req, err := ioutil.ReadAll(c) if err != nil { t.Fatalf("unepxected error when reading request: %s", err) } if _, err = c.Write(req); err != nil { t.Fatalf("unexpected error when writing response: %s", err) } if err = c.Close(); err != nil { t.Fatalf("unexpected error when closing connection: %s", err) } } } golang-github-valyala-fasthttp-20160617/server.go000066400000000000000000001457461273074646000216620ustar00rootroot00000000000000package fasthttp import ( "bufio" "crypto/tls" "errors" "fmt" "io" "log" "mime/multipart" "net" "os" "runtime/debug" "strings" "sync" "sync/atomic" "time" ) // ServeConn serves HTTP requests from the given connection // using the given handler. // // ServeConn returns nil if all requests from the c are successfully served. // It returns non-nil error otherwise. // // Connection c must immediately propagate all the data passed to Write() // to the client. Otherwise requests' processing may hang. // // ServeConn closes c before returning. func ServeConn(c net.Conn, handler RequestHandler) error { v := serverPool.Get() if v == nil { v = &Server{} } s := v.(*Server) s.Handler = handler err := s.ServeConn(c) s.Handler = nil serverPool.Put(v) return err } var serverPool sync.Pool // Serve serves incoming connections from the given listener // using the given handler. // // Serve blocks until the given listener returns permanent error. func Serve(ln net.Listener, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.Serve(ln) } // ServeTLS serves HTTPS requests from the given net.Listener // using the given handler. // // certFile and keyFile are paths to TLS certificate and key files. func ServeTLS(ln net.Listener, certFile, keyFile string, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ServeTLS(ln, certFile, keyFile) } // ServeTLSEmbed serves HTTPS requests from the given net.Listener // using the given handler. // // certData and keyData must contain valid TLS certificate and key data. func ServeTLSEmbed(ln net.Listener, certData, keyData []byte, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ServeTLSEmbed(ln, certData, keyData) } // ListenAndServe serves HTTP requests from the given TCP addr // using the given handler. func ListenAndServe(addr string, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ListenAndServe(addr) } // ListenAndServeUNIX serves HTTP requests from the given UNIX addr // using the given handler. // // The function deletes existing file at addr before starting serving. // // The server sets the given file mode for the UNIX addr. func ListenAndServeUNIX(addr string, mode os.FileMode, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ListenAndServeUNIX(addr, mode) } // ListenAndServeTLS serves HTTPS requests from the given TCP addr // using the given handler. // // certFile and keyFile are paths to TLS certificate and key files. func ListenAndServeTLS(addr, certFile, keyFile string, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ListenAndServeTLS(addr, certFile, keyFile) } // ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr // using the given handler. // // certData and keyData must contain valid TLS certificate and key data. func ListenAndServeTLSEmbed(addr string, certData, keyData []byte, handler RequestHandler) error { s := &Server{ Handler: handler, } return s.ListenAndServeTLSEmbed(addr, certData, keyData) } // RequestHandler must process incoming requests. // // RequestHandler must call ctx.TimeoutError() before returning // if it keeps references to ctx and/or its' members after the return. // Consider wrapping RequestHandler into TimeoutHandler if response time // must be limited. type RequestHandler func(ctx *RequestCtx) // Server implements HTTP server. // // Default Server settings should satisfy the majority of Server users. // Adjust Server settings only if you really understand the consequences. // // It is forbidden copying Server instances. Create new Server instances // instead. // // It is safe to call Server methods from concurrently running goroutines. type Server struct { noCopy noCopy // Handler for processing incoming requests. Handler RequestHandler // Server name for sending in response headers. // // Default server name is used if left blank. Name string // The maximum number of concurrent connections the server may serve. // // DefaultConcurrency is used if not set. Concurrency int // Whether to disable keep-alive connections. // // The server will close all the incoming connections after sending // the first response to client if this option is set to true. // // By default keep-alive connections are enabled. DisableKeepalive bool // Per-connection buffer size for requests' reading. // This also limits the maximum header size. // // Increase this buffer if your clients send multi-KB RequestURIs // and/or multi-KB headers (for example, BIG cookies). // // Default buffer size is used if not set. ReadBufferSize int // Per-connection buffer size for responses' writing. // // Default buffer size is used if not set. WriteBufferSize int // Maximum duration for reading the full request (including body). // // This also limits the maximum duration for idle keep-alive // connections. // // By default request read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for writing the full response (including body). // // By default response write timeout is unlimited. WriteTimeout time.Duration // Maximum number of concurrent client connections allowed per IP. // // By default unlimited number of concurrent connections // may be established to the server from a single IP address. MaxConnsPerIP int // Maximum number of requests served per connection. // // The server closes connection after the last request. // 'Connection: close' header is added to the last response. // // By default unlimited number of requests may be served per connection. MaxRequestsPerConn int // Maximum keep-alive connection lifetime. // // The server closes keep-alive connection after its' lifetime // expiration. // // See also ReadTimeout for limiting the duration of idle keep-alive // connections. // // By default keep-alive connection lifetime is unlimited. MaxKeepaliveDuration time.Duration // Maximum request body size. // // The server rejects requests with bodies exceeding this limit. // // By default request body size is unlimited. MaxRequestBodySize int // Aggressively reduces memory usage at the cost of higher CPU usage // if set to true. // // Try enabling this option only if the server consumes too much memory // serving mostly idle keep-alive connections (more than 1M concurrent // connections). This may reduce memory usage by up to 50%. // // Aggressive memory usage reduction is disabled by default. ReduceMemoryUsage bool // Rejects all non-GET requests if set to true. // // This option is useful as anti-DoS protection for servers // accepting only GET requests. The request size is limited // by ReadBufferSize if GetOnly is set. // // Server accepts all the requests by default. GetOnly bool // Logs all errors, including the most frequent // 'connection reset by peer', 'broken pipe' and 'connection timeout' // errors. Such errors are common in production serving real-world // clients. // // By default the most frequent errors such as // 'connection reset by peer', 'broken pipe' and 'connection timeout' // are suppressed in order to limit output log traffic. LogAllErrors bool // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // incoming requests to other servers expecting case-sensitive // header names. See https://github.com/valyala/fasthttp/issues/57 // for details. // // By default request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool // Logger, which is used by RequestCtx.Logger(). // // By default standard logger from log package is used. Logger Logger concurrency uint32 concurrencyCh chan struct{} perIPConnCounter perIPConnCounter serverName atomic.Value ctxPool sync.Pool readerPool sync.Pool writerPool sync.Pool hijackConnPool sync.Pool bytePool sync.Pool } // TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout // error with the given msg to the client if h didn't return during // the given duration. // // The returned handler may return StatusTooManyRequests error with the given // msg to the client if there are more than Server.Concurrency concurrent // handlers h are running at the moment. func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler { if timeout <= 0 { return h } return func(ctx *RequestCtx) { concurrencyCh := ctx.s.concurrencyCh select { case concurrencyCh <- struct{}{}: default: ctx.Error(msg, StatusTooManyRequests) return } ch := ctx.timeoutCh if ch == nil { ch = make(chan struct{}, 1) ctx.timeoutCh = ch } go func() { h(ctx) ch <- struct{}{} <-concurrencyCh }() ctx.timeoutTimer = initTimer(ctx.timeoutTimer, timeout) select { case <-ch: case <-ctx.timeoutTimer.C: ctx.TimeoutError(msg) } stopTimer(ctx.timeoutTimer) } } // CompressHandler returns RequestHandler that transparently compresses // response body generated by h if the request contains 'gzip' or 'deflate' // 'Accept-Encoding' header. func CompressHandler(h RequestHandler) RequestHandler { return CompressHandlerLevel(h, CompressDefaultCompression) } // CompressHandlerLevel returns RequestHandler that transparently compresses // response body generated by h if the request contains 'gzip' or 'deflate' // 'Accept-Encoding' header. // // Level is the desired compression level: // // * CompressNoCompression // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression func CompressHandlerLevel(h RequestHandler, level int) RequestHandler { return func(ctx *RequestCtx) { h(ctx) ce := ctx.Response.Header.PeekBytes(strContentEncoding) if len(ce) > 0 { // Do not compress responses with non-empty // Content-Encoding. return } if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { ctx.Response.gzipBody(level) } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { ctx.Response.deflateBody(level) } } } // RequestCtx contains incoming request and manages outgoing response. // // It is forbidden copying RequestCtx instances. // // RequestHandler should avoid holding references to incoming RequestCtx and/or // its' members after the return. // If holding RequestCtx references after the return is unavoidable // (for instance, ctx is passed to a separate goroutine and ctx lifetime cannot // be controlled), then the RequestHandler MUST call ctx.TimeoutError() // before return. // // It is unsafe modifying/reading RequestCtx instance from concurrently // running goroutines. The only exception is TimeoutError*, which may be called // while other goroutines accessing RequestCtx. type RequestCtx struct { noCopy noCopy // Incoming request. // // Copying Request by value is forbidden. Use pointer to Request instead. Request Request // Outgoing response. // // Copying Response by value is forbidden. Use pointer to Response instead. Response Response userValues userData lastReadDuration time.Duration connID uint64 connRequestNum uint64 connTime time.Time time time.Time logger ctxLogger s *Server c net.Conn fbr firstByteReader timeoutResponse *Response timeoutCh chan struct{} timeoutTimer *time.Timer hijackHandler HijackHandler } // HijackHandler must process the hijacked connection c. // // The connection c is automatically closed after returning from HijackHandler. // // The connection c must not be used after returning from the handler. type HijackHandler func(c net.Conn) // Hijack registers the given handler for connection hijacking. // // The handler is called after returning from RequestHandler // and sending http response. The current connection is passed // to the handler. The connection is automatically closed after // returning from the handler. // // The server skips calling the handler in the following cases: // // * 'Connection: close' header exists in either request or response. // * Unexpected error during response writing to the connection. // // The server stops processing requests from hijacked connections. // Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc. // aren't applied to hijacked connections. // // The handler must not retain references to ctx members. // // Arbitrary 'Connection: Upgrade' protocols may be implemented // with HijackHandler. For instance, // // * WebSocket ( https://en.wikipedia.org/wiki/WebSocket ) // * HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 ) // func (ctx *RequestCtx) Hijack(handler HijackHandler) { ctx.hijackHandler = handler } // SetUserValue stores the given value (arbitrary object) // under the given key in ctx. // // The value stored in ctx may be obtained by UserValue*. // // This functionality may be useful for passing arbitrary values between // functions involved in request processing. // // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. func (ctx *RequestCtx) SetUserValue(key string, value interface{}) { ctx.userValues.Set(key, value) } // SetUserValueBytes stores the given value (arbitrary object) // under the given key in ctx. // // The value stored in ctx may be obtained by UserValue*. // // This functionality may be useful for passing arbitrary values between // functions involved in request processing. // // All the values stored in ctx are deleted after returning from RequestHandler. func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) { ctx.userValues.SetBytes(key, value) } // UserValue returns the value stored via SetUserValue* under the given key. func (ctx *RequestCtx) UserValue(key string) interface{} { return ctx.userValues.Get(key) } // UserValueBytes returns the value stored via SetUserValue* // under the given key. func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} { return ctx.userValues.GetBytes(key) } // IsTLS returns true if the underlying connection is tls.Conn. // // tls.Conn is an encrypted connection (aka SSL, HTTPS). func (ctx *RequestCtx) IsTLS() bool { _, ok := ctx.c.(*tls.Conn) return ok } // TLSConnectionState returns TLS connection state. // // The function returns nil if the underlying connection isn't tls.Conn. // // The returned state may be used for verifying TLS version, client certificates, // etc. func (ctx *RequestCtx) TLSConnectionState() *tls.ConnectionState { tlsConn, ok := ctx.c.(*tls.Conn) if !ok { return nil } state := tlsConn.ConnectionState() return &state } type firstByteReader struct { c net.Conn ch byte byteRead bool } func (r *firstByteReader) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } nn := 0 if !r.byteRead { b[0] = r.ch b = b[1:] r.byteRead = true nn = 1 } n, err := r.c.Read(b) return n + nn, err } // Logger is used for logging formatted messages. type Logger interface { // Printf must have the same semantics as log.Printf. Printf(format string, args ...interface{}) } var ctxLoggerLock sync.Mutex type ctxLogger struct { ctx *RequestCtx logger Logger } func (cl *ctxLogger) Printf(format string, args ...interface{}) { ctxLoggerLock.Lock() msg := fmt.Sprintf(format, args...) ctx := cl.ctx req := &ctx.Request cl.logger.Printf("%.3f #%016X - %s<->%s - %s %s - %s", time.Since(ctx.Time()).Seconds(), ctx.ID(), ctx.LocalAddr(), ctx.RemoteAddr(), req.Header.Method(), ctx.URI().FullURI(), msg) ctxLoggerLock.Unlock() } var zeroTCPAddr = &net.TCPAddr{ IP: net.IPv4zero, } // ID returns unique ID of the request. func (ctx *RequestCtx) ID() uint64 { return (ctx.connID << 32) | ctx.connRequestNum } // ConnID returns unique connection ID. // // This ID may be used to match distinct requests to the same incoming // connection. func (ctx *RequestCtx) ConnID() uint64 { return ctx.connID } // Time returns RequestHandler call time. func (ctx *RequestCtx) Time() time.Time { return ctx.time } // ConnTime returns the time server starts serving the connection // the current request came from. func (ctx *RequestCtx) ConnTime() time.Time { return ctx.connTime } // ConnRequestNum returns request sequence number // for the current connection. func (ctx *RequestCtx) ConnRequestNum() uint64 { return ctx.connRequestNum } // SetConnectionClose sets 'Connection: close' response header and closes // connection after the RequestHandler returns. func (ctx *RequestCtx) SetConnectionClose() { ctx.Response.SetConnectionClose() } // SetStatusCode sets response status code. func (ctx *RequestCtx) SetStatusCode(statusCode int) { ctx.Response.SetStatusCode(statusCode) } // SetContentType sets response Content-Type. func (ctx *RequestCtx) SetContentType(contentType string) { ctx.Response.Header.SetContentType(contentType) } // SetContentTypeBytes sets response Content-Type. // // It is safe modifying contentType buffer after function return. func (ctx *RequestCtx) SetContentTypeBytes(contentType []byte) { ctx.Response.Header.SetContentTypeBytes(contentType) } // RequestURI returns RequestURI. // // This uri is valid until returning from RequestHandler. func (ctx *RequestCtx) RequestURI() []byte { return ctx.Request.Header.RequestURI() } // URI returns requested uri. // // The uri is valid until returning from RequestHandler. func (ctx *RequestCtx) URI() *URI { return ctx.Request.URI() } // Referer returns request referer. // // The referer is valid until returning from RequestHandler. func (ctx *RequestCtx) Referer() []byte { return ctx.Request.Header.Referer() } // UserAgent returns User-Agent header value from the request. func (ctx *RequestCtx) UserAgent() []byte { return ctx.Request.Header.UserAgent() } // Path returns requested path. // // The path is valid until returning from RequestHandler. func (ctx *RequestCtx) Path() []byte { return ctx.URI().Path() } // Host returns requested host. // // The host is valid until returning from RequestHandler. func (ctx *RequestCtx) Host() []byte { return ctx.URI().Host() } // QueryArgs returns query arguments from RequestURI. // // It doesn't return POST'ed arguments - use PostArgs() for this. // // Returned arguments are valid until returning from RequestHandler. // // See also PostArgs, FormValue and FormFile. func (ctx *RequestCtx) QueryArgs() *Args { return ctx.URI().QueryArgs() } // PostArgs returns POST arguments. // // It doesn't return query arguments from RequestURI - use QueryArgs for this. // // Returned arguments are valid until returning from RequestHandler. // // See also QueryArgs, FormValue and FormFile. func (ctx *RequestCtx) PostArgs() *Args { return ctx.Request.PostArgs() } // MultipartForm returns requests's multipart form. // // Returns ErrNoMultipartForm if request's content-type // isn't 'multipart/form-data'. // // All uploaded temporary files are automatically deleted after // returning from RequestHandler. Either move or copy uploaded files // into new place if you want retaining them. // // Use SaveMultipartFile function for permanently saving uploaded file. // // The returned form is valid until returning from RequestHandler. // // See also FormFile and FormValue. func (ctx *RequestCtx) MultipartForm() (*multipart.Form, error) { return ctx.Request.MultipartForm() } // FormFile returns uploaded file associated with the given multipart form key. // // The file is automatically deleted after returning from RequestHandler, // so either move or copy uploaded file into new place if you want retaining it. // // Use SaveMultipartFile function for permanently saving uploaded file. // // The returned file header is valid until returning from RequestHandler. func (ctx *RequestCtx) FormFile(key string) (*multipart.FileHeader, error) { mf, err := ctx.MultipartForm() if err != nil { return nil, err } if mf.File == nil { return nil, err } fhh := mf.File[key] if fhh == nil { return nil, ErrMissingFile } return fhh[0], nil } // ErrMissingFile may be returned from FormFile when the is no uploaded file // associated with the given multipart form key. var ErrMissingFile = errors.New("there is no uploaded file associated with the given key") // SaveMultipartFile saves multipart file fh under the given filename path. func SaveMultipartFile(fh *multipart.FileHeader, path string) error { f, err := fh.Open() if err != nil { return err } defer f.Close() if ff, ok := f.(*os.File); ok { return os.Rename(ff.Name(), path) } ff, err := os.Create(path) if err != nil { return err } defer ff.Close() _, err = copyZeroAlloc(ff, f) return err } // FormValue returns form value associated with the given key. // // The value is searched in the following places: // // * Query string. // * POST or PUT body. // // There are more fine-grained methods for obtaining form values: // // * QueryArgs for obtaining values from query string. // * PostArgs for obtaining values from POST or PUT body. // * MultipartForm for obtaining values from multipart form. // * FormFile for obtaining uploaded files. // // The returned value is valid until returning from RequestHandler. func (ctx *RequestCtx) FormValue(key string) []byte { v := ctx.QueryArgs().Peek(key) if len(v) > 0 { return v } v = ctx.PostArgs().Peek(key) if len(v) > 0 { return v } mf, err := ctx.MultipartForm() if err == nil && mf.Value != nil { vv := mf.Value[key] if len(vv) > 0 { return []byte(vv[0]) } } return nil } // IsGet returns true if request method is GET. func (ctx *RequestCtx) IsGet() bool { return ctx.Request.Header.IsGet() } // IsPost returns true if request method is POST. func (ctx *RequestCtx) IsPost() bool { return ctx.Request.Header.IsPost() } // IsPut returns true if request method is PUT. func (ctx *RequestCtx) IsPut() bool { return ctx.Request.Header.IsPut() } // IsDelete returns true if request method is DELETE. func (ctx *RequestCtx) IsDelete() bool { return ctx.Request.Header.IsDelete() } // Method return request method. // // Returned value is valid until returning from RequestHandler. func (ctx *RequestCtx) Method() []byte { return ctx.Request.Header.Method() } // IsHead returns true if request method is HEAD. func (ctx *RequestCtx) IsHead() bool { return ctx.Request.Header.IsHead() } // RemoteAddr returns client address for the given request. // // Always returns non-nil result. func (ctx *RequestCtx) RemoteAddr() net.Addr { addr := ctx.c.RemoteAddr() if addr == nil { return zeroTCPAddr } return addr } // LocalAddr returns server address for the given request. // // Always returns non-nil result. func (ctx *RequestCtx) LocalAddr() net.Addr { addr := ctx.c.LocalAddr() if addr == nil { return zeroTCPAddr } return addr } // RemoteIP returns client ip for the given request. // // Always returns non-nil result. func (ctx *RequestCtx) RemoteIP() net.IP { x, ok := ctx.RemoteAddr().(*net.TCPAddr) if !ok { return net.IPv4zero } return x.IP } // Error sets response status code to the given value and sets response body // to the given message. func (ctx *RequestCtx) Error(msg string, statusCode int) { ctx.Response.Reset() ctx.SetStatusCode(statusCode) ctx.SetContentTypeBytes(defaultContentType) ctx.SetBodyString(msg) } // Success sets response Content-Type and body to the given values. func (ctx *RequestCtx) Success(contentType string, body []byte) { ctx.SetContentType(contentType) ctx.SetBody(body) } // SuccessString sets response Content-Type and body to the given values. func (ctx *RequestCtx) SuccessString(contentType, body string) { ctx.SetContentType(contentType) ctx.SetBodyString(body) } // Redirect sets 'Location: uri' response header and sets the given statusCode. // // statusCode must have one of the following values: // // * StatusMovedPermanently (301) // * StatusFound (302) // * StatusSeeOther (303) // * StatusTemporaryRedirect (307) // // All other statusCode values are replaced by StatusFound (302). // // The redirect uri may be either absolute or relative to the current // request uri. func (ctx *RequestCtx) Redirect(uri string, statusCode int) { u := AcquireURI() ctx.URI().CopyTo(u) u.Update(uri) ctx.redirect(u.FullURI(), statusCode) ReleaseURI(u) } // RedirectBytes sets 'Location: uri' response header and sets // the given statusCode. // // statusCode must have one of the following values: // // * StatusMovedPermanently (301) // * StatusFound (302) // * StatusSeeOther (303) // * StatusTemporaryRedirect (307) // // All other statusCode values are replaced by StatusFound (302). // // The redirect uri may be either absolute or relative to the current // request uri. func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) { s := b2s(uri) ctx.Redirect(s, statusCode) } func (ctx *RequestCtx) redirect(uri []byte, statusCode int) { ctx.Response.Header.SetCanonical(strLocation, uri) statusCode = getRedirectStatusCode(statusCode) ctx.Response.SetStatusCode(statusCode) } func getRedirectStatusCode(statusCode int) int { if statusCode == StatusMovedPermanently || statusCode == StatusFound || statusCode == StatusSeeOther || statusCode == StatusTemporaryRedirect { return statusCode } return StatusFound } // SetBody sets response body to the given value. // // It is safe re-using body argument after the function returns. func (ctx *RequestCtx) SetBody(body []byte) { ctx.Response.SetBody(body) } // SetBodyString sets response body to the given value. func (ctx *RequestCtx) SetBodyString(body string) { ctx.Response.SetBodyString(body) } // ResetBody resets response body contents. func (ctx *RequestCtx) ResetBody() { ctx.Response.ResetBody() } // SendFile sends local file contents from the given path as response body. // // This is a shortcut to ServeFile(ctx, path). // // SendFile logs all the errors via ctx.Logger. // // See also ServeFile, FSHandler and FS. func (ctx *RequestCtx) SendFile(path string) { ServeFile(ctx, path) } // SendFileBytes sends local file contents from the given path as response body. // // This is a shortcut to ServeFileBytes(ctx, path). // // SendFileBytes logs all the errors via ctx.Logger. // // See also ServeFileBytes, FSHandler and FS. func (ctx *RequestCtx) SendFileBytes(path []byte) { ServeFileBytes(ctx, path) } // IfModifiedSince returns true if lastModified exceeds 'If-Modified-Since' // value from the request header. // // The function returns true also 'If-Modified-Since' request header is missing. func (ctx *RequestCtx) IfModifiedSince(lastModified time.Time) bool { ifModStr := ctx.Request.Header.peek(strIfModifiedSince) if len(ifModStr) == 0 { return true } ifMod, err := ParseHTTPDate(ifModStr) if err != nil { return true } lastModified = lastModified.Truncate(time.Second) return ifMod.Before(lastModified) } // NotModified resets response and sets '304 Not Modified' response status code. func (ctx *RequestCtx) NotModified() { ctx.Response.Reset() ctx.SetStatusCode(StatusNotModified) } // NotFound resets response and sets '404 Not Found' response status code. func (ctx *RequestCtx) NotFound() { ctx.Response.Reset() ctx.SetStatusCode(StatusNotFound) ctx.SetBodyString("404 Page not found") } // Write writes p into response body. func (ctx *RequestCtx) Write(p []byte) (int, error) { ctx.Response.AppendBody(p) return len(p), nil } // WriteString appends s to response body. func (ctx *RequestCtx) WriteString(s string) (int, error) { ctx.Response.AppendBodyString(s) return len(s), nil } // PostBody returns POST request body. // // The returned value is valid until RequestHandler return. func (ctx *RequestCtx) PostBody() []byte { return ctx.Request.Body() } // SetBodyStream sets response body stream and, optionally body size. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // If bodySize is >= 0, then bodySize bytes must be provided by bodyStream // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // See also SetBodyStreamWriter. func (ctx *RequestCtx) SetBodyStream(bodyStream io.Reader, bodySize int) { ctx.Response.SetBodyStream(bodyStream, bodySize) } // SetBodyStreamWriter registers the given stream writer for populating // response body. // // Access to RequestCtx and/or its' members is forbidden from sw. // // This function may be used in the following cases: // // * if response body is too big (more than 10MB). // * if response body is streamed from slow external sources. // * if response body must be streamed to the client in chunks. // (aka `http server push`). func (ctx *RequestCtx) SetBodyStreamWriter(sw StreamWriter) { ctx.Response.SetBodyStreamWriter(sw) } // IsBodyStream returns true if response body is set via SetBodyStream*. func (ctx *RequestCtx) IsBodyStream() bool { return ctx.Response.IsBodyStream() } // Logger returns logger, which may be used for logging arbitrary // request-specific messages inside RequestHandler. // // Each message logged via returned logger contains request-specific information // such as request id, request duration, local address, remote address, // request method and request url. // // It is safe re-using returned logger for logging multiple messages // for the current request. // // The returned logger is valid until returning from RequestHandler. func (ctx *RequestCtx) Logger() Logger { if ctx.logger.ctx == nil { ctx.logger.ctx = ctx } if ctx.logger.logger == nil { ctx.logger.logger = ctx.s.logger() } return &ctx.logger } // TimeoutError sets response status code to StatusRequestTimeout and sets // body to the given msg. // // All response modifications after TimeoutError call are ignored. // // TimeoutError MUST be called before returning from RequestHandler if there are // references to ctx and/or its members in other goroutines remain. // // Usage of this function is discouraged. Prefer eliminating ctx references // from pending goroutines instead of using this function. func (ctx *RequestCtx) TimeoutError(msg string) { ctx.TimeoutErrorWithCode(msg, StatusRequestTimeout) } // TimeoutErrorWithCode sets response body to msg and response status // code to statusCode. // // All response modifications after TimeoutErrorWithCode call are ignored. // // TimeoutErrorWithCode MUST be called before returning from RequestHandler // if there are references to ctx and/or its members in other goroutines remain. // // Usage of this function is discouraged. Prefer eliminating ctx references // from pending goroutines instead of using this function. func (ctx *RequestCtx) TimeoutErrorWithCode(msg string, statusCode int) { var resp Response resp.SetStatusCode(statusCode) resp.SetBodyString(msg) ctx.TimeoutErrorWithResponse(&resp) } // TimeoutErrorWithResponse marks the ctx as timed out and sends the given // response to the client. // // All ctx modifications after TimeoutErrorWithResponse call are ignored. // // TimeoutErrorWithResponse MUST be called before returning from RequestHandler // if there are references to ctx and/or its members in other goroutines remain. // // Usage of this function is discouraged. Prefer eliminating ctx references // from pending goroutines instead of using this function. func (ctx *RequestCtx) TimeoutErrorWithResponse(resp *Response) { respCopy := &Response{} resp.CopyTo(respCopy) ctx.timeoutResponse = respCopy } // ListenAndServe serves HTTP requests from the given TCP4 addr. // // Pass custom listener to Serve if you need listening on non-TCP4 media // such as IPv6. func (s *Server) ListenAndServe(addr string) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } return s.Serve(ln) } // ListenAndServeUNIX serves HTTP requests from the given UNIX addr. // // The function deletes existing file at addr before starting serving. // // The server sets the given file mode for the UNIX addr. func (s *Server) ListenAndServeUNIX(addr string, mode os.FileMode) error { if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { return fmt.Errorf("unexpected error when trying to remove unix socket file %q: %s", addr, err) } ln, err := net.Listen("unix", addr) if err != nil { return err } if err = os.Chmod(addr, mode); err != nil { return fmt.Errorf("cannot chmod %#o for %q: %s", mode, addr, err) } return s.Serve(ln) } // ListenAndServeTLS serves HTTPS requests from the given TCP4 addr. // // certFile and keyFile are paths to TLS certificate and key files. // // Pass custom listener to Serve if you need listening on non-TCP4 media // such as IPv6. func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } return s.ServeTLS(ln, certFile, keyFile) } // ListenAndServeTLSEmbed serves HTTPS requests from the given TCP4 addr. // // certData and keyData must contain valid TLS certificate and key data. // // Pass custom listener to Serve if you need listening on arbitrary media // such as IPv6. func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } return s.ServeTLSEmbed(ln, certData, keyData) } // ServeTLS serves HTTPS requests from the given listener. // // certFile and keyFile are paths to TLS certificate and key files. func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error { lnTLS, err := newTLSListener(ln, certFile, keyFile) if err != nil { return err } return s.Serve(lnTLS) } // ServeTLSEmbed serves HTTPS requests from the given listener. // // certData and keyData must contain valid TLS certificate and key data. func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error { lnTLS, err := newTLSListenerEmbed(ln, certData, keyData) if err != nil { return err } return s.Serve(lnTLS) } func newTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) } return newCertListener(ln, &cert), nil } func newTLSListenerEmbed(ln net.Listener, certData, keyData []byte) (net.Listener, error) { cert, err := tls.X509KeyPair(certData, keyData) if err != nil { return nil, fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s", len(certData), len(keyData), err) } return newCertListener(ln, &cert), nil } func newCertListener(ln net.Listener, cert *tls.Certificate) net.Listener { tlsConfig := &tls.Config{ Certificates: []tls.Certificate{*cert}, PreferServerCipherSuites: true, } return tls.NewListener(ln, tlsConfig) } // DefaultConcurrency is the maximum number of concurrent connections // the Server may serve by default (i.e. if Server.Concurrency isn't set). const DefaultConcurrency = 256 * 1024 // Serve serves incoming connections from the given listener. // // Serve blocks until the given listener returns permanent error. func (s *Server) Serve(ln net.Listener) error { var lastOverflowErrorTime time.Time var lastPerIPErrorTime time.Time var c net.Conn var err error maxWorkersCount := s.getConcurrency() s.concurrencyCh = make(chan struct{}, maxWorkersCount) wp := &workerPool{ WorkerFunc: s.serveConn, MaxWorkersCount: maxWorkersCount, LogAllErrors: s.LogAllErrors, Logger: s.logger(), } wp.Start() for { if c, err = acceptConn(s, ln, &lastPerIPErrorTime); err != nil { wp.Stop() if err == io.EOF { return nil } return err } if !wp.Serve(c) { s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded") c.Close() if time.Since(lastOverflowErrorTime) > time.Minute { s.logger().Printf("The incoming connection cannot be served, because %d concurrent connections are served. "+ "Try increasing Server.Concurrency", maxWorkersCount) lastOverflowErrorTime = time.Now() } // The current server reached concurrency limit, // so give other concurrently running servers a chance // accepting incoming connections on the same address. // // There is a hope other servers didn't reach their // concurrency limits yet :) time.Sleep(100 * time.Millisecond) } c = nil } } func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { for { c, err := ln.Accept() if err != nil { if c != nil { panic("BUG: net.Listener returned non-nil conn and non-nil error") } if netErr, ok := err.(net.Error); ok && netErr.Temporary() { s.logger().Printf("Temporary error when accepting new connections: %s", netErr) time.Sleep(time.Second) continue } if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { s.logger().Printf("Permanent error when accepting new connections: %s", err) return nil, err } return nil, io.EOF } if c == nil { panic("BUG: net.Listener returned (nil, nil)") } if s.MaxConnsPerIP > 0 { pic := wrapPerIPConn(s, c) if pic == nil { if time.Since(*lastPerIPErrorTime) > time.Minute { s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d", getConnIP4(c), s.MaxConnsPerIP) *lastPerIPErrorTime = time.Now() } continue } c = pic } return c, nil } } func wrapPerIPConn(s *Server, c net.Conn) net.Conn { ip := getUint32IP(c) if ip == 0 { return c } n := s.perIPConnCounter.Register(ip) if n > s.MaxConnsPerIP { s.perIPConnCounter.Unregister(ip) s.writeFastError(c, StatusTooManyRequests, "The number of connections from your ip exceeds MaxConnsPerIP") c.Close() return nil } return acquirePerIPConn(c, ip, &s.perIPConnCounter) } var defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) func (s *Server) logger() Logger { if s.Logger != nil { return s.Logger } return defaultLogger } var ( // ErrPerIPConnLimit may be returned from ServeConn if the number of connections // per ip exceeds Server.MaxConnsPerIP. ErrPerIPConnLimit = errors.New("too many connections per ip") // ErrConcurrencyLimit may be returned from ServeConn if the number // of concurrenty served connections exceeds Server.Concurrency. ErrConcurrencyLimit = errors.New("canot serve the connection because Server.Concurrency concurrent connections are served") // ErrKeepaliveTimeout is returned from ServeConn // if the connection lifetime exceeds MaxKeepaliveDuration. ErrKeepaliveTimeout = errors.New("exceeded MaxKeepaliveDuration") ) // ServeConn serves HTTP requests from the given connection. // // ServeConn returns nil if all requests from the c are successfully served. // It returns non-nil error otherwise. // // Connection c must immediately propagate all the data passed to Write() // to the client. Otherwise requests' processing may hang. // // ServeConn closes c before returning. func (s *Server) ServeConn(c net.Conn) error { if s.MaxConnsPerIP > 0 { pic := wrapPerIPConn(s, c) if pic == nil { return ErrPerIPConnLimit } c = pic } n := atomic.AddUint32(&s.concurrency, 1) if n > uint32(s.getConcurrency()) { atomic.AddUint32(&s.concurrency, ^uint32(0)) s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded") c.Close() return ErrConcurrencyLimit } err := s.serveConn(c) atomic.AddUint32(&s.concurrency, ^uint32(0)) if err != errHijacked { err1 := c.Close() if err == nil { err = err1 } } else { err = nil } return err } var errHijacked = errors.New("connection has been hijacked") func (s *Server) getConcurrency() int { n := s.Concurrency if n <= 0 { n = DefaultConcurrency } return n } var globalConnID uint64 func nextConnID() uint64 { return atomic.AddUint64(&globalConnID, 1) } func (s *Server) serveConn(c net.Conn) error { serverName := s.getServerName() connRequestNum := uint64(0) connID := nextConnID() currentTime := time.Now() connTime := currentTime ctx := s.acquireCtx(c) ctx.connTime = connTime var ( br *bufio.Reader bw *bufio.Writer err error timeoutResponse *Response hijackHandler HijackHandler lastReadDeadlineTime time.Time lastWriteDeadlineTime time.Time connectionClose bool isHTTP11 bool ) for { connRequestNum++ ctx.time = currentTime if s.ReadTimeout > 0 || s.MaxKeepaliveDuration > 0 { lastReadDeadlineTime = s.updateReadDeadline(c, ctx, lastReadDeadlineTime) if lastReadDeadlineTime.IsZero() { err = ErrKeepaliveTimeout break } } if !(s.ReduceMemoryUsage || ctx.lastReadDuration > time.Second) || br != nil { if br == nil { br = acquireReader(ctx) } } else { br, err = acquireByteReader(&ctx) } if err == nil { if s.DisableHeaderNamesNormalizing { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly) if br.Buffered() == 0 || err != nil { releaseReader(s, br) br = nil } } currentTime = time.Now() ctx.lastReadDuration = currentTime.Sub(ctx.time) if err != nil { if err == io.EOF { err = nil } break } // 'Expect: 100-continue' request handling. // See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details. if !ctx.Request.Header.noBody() && ctx.Request.MayContinue() { // Send 'HTTP/1.1 100 Continue' response. if bw == nil { bw = acquireWriter(ctx) } bw.Write(strResponseContinue) err = bw.Flush() releaseWriter(s, bw) bw = nil if err != nil { break } // Read request body. if br == nil { br = acquireReader(ctx) } err = ctx.Request.ContinueReadBody(br, s.MaxRequestBodySize) if br.Buffered() == 0 || err != nil { releaseReader(s, br) br = nil } if err != nil { break } } connectionClose = s.DisableKeepalive || ctx.Request.Header.connectionCloseFast() isHTTP11 = ctx.Request.Header.IsHTTP11() ctx.Response.Header.SetServerBytes(serverName) ctx.connID = connID ctx.connRequestNum = connRequestNum ctx.connTime = connTime ctx.time = currentTime s.Handler(ctx) timeoutResponse = ctx.timeoutResponse if timeoutResponse != nil { ctx = s.acquireCtx(c) timeoutResponse.CopyTo(&ctx.Response) if br != nil { // Close connection, since br may be attached to the old ctx via ctx.fbr. ctx.SetConnectionClose() } } if !ctx.IsGet() && ctx.IsHead() { ctx.Response.SkipBody = true } ctx.Request.Reset() hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil ctx.userValues.Reset() if s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn) { ctx.SetConnectionClose() } if s.WriteTimeout > 0 || s.MaxKeepaliveDuration > 0 { lastWriteDeadlineTime = s.updateWriteDeadline(c, ctx, lastWriteDeadlineTime) } // Verify Request.Header.connectionCloseFast() again, // since request handler might trigger full headers' parsing. connectionClose = connectionClose || ctx.Request.Header.connectionCloseFast() || ctx.Response.ConnectionClose() if connectionClose { ctx.Response.Header.SetCanonical(strConnection, strClose) } else if !isHTTP11 { // Set 'Connection: keep-alive' response header for non-HTTP/1.1 request. // There is no need in setting this header for http/1.1, since in http/1.1 // connections are keep-alive by default. ctx.Response.Header.SetCanonical(strConnection, strKeepAlive) } if len(ctx.Response.Header.Server()) == 0 { ctx.Response.Header.SetServerBytes(serverName) } if bw == nil { bw = acquireWriter(ctx) } if err = writeResponse(ctx, bw); err != nil { break } if br == nil || connectionClose { err = bw.Flush() releaseWriter(s, bw) bw = nil if err != nil { break } if connectionClose { break } } if hijackHandler != nil { var hjr io.Reader hjr = c if br != nil { hjr = br br = nil // br may point to ctx.fbr, so do not return ctx into pool. ctx = s.acquireCtx(c) } if bw != nil { err = bw.Flush() releaseWriter(s, bw) bw = nil if err != nil { break } } c.SetReadDeadline(zeroTime) c.SetWriteDeadline(zeroTime) go hijackConnHandler(hjr, c, s, hijackHandler) hijackHandler = nil err = errHijacked break } currentTime = time.Now() } if br != nil { releaseReader(s, br) } if bw != nil { releaseWriter(s, bw) } s.releaseCtx(ctx) return err } func (s *Server) updateReadDeadline(c net.Conn, ctx *RequestCtx, lastDeadlineTime time.Time) time.Time { readTimeout := s.ReadTimeout currentTime := ctx.time if s.MaxKeepaliveDuration > 0 { connTimeout := s.MaxKeepaliveDuration - currentTime.Sub(ctx.connTime) if connTimeout <= 0 { return zeroTime } if connTimeout < readTimeout { readTimeout = connTimeout } } // Optimization: update read deadline only if more than 25% // of the last read deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. if currentTime.Sub(lastDeadlineTime) > (readTimeout >> 2) { if err := c.SetReadDeadline(currentTime.Add(readTimeout)); err != nil { panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", readTimeout, err)) } lastDeadlineTime = currentTime } return lastDeadlineTime } func (s *Server) updateWriteDeadline(c net.Conn, ctx *RequestCtx, lastDeadlineTime time.Time) time.Time { writeTimeout := s.WriteTimeout if s.MaxKeepaliveDuration > 0 { connTimeout := s.MaxKeepaliveDuration - time.Since(ctx.connTime) if connTimeout <= 0 { // MaxKeepAliveDuration exceeded, but let's try sending response anyway // in 100ms with 'Connection: close' header. ctx.SetConnectionClose() connTimeout = 100 * time.Millisecond } if connTimeout < writeTimeout { writeTimeout = connTimeout } } // Optimization: update write deadline only if more than 25% // of the last write deadline exceeded. // See https://github.com/golang/go/issues/15133 for details. currentTime := time.Now() if currentTime.Sub(lastDeadlineTime) > (writeTimeout >> 2) { if err := c.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err)) } lastDeadlineTime = currentTime } return lastDeadlineTime } func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) { hjc := s.acquireHijackConn(r, c) defer func() { if r := recover(); r != nil { s.logger().Printf("panic on hijacked conn: %s\nStack trace:\n%s", r, debug.Stack()) } if br, ok := r.(*bufio.Reader); ok { releaseReader(s, br) } c.Close() s.releaseHijackConn(hjc) }() h(hjc) } func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn { v := s.hijackConnPool.Get() if v == nil { hjc := &hijackConn{ Conn: c, r: r, } return hjc } hjc := v.(*hijackConn) hjc.Conn = c hjc.r = r return hjc } func (s *Server) releaseHijackConn(hjc *hijackConn) { hjc.Conn = nil hjc.r = nil s.hijackConnPool.Put(hjc) } type hijackConn struct { net.Conn r io.Reader } func (c hijackConn) Read(p []byte) (int, error) { return c.r.Read(p) } func (c hijackConn) Close() error { // hijacked conn is closed in hijackConnHandler. return nil } // LastTimeoutErrorResponse returns the last timeout response set // via TimeoutError* call. // // This function is intended for custom server implementations. func (ctx *RequestCtx) LastTimeoutErrorResponse() *Response { return ctx.timeoutResponse } func writeResponse(ctx *RequestCtx, w *bufio.Writer) error { if ctx.timeoutResponse != nil { panic("BUG: cannot write timed out response") } err := ctx.Response.Write(w) ctx.Response.Reset() return err } const ( defaultReadBufferSize = 4096 defaultWriteBufferSize = 4096 ) func acquireByteReader(ctxP **RequestCtx) (*bufio.Reader, error) { ctx := *ctxP s := ctx.s c := ctx.c t := ctx.time s.releaseCtx(ctx) // Make GC happy, so it could garbage collect ctx // while we waiting for the next request. ctx = nil *ctxP = nil v := s.bytePool.Get() if v == nil { v = make([]byte, 1) } b := v.([]byte) n, err := c.Read(b) ch := b[0] s.bytePool.Put(v) ctx = s.acquireCtx(c) ctx.time = t *ctxP = ctx if err != nil { // Treat all errors as EOF on unsuccessful read // of the first request byte. return nil, io.EOF } if n != 1 { panic("BUG: Reader must return at least one byte") } ctx.fbr.c = c ctx.fbr.ch = ch ctx.fbr.byteRead = false r := acquireReader(ctx) r.Reset(&ctx.fbr) return r, nil } func acquireReader(ctx *RequestCtx) *bufio.Reader { v := ctx.s.readerPool.Get() if v == nil { n := ctx.s.ReadBufferSize if n <= 0 { n = defaultReadBufferSize } return bufio.NewReaderSize(ctx.c, n) } r := v.(*bufio.Reader) r.Reset(ctx.c) return r } func releaseReader(s *Server, r *bufio.Reader) { s.readerPool.Put(r) } func acquireWriter(ctx *RequestCtx) *bufio.Writer { v := ctx.s.writerPool.Get() if v == nil { n := ctx.s.WriteBufferSize if n <= 0 { n = defaultWriteBufferSize } return bufio.NewWriterSize(ctx.c, n) } w := v.(*bufio.Writer) w.Reset(ctx.c) return w } func releaseWriter(s *Server, w *bufio.Writer) { s.writerPool.Put(w) } func (s *Server) acquireCtx(c net.Conn) *RequestCtx { v := s.ctxPool.Get() var ctx *RequestCtx if v == nil { v = &RequestCtx{ s: s, } } ctx = v.(*RequestCtx) ctx.c = c return ctx } // Init prepares ctx for passing to RequestHandler. // // remoteAddr and logger are optional. They are used by RequestCtx.Logger(). // // This function is intended for custom Server implementations. func (ctx *RequestCtx) Init(req *Request, remoteAddr net.Addr, logger Logger) { if remoteAddr == nil { remoteAddr = zeroTCPAddr } ctx.c = &fakeAddrer{ addr: remoteAddr, } if logger == nil { logger = defaultLogger } ctx.connID = nextConnID() ctx.logger.logger = logger ctx.s = &fakeServer req.CopyTo(&ctx.Request) ctx.Response.Reset() ctx.connRequestNum = 0 ctx.connTime = time.Now() ctx.time = ctx.connTime } var fakeServer Server type fakeAddrer struct { net.Conn addr net.Addr } func (fa *fakeAddrer) RemoteAddr() net.Addr { return fa.addr } func (fa *fakeAddrer) LocalAddr() net.Addr { return fa.addr } func (fa *fakeAddrer) Read(p []byte) (int, error) { panic("BUG: unexpected Read call") } func (fa *fakeAddrer) Write(p []byte) (int, error) { panic("BUG: unexpected Write call") } func (fa *fakeAddrer) Close() error { panic("BUG: unexpected Close call") } func (s *Server) releaseCtx(ctx *RequestCtx) { if ctx.timeoutResponse != nil { panic("BUG: cannot release timed out RequestCtx") } ctx.c = nil ctx.fbr.c = nil s.ctxPool.Put(ctx) } func (s *Server) getServerName() []byte { v := s.serverName.Load() var serverName []byte if v == nil { serverName = []byte(s.Name) if len(serverName) == 0 { serverName = defaultServerName } s.serverName.Store(serverName) } else { serverName = v.([]byte) } return serverName } func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) { w.Write(statusLine(statusCode)) fmt.Fprintf(w, "Connection: close\r\n"+ "Server: %s\r\n"+ "Date: %s\r\n"+ "Content-Type: text/plain\r\n"+ "Content-Length: %d\r\n"+ "\r\n"+ "%s", s.getServerName(), serverDate.Load(), len(msg), msg) } golang-github-valyala-fasthttp-20160617/server_example_test.go000066400000000000000000000126441273074646000244220ustar00rootroot00000000000000package fasthttp_test import ( "fmt" "log" "math/rand" "net" "time" "github.com/valyala/fasthttp" ) func ExampleListenAndServe() { // The server will listen for incoming requests on this address. listenAddr := "127.0.0.1:80" // This function will be called by the server for each incoming request. // // RequestCtx provides a lot of functionality related to http request // processing. See RequestCtx docs for details. requestHandler := func(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path()) } // Start the server with default settings. // Create Server instance for adjusting server settings. // // ListenAndServe returns only on error, so usually it blocks forever. if err := fasthttp.ListenAndServe(listenAddr, requestHandler); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } func ExampleServe() { // Create network listener for accepting incoming requests. // // Note that you are not limited by TCP listener - arbitrary // net.Listener may be used by the server. // For example, unix socket listener or TLS listener. ln, err := net.Listen("tcp4", "127.0.0.1:8080") if err != nil { log.Fatalf("error in net.Listen: %s", err) } // This function will be called by the server for each incoming request. // // RequestCtx provides a lot of functionality related to http request // processing. See RequestCtx docs for details. requestHandler := func(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path()) } // Start the server with default settings. // Create Server instance for adjusting server settings. // // Serve returns on ln.Close() or error, so usually it blocks forever. if err := fasthttp.Serve(ln, requestHandler); err != nil { log.Fatalf("error in Serve: %s", err) } } func ExampleServer() { // This function will be called by the server for each incoming request. // // RequestCtx provides a lot of functionality related to http request // processing. See RequestCtx docs for details. requestHandler := func(ctx *fasthttp.RequestCtx) { fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path()) } // Create custom server. s := &fasthttp.Server{ Handler: requestHandler, // Every response will contain 'Server: My super server' header. Name: "My super server", // Other Server settings may be set here. } // Start the server listening for incoming requests on the given address. // // ListenAndServe returns only on error, so usually it blocks forever. if err := s.ListenAndServe("127.0.0.1:80"); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } func ExampleRequestCtx_Hijack() { // hijackHandler is called on hijacked connection. hijackHandler := func(c net.Conn) { fmt.Fprintf(c, "This message is sent over a hijacked connection to the client %s\n", c.RemoteAddr()) fmt.Fprintf(c, "Send me something and I'll echo it to you\n") var buf [1]byte for { if _, err := c.Read(buf[:]); err != nil { log.Printf("error when reading from hijacked connection: %s", err) return } fmt.Fprintf(c, "You sent me %q. Waiting for new data\n", buf[:]) } } // requestHandler is called for each incoming request. requestHandler := func(ctx *fasthttp.RequestCtx) { path := ctx.Path() switch { case string(path) == "/hijack": // Note that the connection is hijacked only after // returning from requestHandler and sending http response. ctx.Hijack(hijackHandler) // The connection will be hijacked after sending this response. fmt.Fprintf(ctx, "Hijacked the connection!") case string(path) == "/": fmt.Fprintf(ctx, "Root directory requested") default: fmt.Fprintf(ctx, "Requested path is %q", path) } } if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } func ExampleRequestCtx_TimeoutError() { requestHandler := func(ctx *fasthttp.RequestCtx) { // Emulate long-running task, which touches ctx. doneCh := make(chan struct{}) go func() { workDuration := time.Millisecond * time.Duration(rand.Intn(2000)) time.Sleep(workDuration) fmt.Fprintf(ctx, "ctx has been accessed by long-running task\n") fmt.Fprintf(ctx, "The reuqestHandler may be finished by this time.\n") close(doneCh) }() select { case <-doneCh: fmt.Fprintf(ctx, "The task has been finished in less than a second") case <-time.After(time.Second): // Since the long-running task is still running and may access ctx, // we must call TimeoutError before returning from requestHandler. // // Otherwise the program will suffer from data races. ctx.TimeoutError("Timeout!") } } if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } func ExampleRequestCtx_Logger() { requestHandler := func(ctx *fasthttp.RequestCtx) { if string(ctx.Path()) == "/top-secret" { ctx.Logger().Printf("Alarm! Alien intrusion detected!") ctx.Error("Access denied!", fasthttp.StatusForbidden) return } // Logger may be cached in local variables. logger := ctx.Logger() logger.Printf("Good request from User-Agent %q", ctx.Request.Header.UserAgent()) fmt.Fprintf(ctx, "Good request to %q", ctx.Path()) logger.Printf("Multiple log messages may be written during a single request") } if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil { log.Fatalf("error in ListenAndServe: %s", err) } } golang-github-valyala-fasthttp-20160617/server_test.go000066400000000000000000001630211273074646000227030ustar00rootroot00000000000000package fasthttp import ( "bufio" "bytes" "crypto/tls" "fmt" "io" "io/ioutil" "net" "os" "strings" "sync" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) func TestRequestCtxRedirect(t *testing.T) { testRequestCtxRedirect(t, "http://qqq/", "", "http://qqq/") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "", "http://qqq/foo/bar?baz=111") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "#aaa", "http://qqq/foo/bar?baz=111#aaa") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f", "http://qqq/foo/bar?abc=de&f") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f#sf", "http://qqq/foo/bar?abc=de&f#sf") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html", "http://qqq/foo/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?a=1", "http://qqq/foo/x.html?a=1") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html#aaa=bbb&cc=ddd", "http://qqq/foo/x.html#aaa=bbb&cc=ddd") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz") } func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) { var ctx RequestCtx var req Request req.SetRequestURI(origURL) ctx.Init(&req, nil, nil) ctx.Redirect(redirectURL, StatusFound) loc := ctx.Response.Header.Peek("Location") if string(loc) != expectedURL { t.Fatalf("unexpected redirect url %q. Expecting %q. origURL=%q, redirectURL=%q", loc, expectedURL, origURL, redirectURL) } } func TestServerResponseServerHeader(t *testing.T) { serverName := "foobar serv" s := &Server{ Handler: func(ctx *RequestCtx) { name := ctx.Response.Header.Server() if string(name) != serverName { fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName) } else { ctx.WriteString("OK") } // make sure the server name is sent to the client after ctx.Response.Reset() ctx.NotFound() }, Name: serverName, } ln := fasthttputil.NewInmemoryListener() serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { c, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(c) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusNotFound { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotFound) } if string(resp.Body()) != "404 Page not found" { t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "404 Page not found") } if string(resp.Header.Server()) != serverName { t.Fatalf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName) } if err = c.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } close(clientCh) }() select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerResponseBodyStream(t *testing.T) { ln := fasthttputil.NewInmemoryListener() readyCh := make(chan struct{}) h := func(ctx *RequestCtx) { ctx.SetConnectionClose() if ctx.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } ctx.SetBodyStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "first") if err := w.Flush(); err != nil { return } <-readyCh fmt.Fprintf(w, "second") // there is no need to flush w here, since it will // be flushed automatically after returning from StreamWriter. }) if !ctx.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } } serverCh := make(chan struct{}) go func() { if err := Serve(ln, h); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { c, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(c) var respH ResponseHeader if err = respH.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if respH.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK) } buf := make([]byte, 1024) n, err := br.Read(buf) if err != nil { t.Fatalf("unexpected error: %s", err) } b := buf[:n] if string(b) != "5\r\nfirst\r\n" { t.Fatalf("unexpected result %q. Expecting %q", b, "5\r\nfirst\r\n") } close(readyCh) tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n") } close(clientCh) }() select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerDisableKeepalive(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { ctx.WriteString("OK") }, DisableKeepalive: true, } ln := fasthttputil.NewInmemoryListener() serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { c, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(c) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } if !resp.ConnectionClose() { t.Fatalf("expecting 'Connection: close' response header") } if string(resp.Body()) != "OK" { t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "OK") } // make sure the connection is closed data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("unexpected error: %s", err) } if len(data) > 0 { t.Fatalf("unexpected data read from the connection: %q. Expecting empty data", data) } close(clientCh) }() select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerMaxConnsPerIPLimit(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { ctx.WriteString("OK") }, MaxConnsPerIP: 1, Logger: &customLogger{}, } ln := fasthttputil.NewInmemoryListener() serverCh := make(chan struct{}) go func() { fakeLN := &fakeIPListener{ Listener: ln, } if err := s.Serve(fakeLN); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { c1, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } c2, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(c2) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusTooManyRequests { t.Fatalf("unexpected status code for the second connection: %d. Expecting %d", resp.StatusCode(), StatusTooManyRequests) } if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { t.Fatalf("unexpected error when writing to the first connection: %s", err) } br = bufio.NewReader(c1) if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code for the first connection: %d. Expecting %d", resp.StatusCode(), StatusOK) } if string(resp.Body()) != "OK" { t.Fatalf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") } close(clientCh) }() select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } type fakeIPListener struct { net.Listener } func (ln *fakeIPListener) Accept() (net.Conn, error) { conn, err := ln.Listener.Accept() if err != nil { return nil, err } return &fakeIPConn{ Conn: conn, }, nil } type fakeIPConn struct { net.Conn } func (conn *fakeIPConn) RemoteAddr() net.Addr { addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789") if err != nil { panic(fmt.Sprintf("BUG: unexpected error: %s", err)) } return addr } func TestServerConcurrencyLimit(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { ctx.WriteString("OK") }, Concurrency: 1, Logger: &customLogger{}, } ln := fasthttputil.NewInmemoryListener() serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { c1, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } c2, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(c2) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusServiceUnavailable { t.Fatalf("unexpected status code for the second connection: %d. Expecting %d", resp.StatusCode(), StatusServiceUnavailable) } if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { t.Fatalf("unexpected error when writing to the first connection: %s", err) } br = bufio.NewReader(c1) if err = resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code for the first connection: %d. Expecting %d", resp.StatusCode(), StatusOK) } if string(resp.Body()) != "OK" { t.Fatalf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") } close(clientCh) }() select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerWriteFastError(t *testing.T) { s := &Server{ Name: "foobar", } var buf bytes.Buffer expectedBody := "access denied" s.writeFastError(&buf, StatusForbidden, expectedBody) br := bufio.NewReader(&buf) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } if resp.StatusCode() != StatusForbidden { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden) } body := resp.Body() if string(body) != expectedBody { t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) } server := string(resp.Header.Server()) if server != s.Name { t.Fatalf("unexpected server: %q. Expecting %q", server, s.Name) } contentType := string(resp.Header.ContentType()) if contentType != "text/plain" { t.Fatalf("unexpected content-type: %q. Expecting %q", contentType, "text/plain") } if !resp.Header.ConnectionClose() { t.Fatalf("expecting 'Connection: close' response header") } } func TestServerServeTLSEmbed(t *testing.T) { ln := fasthttputil.NewInmemoryListener() certFile := "./ssl-cert-snakeoil.pem" keyFile := "./ssl-cert-snakeoil.key" certData, err := ioutil.ReadFile(certFile) if err != nil { t.Fatalf("unexpected error when reading %q: %s", certFile, err) } keyData, err := ioutil.ReadFile(keyFile) if err != nil { t.Fatalf("unexpected error when reading %q: %s", keyFile, err) } // start the server ch := make(chan struct{}) go func() { err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) { ctx.WriteString("success") }) if err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() // establish connection to the server conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } tlsConn := tls.Client(conn, &tls.Config{ InsecureSkipVerify: true, }) // send request if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } // read response respCh := make(chan struct{}) go func() { br := bufio.NewReader(tlsConn) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("unexpected error") } body := resp.Body() if string(body) != "success" { t.Fatalf("unexpected response body %q. Expecting %q", body, "success") } close(respCh) }() select { case <-respCh: case <-time.After(time.Second): t.Fatalf("timeout") } // close the server if err = ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerMultipartFormDataRequest(t *testing.T) { reqS := `POST /upload HTTP/1.1 Host: qwerty.com Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- GET / HTTP/1.1 Host: asbd Connection: close ` ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { switch string(ctx.Path()) { case "/upload": f, err := ctx.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } if len(f.Value) != 1 { t.Fatalf("unexpected values %d. Expecting %d", len(f.Value), 1) } if len(f.File) != 1 { t.Fatalf("unexpected file values %d. Expecting %d", len(f.File), 1) } fv := ctx.FormValue("f1") if string(fv) != "value1" { t.Fatalf("unexpected form value: %q. Expecting %q", fv, "value1") } ctx.Redirect("/", StatusSeeOther) default: ctx.WriteString("non-upload") } }, } ch := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } if _, err = conn.Write([]byte(reqS)); err != nil { t.Fatalf("unexpected error: %s", err) } var resp Response br := bufio.NewReader(conn) respCh := make(chan struct{}) go func() { if err := resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if resp.StatusCode() != StatusSeeOther { t.Fatalf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther) } loc := resp.Header.Peek("Location") if string(loc) != "http://qwerty.com/" { t.Fatalf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/") } if err := resp.Read(br); err != nil { t.Fatalf("error when reading the second response: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } body := resp.Body() if string(body) != "non-upload" { t.Fatalf("unexpected body %q. Expecting %q", body, "non-upload") } close(respCh) }() select { case <-respCh: case <-time.After(time.Second): t.Fatalf("timeout") } if err := ln.Close(); err != nil { t.Fatalf("error when closing listener: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout when waiting for the server to stop") } } func TestServerDisableHeaderNamesNormalizing(t *testing.T) { headerName := "CASE-senSITive-HEAder-NAME" headerNameLower := strings.ToLower(headerName) headerValue := "foobar baz" s := &Server{ Handler: func(ctx *RequestCtx) { hv := ctx.Request.Header.Peek(headerName) if string(hv) != headerValue { t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) } hv = ctx.Request.Header.Peek(headerNameLower) if len(hv) > 0 { t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) } ctx.Response.Header.Set(headerName, headerValue) ctx.WriteString("ok") ctx.SetContentType("aaa") }, DisableHeaderNamesNormalizing: true, } rw := &readWriter{} rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue)) ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response resp.Header.DisableNormalizing() if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } hv := resp.Header.Peek(headerName) if string(hv) != headerValue { t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) } hv = resp.Header.Peek(headerNameLower) if len(hv) > 0 { t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) } } func TestServerReduceMemoryUsageSerial(t *testing.T) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) {}, ReduceMemoryUsage: true, } ch := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() testServerRequests(t, ln) if err := ln.Close(); err != nil { t.Fatalf("error when closing listener: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout when waiting for the server to stop") } } func TestServerReduceMemoryUsageConcurrent(t *testing.T) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) {}, ReduceMemoryUsage: true, } ch := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() gCh := make(chan struct{}) for i := 0; i < 10; i++ { go func() { testServerRequests(t, ln) gCh <- struct{}{} }() } for i := 0; i < 10; i++ { select { case <-gCh: case <-time.After(time.Second): t.Fatalf("timeout on goroutine %d", i) } } if err := ln.Close(); err != nil { t.Fatalf("error when closing listener: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout when waiting for the server to stop") } } func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) { conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(conn) var resp Response for i := 0; i < 10; i++ { if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil { t.Fatalf("unexpected error on iteration %d: %s", i, err) } respCh := make(chan struct{}) go func() { if err = resp.Read(br); err != nil { t.Fatalf("unexpected error when reading response on iteration %d: %s", i, err) } close(respCh) }() select { case <-respCh: case <-time.After(time.Second): t.Fatalf("timeout on iteration %d", i) } } if err = conn.Close(); err != nil { t.Fatalf("error when closing the connection: %s", err) } } func TestServerHTTP10ConnectionKeepAlive(t *testing.T) { ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { err := Serve(ln, func(ctx *RequestCtx) { if string(ctx.Path()) == "/close" { ctx.SetConnectionClose() } }) if err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") if err != nil { t.Fatalf("error when writing request: %s", err) } _, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") if err != nil { t.Fatalf("error when writing request: %s", err) } br := bufio.NewReader(conn) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if resp.ConnectionClose() { t.Fatalf("response mustn't have 'Connection: close' header") } if err = resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if !resp.ConnectionClose() { t.Fatalf("response must have 'Connection: close' header") } tailCh := make(chan struct{}) go func() { tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("error when reading tail: %s", err) } if len(tail) > 0 { t.Fatalf("unexpected non-zero tail %q", tail) } close(tailCh) }() select { case <-tailCh: case <-time.After(time.Second): t.Fatalf("timeout when reading tail") } if err = conn.Close(); err != nil { t.Fatalf("error when closing the connection: %s", err) } if err = ln.Close(); err != nil { t.Fatalf("error when closing listener: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout when waiting for the server to stop") } } func TestServerHTTP10ConnectionClose(t *testing.T) { ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { err := Serve(ln, func(ctx *RequestCtx) { // The server must close the connection irregardless // of request and response state set inside request // handler, since the HTTP/1.0 request // had no 'Connection: keep-alive' header. ctx.Request.Header.ResetConnectionClose() ctx.Request.Header.Set("Connection", "keep-alive") ctx.Response.Header.ResetConnectionClose() ctx.Response.Header.Set("Connection", "keep-alive") }) if err != nil { t.Fatalf("unexpected error: %s", err) } close(ch) }() conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n") if err != nil { t.Fatalf("error when writing request: %s", err) } br := bufio.NewReader(conn) var resp Response if err = resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if !resp.ConnectionClose() { t.Fatalf("HTTP1.0 response must have 'Connection: close' header") } tailCh := make(chan struct{}) go func() { tail, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("error when reading tail: %s", err) } if len(tail) > 0 { t.Fatalf("unexpected non-zero tail %q", tail) } close(tailCh) }() select { case <-tailCh: case <-time.After(time.Second): t.Fatalf("timeout when reading tail") } if err = conn.Close(); err != nil { t.Fatalf("error when closing the connection: %s", err) } if err = ln.Close(); err != nil { t.Fatalf("error when closing listener: %s", err) } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout when waiting for the server to stop") } } func TestRequestCtxFormValue(t *testing.T) { var ctx RequestCtx var req Request req.SetRequestURI("/foo/bar?baz=123&aaa=bbb") req.SetBodyString("qqq=port&mmm=sddd") req.Header.SetContentType("application/x-www-form-urlencoded") ctx.Init(&req, nil, nil) v := ctx.FormValue("baz") if string(v) != "123" { t.Fatalf("unexpected value %q. Expecting %q", v, "123") } v = ctx.FormValue("mmm") if string(v) != "sddd" { t.Fatalf("unexpected value %q. Expecting %q", v, "sddd") } v = ctx.FormValue("aaaasdfsdf") if len(v) > 0 { t.Fatalf("unexpected value for unknown key %q", v) } } func TestRequestCtxUserValue(t *testing.T) { var ctx RequestCtx for i := 0; i < 5; i++ { k := fmt.Sprintf("key-%d", i) ctx.SetUserValue(k, i) } for i := 5; i < 10; i++ { k := fmt.Sprintf("key-%d", i) ctx.SetUserValueBytes([]byte(k), i) } for i := 0; i < 10; i++ { k := fmt.Sprintf("key-%d", i) v := ctx.UserValue(k) n, ok := v.(int) if !ok || n != i { t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i) } } } func TestServerHeadRequest(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { fmt.Fprintf(ctx, "Request method is %q", ctx.Method()) ctx.SetContentType("aaa/bbb") }, } rw := &readWriter{} rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response resp.SkipBody = true if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if resp.Header.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK) } if len(resp.Body()) > 0 { t.Fatalf("Unexpected non-zero body %q", resp.Body()) } if resp.Header.ContentLength() != 24 { t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24) } if string(resp.Header.ContentType()) != "aaa/bbb" { t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb") } data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) > 0 { t.Fatalf("unexpected remaining data %q", data) } } func TestServerExpect100Continue(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { if !ctx.IsPost() { t.Fatalf("unexpected method %q. Expecting POST", ctx.Method()) } if string(ctx.Path()) != "/foo" { t.Fatalf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") } ct := ctx.Request.Header.ContentType() if string(ct) != "a/b" { t.Fatalf("unexpectected content-type: %q. Expecting %q", ct, "a/b") } if string(ctx.PostBody()) != "12345" { t.Fatalf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") } ctx.WriteString("foobar") }, } rw := &readWriter{} rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar") data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) > 0 { t.Fatalf("unexpected remaining data %q", data) } } func TestCompressHandler(t *testing.T) { expectedBody := "foo/bar/baz" h := CompressHandler(func(ctx *RequestCtx) { ctx.Write([]byte(expectedBody)) }) var ctx RequestCtx var resp Response // verify uncompressed response h(&ctx) s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce := resp.Header.Peek("Content-Encoding") if string(ce) != "" { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") } body := resp.Body() if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } // verify gzip-compressed response ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") h(&ctx) s = ctx.Response.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce = resp.Header.Peek("Content-Encoding") if string(ce) != "gzip" { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") } body, err := resp.BodyGunzip() if err != nil { t.Fatalf("unexpected error: %s", err) } if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } // an attempt to compress already compressed response ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") hh := CompressHandler(h) hh(&ctx) s = ctx.Response.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce = resp.Header.Peek("Content-Encoding") if string(ce) != "gzip" { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") } body, err = resp.BodyGunzip() if err != nil { t.Fatalf("unexpected error: %s", err) } if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } // verify deflate-compressed response ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.Set("Accept-Encoding", "foobar, deflate, sdhc") h(&ctx) s = ctx.Response.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } ce = resp.Header.Peek("Content-Encoding") if string(ce) != "deflate" { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate") } body, err = resp.BodyInflate() if err != nil { t.Fatalf("unexpected error: %s", err) } if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } } func TestRequestCtxWriteString(t *testing.T) { var ctx RequestCtx n, err := ctx.WriteString("foo") if err != nil { t.Fatalf("unexpected error: %s", err) } if n != 3 { t.Fatalf("unexpected n %d. Expecting 3", n) } n, err = ctx.WriteString("привет") if err != nil { t.Fatalf("unexpected error: %s", err) } if n != 12 { t.Fatalf("unexpected n=%d. Expecting 12", n) } s := ctx.Response.Body() if string(s) != "fooпривет" { t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет") } } func TestServeConnNonHTTP11KeepAlive(t *testing.T) { rw := &readWriter{} rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n") requestsServed := 0 ch := make(chan struct{}) go func() { err := ServeConn(rw, func(ctx *RequestCtx) { requestsServed++ ctx.SuccessString("aaa/bbb", "foobar") }) if err != nil { t.Fatalf("unexpected error in ServeConn: %s", err) } close(ch) }() select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response // verify the first response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if string(resp.Header.Peek("Connection")) != "keep-alive" { t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "keep-alive") } if resp.Header.ConnectionClose() { t.Fatalf("unexpected Connection: close") } // verify the second response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if string(resp.Header.Peek("Connection")) != "close" { t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "close") } if !resp.Header.ConnectionClose() { t.Fatalf("expecting Connection: close") } data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after responses %q", data) } if requestsServed != 2 { t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed) } } func TestRequestCtxSetBodyStreamWriter(t *testing.T) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) if ctx.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } ctx.SetBodyStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "body writer line 1\n") if err := w.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } fmt.Fprintf(w, "body writer line 2\n") }) if !ctx.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("Error when reading response: %s", err) } body := string(resp.Body()) expectedBody := "body writer line 1\nbody writer line 2\n" if body != expectedBody { t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) } } func TestRequestCtxIfModifiedSince(t *testing.T) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) lastModified := time.Now().Add(-time.Hour) if !ctx.IfModifiedSince(lastModified) { t.Fatalf("IfModifiedSince must return true for non-existing If-Modified-Since header") } ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) if ctx.IfModifiedSince(lastModified) { t.Fatalf("If-Modified-Since current time must return false") } past := lastModified.Add(-time.Hour) if ctx.IfModifiedSince(past) { t.Fatalf("If-Modified-Since past time must return false") } future := lastModified.Add(time.Hour) if !ctx.IfModifiedSince(future) { t.Fatalf("If-Modified-Since future time must return true") } } func TestRequestCtxSendFileNotModified(t *testing.T) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) filePath := "./server_test.go" lastModified, err := FileLastModified(filePath) if err != nil { t.Fatalf("unexpected error: %s", err) } ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) ctx.SendFile(filePath) s := ctx.Response.String() var resp Response br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if resp.StatusCode() != StatusNotModified { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified) } if len(resp.Body()) > 0 { t.Fatalf("unexpected non-zero response body: %q", resp.Body()) } } func TestRequestCtxSendFileModified(t *testing.T) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) filePath := "./server_test.go" lastModified, err := FileLastModified(filePath) if err != nil { t.Fatalf("unexpected error: %s", err) } lastModified = lastModified.Add(-time.Hour) ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) ctx.SendFile(filePath) s := ctx.Response.String() var resp Response br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } f, err := os.Open(filePath) if err != nil { t.Fatalf("cannot open file: %s", err) } body, err := ioutil.ReadAll(f) f.Close() if err != nil { t.Fatalf("error when reading file: %s", err) } if !bytes.Equal(resp.Body(), body) { t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) } } func TestRequestCtxSendFile(t *testing.T) { var ctx RequestCtx var req Request ctx.Init(&req, nil, defaultLogger) filePath := "./server_test.go" ctx.SendFile(filePath) w := &bytes.Buffer{} bw := bufio.NewWriter(w) if err := ctx.Response.Write(bw); err != nil { t.Fatalf("error when writing response: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("error when flushing response: %s", err) } var resp Response br := bufio.NewReader(w) if err := resp.Read(br); err != nil { t.Fatalf("error when reading response: %s", err) } if resp.StatusCode() != StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } f, err := os.Open(filePath) if err != nil { t.Fatalf("cannot open file: %s", err) } body, err := ioutil.ReadAll(f) f.Close() if err != nil { t.Fatalf("error when reading file: %s", err) } if !bytes.Equal(resp.Body(), body) { t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) } } func TestRequestCtxHijack(t *testing.T) { hijackStartCh := make(chan struct{}) hijackStopCh := make(chan struct{}) s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Hijack(func(c net.Conn) { <-hijackStartCh b := make([]byte, 1) // ping-pong echo via hijacked conn for { n, err := c.Read(b) if n != 1 { if err == io.EOF { close(hijackStopCh) return } if err != nil { t.Fatalf("unexpected error: %s", err) } t.Fatalf("unexpected number of bytes read: %d. Expecting 1", n) } if _, err = c.Write(b); err != nil { t.Fatalf("unexpected error when writing data: %s", err) } } }) ctx.Success("foo/bar", []byte("hijack it!")) }, } hijackedString := "foobar baz hijacked!!!" rw := &readWriter{} rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString(hijackedString) ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") close(hijackStartCh) select { case <-hijackStopCh: case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if string(data) != hijackedString { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) } } func TestRequestCtxInit(t *testing.T) { var ctx RequestCtx var logger customLogger globalConnID = 0x123456 ctx.Init(&ctx.Request, zeroTCPAddr, &logger) ip := ctx.RemoteIP() if !ip.IsUnspecified() { t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip) } ctx.Logger().Printf("foo bar %d", 10) expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n" if logger.out != expectedLog { t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog) } } func TestTimeoutHandlerSuccess(t *testing.T) { ln := fasthttputil.NewInmemoryListener() h := func(ctx *RequestCtx) { if string(ctx.Path()) == "/" { ctx.Success("aaa/bbb", []byte("real response")) } } s := &Server{ Handler: TimeoutHandler(h, 10*time.Second, "timeout!!!"), } serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexepcted error: %s", err) } close(serverCh) }() concurrency := 20 clientCh := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { conn, err := ln.Dial() if err != nil { t.Fatalf("unexepcted error: %s", err) } if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(conn) verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") clientCh <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestTimeoutHandlerTimeout(t *testing.T) { ln := fasthttputil.NewInmemoryListener() readyCh := make(chan struct{}) doneCh := make(chan struct{}) h := func(ctx *RequestCtx) { ctx.Success("aaa/bbb", []byte("real response")) <-readyCh doneCh <- struct{}{} } s := &Server{ Handler: TimeoutHandler(h, 20*time.Millisecond, "timeout!!!"), } serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { t.Fatalf("unexepcted error: %s", err) } close(serverCh) }() concurrency := 20 clientCh := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { conn, err := ln.Dial() if err != nil { t.Fatalf("unexepcted error: %s", err) } if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { t.Fatalf("unexpected error: %s", err) } br := bufio.NewReader(conn) verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!") clientCh <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } } close(readyCh) for i := 0; i < concurrency; i++ { select { case <-doneCh: case <-time.After(time.Second): t.Fatalf("timeout") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-serverCh: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestServerGetOnly(t *testing.T) { h := func(ctx *RequestCtx) { if !ctx.IsGet() { t.Fatalf("non-get request: %q", ctx.Method()) } ctx.Success("foo/bar", []byte("success")) } s := &Server{ Handler: h, GetOnly: true, } rw := &readWriter{} rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err == nil { t.Fatalf("expecting error") } if err != errGetOnly { t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, errGetOnly) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } resp := rw.w.Bytes() if len(resp) > 0 { t.Fatalf("unexpected response %q. Expecting zero", resp) } } func TestServerTimeoutErrorWithResponse(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { go func() { ctx.Success("aaa/bbb", []byte("xxxyyy")) }() var resp Response resp.SetStatusCode(123) resp.SetBodyString("foobar. Should be ignored") ctx.TimeoutErrorWithResponse(&resp) resp.SetStatusCode(456) resp.ResetBody() fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path()) resp.Header.SetContentType("foo/bar") ctx.TimeoutErrorWithResponse(&resp) }, } rw := &readWriter{} rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 456, "foo/bar", "path=/foo") data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerTimeoutErrorWithCode(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { go func() { ctx.Success("aaa/bbb", []byte("xxxyyy")) }() ctx.TimeoutErrorWithCode("should be ignored", 234) ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest) }, } rw := &readWriter{} rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx") data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerTimeoutError(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { go func() { ctx.Success("aaa/bbb", []byte("xxxyyy")) }() ctx.TimeoutError("should be ignored") ctx.TimeoutError("stolen ctx") }, } rw := &readWriter{} rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx") data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerMaxKeepaliveDuration(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { time.Sleep(20 * time.Millisecond) }, MaxKeepaliveDuration: 10 * time.Millisecond, } rw := &readWriter{} rw.r.WriteString("GET /aaa HTTP/1.1\r\nHost: aa.com\r\n\r\n") rw.r.WriteString("GET /bbbb HTTP/1.1\r\nHost: bbb.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if !resp.ConnectionClose() { t.Fatalf("Response must have 'connection: close' header") } verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerMaxRequestsPerConn(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) {}, MaxRequestsPerConn: 1, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if !resp.ConnectionClose() { t.Fatalf("Response must have 'connection: close' header") } verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerConnectionClose(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { ctx.SetConnectionClose() }, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if !resp.ConnectionClose() { t.Fatalf("expecting Connection: close header") } data, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if len(data) != 0 { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") } } func TestServerRequestNumAndTime(t *testing.T) { n := uint64(0) var connT time.Time s := &Server{ Handler: func(ctx *RequestCtx) { n++ if ctx.ConnRequestNum() != n { t.Fatalf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n) } if connT.IsZero() { connT = ctx.ConnTime() } if ctx.ConnTime() != connT { t.Fatalf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT) } }, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } if n != 3 { t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3) } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, string(defaultContentType), "") } func TestServerEmptyResponse(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { // do nothing :) }, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, string(defaultContentType), "") } type customLogger struct { lock sync.Mutex out string } func (cl *customLogger) Printf(format string, args ...interface{}) { cl.lock.Lock() cl.out += fmt.Sprintf(format, args...)[6:] + "\n" cl.lock.Unlock() } func TestServerLogger(t *testing.T) { cl := &customLogger{} s := &Server{ Handler: func(ctx *RequestCtx) { logger := ctx.Logger() h := &ctx.Request.Header logger.Printf("begin") ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s", h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr()))) logger.Printf("end") }, Logger: cl, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde") rwx := &readWriterRemoteAddr{ rw: rw, addr: &net.TCPAddr{ IP: []byte{1, 2, 3, 4}, Port: 8765, }, } globalConnID = 0 ch := make(chan error) go func() { ch <- s.ServeConn(rwx) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765") verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765") expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin #0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end #0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin #0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end ` if cl.out != expectedLogOut { t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut) } } func TestServerRemoteAddr(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { h := &ctx.Request.Header ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s", h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP()))) }, } rw := &readWriter{} rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") rwx := &readWriterRemoteAddr{ rw: rw, addr: &net.TCPAddr{ IP: []byte{1, 2, 3, 4}, Port: 8765, }, } ch := make(chan error) go func() { ch <- s.ServeConn(rwx) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4") } type readWriterRemoteAddr struct { net.Conn rw io.ReadWriteCloser addr net.Addr } func (rw *readWriterRemoteAddr) Close() error { return rw.rw.Close() } func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) { return rw.rw.Read(b) } func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) { return rw.rw.Write(b) } func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr { return rw.addr } func (rw *readWriterRemoteAddr) LocalAddr() net.Addr { return rw.addr } func TestServerConnError(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Error("foobar", 423) }, } rw := &readWriter{} rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) var resp Response if err := resp.Read(br); err != nil { t.Fatalf("Unexpected error when reading response: %s", err) } if resp.Header.StatusCode() != 423 { t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423) } if resp.Header.ContentLength() != 6 { t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6) } if !bytes.Equal(resp.Header.Peek("Content-Type"), defaultContentType) { t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek("Content-Type"), defaultContentType) } if !bytes.Equal(resp.Body(), []byte("foobar")) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar") } } func TestServeConnSingleRequest(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { h := &ctx.Request.Header ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host")))) }, } rw := &readWriter{} rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") } func TestServeConnMultiRequests(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { h := &ctx.Request.Header ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host")))) }, } rw := &readWriter{} rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n") ch := make(chan error) go func() { ch <- s.ServeConn(rw) }() select { case err := <-ch: if err != nil { t.Fatalf("Unexpected error from serveConn: %s", err) } case <-time.After(100 * time.Millisecond): t.Fatalf("timeout") } br := bufio.NewReader(&rw.w) verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com") } func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { var resp Response if err := resp.Read(r); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if !bytes.Equal(resp.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType) } type readWriter struct { net.Conn r bytes.Buffer w bytes.Buffer } func (rw *readWriter) Close() error { return nil } func (rw *readWriter) Read(b []byte) (int, error) { return rw.r.Read(b) } func (rw *readWriter) Write(b []byte) (int, error) { return rw.w.Write(b) } func (rw *readWriter) RemoteAddr() net.Addr { return zeroTCPAddr } func (rw *readWriter) LocalAddr() net.Addr { return zeroTCPAddr } func (rw *readWriter) SetReadDeadline(t time.Time) error { return nil } func (rw *readWriter) SetWriteDeadline(t time.Time) error { return nil } golang-github-valyala-fasthttp-20160617/server_timing_test.go000066400000000000000000000257061273074646000242610ustar00rootroot00000000000000package fasthttp import ( "bytes" "fmt" "io" "io/ioutil" "net" "net/http" "runtime" "sync" "sync/atomic" "testing" "time" ) var defaultClientsCount = runtime.NumCPU() func BenchmarkServerGet1ReqPerConn(b *testing.B) { benchmarkServerGet(b, defaultClientsCount, 1) } func BenchmarkServerGet2ReqPerConn(b *testing.B) { benchmarkServerGet(b, defaultClientsCount, 2) } func BenchmarkServerGet10ReqPerConn(b *testing.B) { benchmarkServerGet(b, defaultClientsCount, 10) } func BenchmarkServerGet10KReqPerConn(b *testing.B) { benchmarkServerGet(b, defaultClientsCount, 10000) } func BenchmarkNetHTTPServerGet1ReqPerConn(b *testing.B) { benchmarkNetHTTPServerGet(b, defaultClientsCount, 1) } func BenchmarkNetHTTPServerGet2ReqPerConn(b *testing.B) { benchmarkNetHTTPServerGet(b, defaultClientsCount, 2) } func BenchmarkNetHTTPServerGet10ReqPerConn(b *testing.B) { benchmarkNetHTTPServerGet(b, defaultClientsCount, 10) } func BenchmarkNetHTTPServerGet10KReqPerConn(b *testing.B) { benchmarkNetHTTPServerGet(b, defaultClientsCount, 10000) } func BenchmarkServerPost1ReqPerConn(b *testing.B) { benchmarkServerPost(b, defaultClientsCount, 1) } func BenchmarkServerPost2ReqPerConn(b *testing.B) { benchmarkServerPost(b, defaultClientsCount, 2) } func BenchmarkServerPost10ReqPerConn(b *testing.B) { benchmarkServerPost(b, defaultClientsCount, 10) } func BenchmarkServerPost10KReqPerConn(b *testing.B) { benchmarkServerPost(b, defaultClientsCount, 10000) } func BenchmarkNetHTTPServerPost1ReqPerConn(b *testing.B) { benchmarkNetHTTPServerPost(b, defaultClientsCount, 1) } func BenchmarkNetHTTPServerPost2ReqPerConn(b *testing.B) { benchmarkNetHTTPServerPost(b, defaultClientsCount, 2) } func BenchmarkNetHTTPServerPost10ReqPerConn(b *testing.B) { benchmarkNetHTTPServerPost(b, defaultClientsCount, 10) } func BenchmarkNetHTTPServerPost10KReqPerConn(b *testing.B) { benchmarkNetHTTPServerPost(b, defaultClientsCount, 10000) } func BenchmarkServerGet1ReqPerConn10KClients(b *testing.B) { benchmarkServerGet(b, 10000, 1) } func BenchmarkServerGet2ReqPerConn10KClients(b *testing.B) { benchmarkServerGet(b, 10000, 2) } func BenchmarkServerGet10ReqPerConn10KClients(b *testing.B) { benchmarkServerGet(b, 10000, 10) } func BenchmarkServerGet100ReqPerConn10KClients(b *testing.B) { benchmarkServerGet(b, 10000, 100) } func BenchmarkNetHTTPServerGet1ReqPerConn10KClients(b *testing.B) { benchmarkNetHTTPServerGet(b, 10000, 1) } func BenchmarkNetHTTPServerGet2ReqPerConn10KClients(b *testing.B) { benchmarkNetHTTPServerGet(b, 10000, 2) } func BenchmarkNetHTTPServerGet10ReqPerConn10KClients(b *testing.B) { benchmarkNetHTTPServerGet(b, 10000, 10) } func BenchmarkNetHTTPServerGet100ReqPerConn10KClients(b *testing.B) { benchmarkNetHTTPServerGet(b, 10000, 100) } func BenchmarkServerHijack(b *testing.B) { clientsCount := 1000 requestsPerConn := 10000 ch := make(chan struct{}, b.N) responseBody := []byte("123") s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Hijack(func(c net.Conn) { // emulate server loop :) err := ServeConn(c, func(ctx *RequestCtx) { ctx.Success("foobar", responseBody) registerServedRequest(b, ch) }) if err != nil { b.Fatalf("error when serving connection") } }) ctx.Success("foobar", responseBody) registerServedRequest(b, ch) }, Concurrency: 16 * clientsCount, } req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" benchmarkServer(b, s, clientsCount, requestsPerConn, req) verifyRequestsServed(b, ch) } func BenchmarkServerMaxConnsPerIP(b *testing.B) { clientsCount := 1000 requestsPerConn := 10 ch := make(chan struct{}, b.N) responseBody := []byte("123") s := &Server{ Handler: func(ctx *RequestCtx) { ctx.Success("foobar", responseBody) registerServedRequest(b, ch) }, MaxConnsPerIP: clientsCount * 2, Concurrency: 16 * clientsCount, } req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" benchmarkServer(b, s, clientsCount, requestsPerConn, req) verifyRequestsServed(b, ch) } func BenchmarkServerTimeoutError(b *testing.B) { clientsCount := 10 requestsPerConn := 1 ch := make(chan struct{}, b.N) n := uint32(0) responseBody := []byte("123") s := &Server{ Handler: func(ctx *RequestCtx) { if atomic.AddUint32(&n, 1)&7 == 0 { ctx.TimeoutError("xxx") go func() { ctx.Success("foobar", responseBody) }() } else { ctx.Success("foobar", responseBody) } registerServedRequest(b, ch) }, Concurrency: 16 * clientsCount, } req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" benchmarkServer(b, s, clientsCount, requestsPerConn, req) verifyRequestsServed(b, ch) } type fakeServerConn struct { net.TCPConn ln *fakeListener requestsCount int pos int closed uint32 } func (c *fakeServerConn) Read(b []byte) (int, error) { nn := 0 reqLen := len(c.ln.request) for len(b) > 0 { if c.requestsCount == 0 { if nn == 0 { return 0, io.EOF } return nn, nil } pos := c.pos % reqLen n := copy(b, c.ln.request[pos:]) b = b[n:] nn += n c.pos += n if n+pos == reqLen { c.requestsCount-- } } return nn, nil } func (c *fakeServerConn) Write(b []byte) (int, error) { return len(b), nil } var fakeAddr = net.TCPAddr{ IP: []byte{1, 2, 3, 4}, Port: 12345, } func (c *fakeServerConn) RemoteAddr() net.Addr { return &fakeAddr } func (c *fakeServerConn) Close() error { if atomic.AddUint32(&c.closed, 1) == 1 { c.ln.ch <- c } return nil } func (c *fakeServerConn) SetReadDeadline(t time.Time) error { return nil } func (c *fakeServerConn) SetWriteDeadline(t time.Time) error { return nil } type fakeListener struct { lock sync.Mutex requestsCount int requestsPerConn int request []byte ch chan *fakeServerConn done chan struct{} closed bool } func (ln *fakeListener) Accept() (net.Conn, error) { ln.lock.Lock() if ln.requestsCount == 0 { ln.lock.Unlock() for len(ln.ch) < cap(ln.ch) { time.Sleep(10 * time.Millisecond) } ln.lock.Lock() if !ln.closed { close(ln.done) ln.closed = true } ln.lock.Unlock() return nil, io.EOF } requestsCount := ln.requestsPerConn if requestsCount > ln.requestsCount { requestsCount = ln.requestsCount } ln.requestsCount -= requestsCount ln.lock.Unlock() c := <-ln.ch c.requestsCount = requestsCount c.closed = 0 c.pos = 0 return c, nil } func (ln *fakeListener) Close() error { return nil } func (ln *fakeListener) Addr() net.Addr { return &fakeAddr } func newFakeListener(requestsCount, clientsCount, requestsPerConn int, request string) *fakeListener { ln := &fakeListener{ requestsCount: requestsCount, requestsPerConn: requestsPerConn, request: []byte(request), ch: make(chan *fakeServerConn, clientsCount), done: make(chan struct{}), } for i := 0; i < clientsCount; i++ { ln.ch <- &fakeServerConn{ ln: ln, } } return ln } var ( fakeResponse = []byte("Hello, world!") getRequest = "GET /foobar?baz HTTP/1.1\r\nHost: google.com\r\nUser-Agent: aaa/bbb/ccc/ddd/eee Firefox Chrome MSIE Opera\r\n" + "Referer: http://xxx.com/aaa?bbb=ccc\r\nCookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n" postRequest = fmt.Sprintf("POST /foobar?baz HTTP/1.1\r\nHost: google.com\r\nContent-Type: foo/bar\r\nContent-Length: %d\r\n"+ "User-Agent: Opera Chrome MSIE Firefox and other/1.2.34\r\nReferer: http://google.com/aaaa/bbb/ccc\r\n"+ "Cookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n%s", len(fakeResponse), fakeResponse) ) func benchmarkServerGet(b *testing.B, clientsCount, requestsPerConn int) { ch := make(chan struct{}, b.N) s := &Server{ Handler: func(ctx *RequestCtx) { if !ctx.IsGet() { b.Fatalf("Unexpected request method: %s", ctx.Method()) } ctx.Success("text/plain", fakeResponse) if requestsPerConn == 1 { ctx.SetConnectionClose() } registerServedRequest(b, ch) }, Concurrency: 16 * clientsCount, } benchmarkServer(b, s, clientsCount, requestsPerConn, getRequest) verifyRequestsServed(b, ch) } func benchmarkNetHTTPServerGet(b *testing.B, clientsCount, requestsPerConn int) { ch := make(chan struct{}, b.N) s := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Method != "GET" { b.Fatalf("Unexpected request method: %s", req.Method) } h := w.Header() h.Set("Content-Type", "text/plain") if requestsPerConn == 1 { h.Set("Connection", "close") } w.Write(fakeResponse) registerServedRequest(b, ch) }), } benchmarkServer(b, s, clientsCount, requestsPerConn, getRequest) verifyRequestsServed(b, ch) } func benchmarkServerPost(b *testing.B, clientsCount, requestsPerConn int) { ch := make(chan struct{}, b.N) s := &Server{ Handler: func(ctx *RequestCtx) { if !ctx.IsPost() { b.Fatalf("Unexpected request method: %s", ctx.Method()) } body := ctx.Request.Body() if !bytes.Equal(body, fakeResponse) { b.Fatalf("Unexpected body %q. Expected %q", body, fakeResponse) } ctx.Success("text/plain", body) if requestsPerConn == 1 { ctx.SetConnectionClose() } registerServedRequest(b, ch) }, Concurrency: 16 * clientsCount, } benchmarkServer(b, s, clientsCount, requestsPerConn, postRequest) verifyRequestsServed(b, ch) } func benchmarkNetHTTPServerPost(b *testing.B, clientsCount, requestsPerConn int) { ch := make(chan struct{}, b.N) s := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { b.Fatalf("Unexpected request method: %s", req.Method) } body, err := ioutil.ReadAll(req.Body) if err != nil { b.Fatalf("Unexpected error: %s", err) } req.Body.Close() if !bytes.Equal(body, fakeResponse) { b.Fatalf("Unexpected body %q. Expected %q", body, fakeResponse) } h := w.Header() h.Set("Content-Type", "text/plain") if requestsPerConn == 1 { h.Set("Connection", "close") } w.Write(body) registerServedRequest(b, ch) }), } benchmarkServer(b, s, clientsCount, requestsPerConn, postRequest) verifyRequestsServed(b, ch) } func registerServedRequest(b *testing.B, ch chan<- struct{}) { select { case ch <- struct{}{}: default: b.Fatalf("More than %d requests served", cap(ch)) } } func verifyRequestsServed(b *testing.B, ch <-chan struct{}) { requestsServed := 0 for len(ch) > 0 { <-ch requestsServed++ } requestsSent := b.N for requestsServed < requestsSent { select { case <-ch: requestsServed++ case <-time.After(100 * time.Millisecond): b.Fatalf("Unexpected number of requests served %d. Expected %d", requestsServed, requestsSent) } } } type realServer interface { Serve(ln net.Listener) error } func benchmarkServer(b *testing.B, s realServer, clientsCount, requestsPerConn int, request string) { ln := newFakeListener(b.N, clientsCount, requestsPerConn, request) ch := make(chan struct{}) go func() { s.Serve(ln) ch <- struct{}{} }() <-ln.done select { case <-ch: case <-time.After(10 * time.Second): b.Fatalf("Server.Serve() didn't stop") } } golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.key000066400000000000000000000032501273074646000240550ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG 3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT 6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt 9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV c257YgaWmjK9uB0Y2r2VxS0G -----END PRIVATE KEY----- golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.pem000066400000000000000000000017551273074646000240560ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te +z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY 6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd -----END CERTIFICATE----- golang-github-valyala-fasthttp-20160617/status.go000066400000000000000000000116271273074646000216650ustar00rootroot00000000000000package fasthttp import ( "fmt" "sync/atomic" ) // HTTP status codes were stolen from net/http. const ( StatusContinue = 100 StatusSwitchingProtocols = 101 StatusOK = 200 StatusCreated = 201 StatusAccepted = 202 StatusNonAuthoritativeInfo = 203 StatusNoContent = 204 StatusResetContent = 205 StatusPartialContent = 206 StatusMultipleChoices = 300 StatusMovedPermanently = 301 StatusFound = 302 StatusSeeOther = 303 StatusNotModified = 304 StatusUseProxy = 305 StatusTemporaryRedirect = 307 StatusBadRequest = 400 StatusUnauthorized = 401 StatusPaymentRequired = 402 StatusForbidden = 403 StatusNotFound = 404 StatusMethodNotAllowed = 405 StatusNotAcceptable = 406 StatusProxyAuthRequired = 407 StatusRequestTimeout = 408 StatusConflict = 409 StatusGone = 410 StatusLengthRequired = 411 StatusPreconditionFailed = 412 StatusRequestEntityTooLarge = 413 StatusRequestURITooLong = 414 StatusUnsupportedMediaType = 415 StatusRequestedRangeNotSatisfiable = 416 StatusExpectationFailed = 417 StatusTeapot = 418 StatusPreconditionRequired = 428 StatusTooManyRequests = 429 StatusRequestHeaderFieldsTooLarge = 431 StatusInternalServerError = 500 StatusNotImplemented = 501 StatusBadGateway = 502 StatusServiceUnavailable = 503 StatusGatewayTimeout = 504 StatusHTTPVersionNotSupported = 505 StatusNetworkAuthenticationRequired = 511 ) var ( statusLines atomic.Value statusMessages = map[int]string{ StatusContinue: "Continue", StatusSwitchingProtocols: "SwitchingProtocols", StatusOK: "OK", StatusCreated: "Created", StatusAccepted: "Accepted", StatusNonAuthoritativeInfo: "Non-Authoritative Info", StatusNoContent: "No Content", StatusResetContent: "Reset Content", StatusPartialContent: "Partial Content", StatusMultipleChoices: "Multiple Choices", StatusMovedPermanently: "Moved Permanently", StatusFound: "Found", StatusSeeOther: "See Other", StatusNotModified: "Not Modified", StatusUseProxy: "Use Proxy", StatusTemporaryRedirect: "Temporary Redirect", StatusBadRequest: "Bad Request", StatusUnauthorized: "Unauthorized", StatusPaymentRequired: "Payment Required", StatusForbidden: "Forbidden", StatusNotFound: "Not Found", StatusMethodNotAllowed: "Method Not Allowed", StatusNotAcceptable: "Not Acceptable", StatusProxyAuthRequired: "Proxy Auth Required", StatusRequestTimeout: "Request Timeout", StatusConflict: "Conflict", StatusGone: "Gone", StatusLengthRequired: "Length Required", StatusPreconditionFailed: "Precondition Failed", StatusRequestEntityTooLarge: "Request Entity Too Large", StatusRequestURITooLong: "Request URI Too Long", StatusUnsupportedMediaType: "Unsupported Media Type", StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", StatusExpectationFailed: "Expectation Failed", StatusTeapot: "Teapot", StatusPreconditionRequired: "Precondition Required", StatusTooManyRequests: "Too Many Requests", StatusRequestHeaderFieldsTooLarge: "Request HeaderFields Too Large", StatusInternalServerError: "Internal Server Error", StatusNotImplemented: "Not Implemented", StatusBadGateway: "Bad Gateway", StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", StatusNetworkAuthenticationRequired: "Network Authentication Required", } ) // StatusMessage returns HTTP status message for the given status code. func StatusMessage(statusCode int) string { s := statusMessages[statusCode] if s == "" { s = "Unknown Status Code" } return s } func init() { statusLines.Store(make(map[int][]byte)) } func statusLine(statusCode int) []byte { m := statusLines.Load().(map[int][]byte) h := m[statusCode] if h != nil { return h } statusText := StatusMessage(statusCode) h = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, statusText)) newM := make(map[int][]byte, len(m)+1) for k, v := range m { newM[k] = v } newM[statusCode] = h statusLines.Store(newM) return h } golang-github-valyala-fasthttp-20160617/stream.go000066400000000000000000000024741273074646000216350ustar00rootroot00000000000000package fasthttp import ( "bufio" "io" "runtime/debug" "sync" "github.com/valyala/fasthttp/fasthttputil" ) // StreamWriter must write data to w. // // Usually StreamWriter writes data to w in a loop (aka 'data streaming'). // // StreamWriter must return immediately if w returns error. // // Since the written data is buffered, do not forget calling w.Flush // when the data must be propagated to reader. type StreamWriter func(w *bufio.Writer) // NewStreamReader returns a reader, which replays all the data generated by sw. // // The returned reader may be passed to Response.SetBodyStream. // // Close must be called on the returned reader after all the required data // has been read. Otherwise goroutine leak may occur. // // See also Response.SetBodyStreamWriter. func NewStreamReader(sw StreamWriter) io.ReadCloser { pc := fasthttputil.NewPipeConns() pw := pc.Conn1() pr := pc.Conn2() var bw *bufio.Writer v := streamWriterBufPool.Get() if v == nil { bw = bufio.NewWriter(pw) } else { bw = v.(*bufio.Writer) bw.Reset(pw) } go func() { defer func() { if r := recover(); r != nil { defaultLogger.Printf("panic in StreamWriter: %s\nStack trace:\n%s", r, debug.Stack()) } }() sw(bw) bw.Flush() pw.Close() streamWriterBufPool.Put(bw) }() return pr } var streamWriterBufPool sync.Pool golang-github-valyala-fasthttp-20160617/stream_test.go000066400000000000000000000041731273074646000226720ustar00rootroot00000000000000package fasthttp import ( "bufio" "fmt" "io" "io/ioutil" "testing" "time" ) func TestNewStreamReader(t *testing.T) { ch := make(chan struct{}) r := NewStreamReader(func(w *bufio.Writer) { fmt.Fprintf(w, "Hello, world\n") fmt.Fprintf(w, "Line #2\n") close(ch) }) data, err := ioutil.ReadAll(r) if err != nil { t.Fatalf("unexpected error: %s", err) } expectedData := "Hello, world\nLine #2\n" if string(data) != expectedData { t.Fatalf("unexpected data %q. Expecting %q", data, expectedData) } if err = r.Close(); err != nil { t.Fatalf("unexpected error") } select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } func TestStreamReaderClose(t *testing.T) { firstLine := "the first line must pass" ch := make(chan error, 1) r := NewStreamReader(func(w *bufio.Writer) { fmt.Fprintf(w, "%s", firstLine) if err := w.Flush(); err != nil { ch <- fmt.Errorf("unexpected error on first flush: %s", err) return } data := createFixedBody(4000) for i := 0; i < 100; i++ { w.Write(data) } if err := w.Flush(); err == nil { ch <- fmt.Errorf("expecting error on the second flush") } ch <- nil }) buf := make([]byte, len(firstLine)) n, err := io.ReadFull(r, buf) if err != nil { t.Fatalf("unexpected error: %s", err) } if n != len(buf) { t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(buf)) } if string(buf) != firstLine { t.Fatalf("unexpected result: %q. Expecting %q", buf, firstLine) } if err := r.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } select { case err := <-ch: if err != nil { t.Fatalf("error returned from stream reader: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout when waiting for stream reader") } // read trailing data go func() { if _, err := ioutil.ReadAll(r); err != nil { ch <- fmt.Errorf("unexpected error when reading trailing data: %s", err) return } ch <- nil }() select { case err := <-ch: if err != nil { t.Fatalf("error returned when reading tail data: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout when reading tail data") } } golang-github-valyala-fasthttp-20160617/stream_timing_test.go000066400000000000000000000025271273074646000242420ustar00rootroot00000000000000package fasthttp import ( "bufio" "io" "testing" "time" ) func BenchmarkStreamReader1(b *testing.B) { benchmarkStreamReader(b, 1) } func BenchmarkStreamReader10(b *testing.B) { benchmarkStreamReader(b, 10) } func BenchmarkStreamReader100(b *testing.B) { benchmarkStreamReader(b, 100) } func BenchmarkStreamReader1K(b *testing.B) { benchmarkStreamReader(b, 1000) } func BenchmarkStreamReader10K(b *testing.B) { benchmarkStreamReader(b, 10000) } func benchmarkStreamReader(b *testing.B, size int) { src := createFixedBody(size) b.SetBytes(int64(size)) b.RunParallel(func(pb *testing.PB) { dst := make([]byte, size) ch := make(chan error, 1) sr := NewStreamReader(func(w *bufio.Writer) { for pb.Next() { if _, err := w.Write(src); err != nil { ch <- err return } if err := w.Flush(); err != nil { ch <- err return } } ch <- nil }) for { if _, err := sr.Read(dst); err != nil { if err == io.EOF { break } b.Fatalf("unexpected error when reading from stream reader: %s", err) } } if err := sr.Close(); err != nil { b.Fatalf("unexpected error when closing stream reader: %s", err) } select { case err := <-ch: if err != nil { b.Fatalf("unexpected error from stream reader: %s", err) } case <-time.After(time.Second): b.Fatalf("timeout") } }) } golang-github-valyala-fasthttp-20160617/strings.go000066400000000000000000000047361273074646000220360ustar00rootroot00000000000000package fasthttp var ( defaultServerName = []byte("fasthttp") defaultUserAgent = []byte("fasthttp") defaultContentType = []byte("text/plain; charset=utf-8") ) var ( strSlash = []byte("/") strSlashSlash = []byte("//") strSlashDotDot = []byte("/..") strSlashDotSlash = []byte("/./") strSlashDotDotSlash = []byte("/../") strCRLF = []byte("\r\n") strHTTP = []byte("http") strHTTPS = []byte("https") strHTTP11 = []byte("HTTP/1.1") strColonSlashSlash = []byte("://") strColonSpace = []byte(": ") strGMT = []byte("GMT") strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n") strGet = []byte("GET") strHead = []byte("HEAD") strPost = []byte("POST") strPut = []byte("PUT") strDelete = []byte("DELETE") strExpect = []byte("Expect") strConnection = []byte("Connection") strContentLength = []byte("Content-Length") strContentType = []byte("Content-Type") strDate = []byte("Date") strHost = []byte("Host") strReferer = []byte("Referer") strServer = []byte("Server") strTransferEncoding = []byte("Transfer-Encoding") strContentEncoding = []byte("Content-Encoding") strAcceptEncoding = []byte("Accept-Encoding") strUserAgent = []byte("User-Agent") strCookie = []byte("Cookie") strSetCookie = []byte("Set-Cookie") strLocation = []byte("Location") strIfModifiedSince = []byte("If-Modified-Since") strLastModified = []byte("Last-Modified") strAcceptRanges = []byte("Accept-Ranges") strRange = []byte("Range") strContentRange = []byte("Content-Range") strCookieExpires = []byte("expires") strCookieDomain = []byte("domain") strCookiePath = []byte("path") strCookieHTTPOnly = []byte("HttpOnly") strCookieSecure = []byte("secure") strClose = []byte("close") strGzip = []byte("gzip") strDeflate = []byte("deflate") strKeepAlive = []byte("keep-alive") strKeepAliveCamelCase = []byte("Keep-Alive") strUpgrade = []byte("Upgrade") strChunked = []byte("chunked") strIdentity = []byte("identity") str100Continue = []byte("100-continue") strPostArgsContentType = []byte("application/x-www-form-urlencoded") strMultipartFormData = []byte("multipart/form-data") strBoundary = []byte("boundary") strBytes = []byte("bytes") ) golang-github-valyala-fasthttp-20160617/tcpdialer.go000066400000000000000000000213431273074646000223050ustar00rootroot00000000000000package fasthttp import ( "errors" "net" "strconv" "sync" "sync/atomic" "time" ) // Dial dials the given TCP addr using tcp4. // // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed // for DefaultDNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. // * It returns ErrDialTimeout if connection cannot be established during // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout. // // This dialer is intended for custom code wrapping before passing // to Client.Dial or HostClient.Dial. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. // // The addr passed to the function must contain port. Example addr values: // // * foobar.baz:443 // * foo.bar:80 // * aaa.com:8080 func Dial(addr string) (net.Conn, error) { return getDialer(DefaultDialTimeout, false)(addr) } // DialTimeout dials the given TCP addr using tcp4 using the given timeout. // // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed // for DefaultDNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing // to Client.Dial or HostClient.Dial. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. // // The addr passed to the function must contain port. Example addr values: // // * foobar.baz:443 // * foo.bar:80 // * aaa.com:8080 func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { return getDialer(timeout, false)(addr) } // DialDualStack dials the given TCP addr using both tcp4 and tcp6. // // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed // for DefaultDNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. // * It returns ErrDialTimeout if connection cannot be established during // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial // timeout. // // This dialer is intended for custom code wrapping before passing // to Client.Dial or HostClient.Dial. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. // // The addr passed to the function must contain port. Example addr values: // // * foobar.baz:443 // * foo.bar:80 // * aaa.com:8080 func DialDualStack(addr string) (net.Conn, error) { return getDialer(DefaultDialTimeout, true)(addr) } // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 // using the given timeout. // // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed // for DefaultDNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing // to Client.Dial or HostClient.Dial. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. // // The addr passed to the function must contain port. Example addr values: // // * foobar.baz:443 // * foo.bar:80 // * aaa.com:8080 func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { return getDialer(timeout, true)(addr) } func getDialer(timeout time.Duration, dualStack bool) DialFunc { if timeout <= 0 { timeout = DefaultDialTimeout } timeoutRounded := int(timeout.Seconds()*10 + 9) m := dialMap if dualStack { m = dialDualStackMap } dialMapLock.Lock() d := m[timeoutRounded] if d == nil { dialer := dialerStd if dualStack { dialer = dialerDualStack } d = dialer.NewDial(timeout) m[timeoutRounded] = d } dialMapLock.Unlock() return d } var ( dialerStd = &tcpDialer{} dialerDualStack = &tcpDialer{DualStack: true} dialMap = make(map[int]DialFunc) dialDualStackMap = make(map[int]DialFunc) dialMapLock sync.Mutex ) type tcpDialer struct { DualStack bool tcpAddrsLock sync.Mutex tcpAddrsMap map[string]*tcpAddrEntry concurrencyCh chan struct{} once sync.Once } const maxDialConcurrency = 1000 func (d *tcpDialer) NewDial(timeout time.Duration) DialFunc { d.once.Do(func() { d.concurrencyCh = make(chan struct{}, maxDialConcurrency) d.tcpAddrsMap = make(map[string]*tcpAddrEntry) go d.tcpAddrsClean() }) return func(addr string) (net.Conn, error) { addrs, idx, err := d.getTCPAddrs(addr) if err != nil { return nil, err } network := "tcp4" if d.DualStack { network = "tcp" } var conn net.Conn n := uint32(len(addrs)) deadline := time.Now().Add(timeout) for n > 0 { conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) if err == nil { return conn, nil } if err == ErrDialTimeout { return nil, err } idx++ n-- } return nil, err } } func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) { timeout := -time.Since(deadline) if timeout <= 0 { return nil, ErrDialTimeout } select { case concurrencyCh <- struct{}{}: default: tc := acquireTimer(timeout) isTimeout := false select { case concurrencyCh <- struct{}{}: case <-tc.C: isTimeout = true } releaseTimer(tc) if isTimeout { return nil, ErrDialTimeout } } timeout = -time.Since(deadline) if timeout <= 0 { <-concurrencyCh return nil, ErrDialTimeout } chv := dialResultChanPool.Get() if chv == nil { chv = make(chan dialResult, 1) } ch := chv.(chan dialResult) go func() { var dr dialResult dr.conn, dr.err = net.DialTCP(network, nil, addr) ch <- dr <-concurrencyCh }() var ( conn net.Conn err error ) tc := acquireTimer(timeout) select { case dr := <-ch: conn = dr.conn err = dr.err dialResultChanPool.Put(ch) case <-tc.C: err = ErrDialTimeout } releaseTimer(tc) return conn, err } var dialResultChanPool sync.Pool type dialResult struct { conn net.Conn err error } // ErrDialTimeout is returned when TCP dialing is timed out. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out") // DefaultDialTimeout is timeout used by Dial and DialDualStack // for establishing TCP connections. const DefaultDialTimeout = 3 * time.Second type tcpAddrEntry struct { addrs []net.TCPAddr addrsIdx uint32 resolveTime time.Time pending bool } // DefaultDNSCacheDuration is the duration for caching resolved TCP addresses // by Dial* functions. const DefaultDNSCacheDuration = time.Minute func (d *tcpDialer) tcpAddrsClean() { expireDuration := 2 * DefaultDNSCacheDuration for { time.Sleep(time.Second) t := time.Now() d.tcpAddrsLock.Lock() for k, e := range d.tcpAddrsMap { if t.Sub(e.resolveTime) > expireDuration { delete(d.tcpAddrsMap, k) } } d.tcpAddrsLock.Unlock() } } func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { d.tcpAddrsLock.Lock() e := d.tcpAddrsMap[addr] if e != nil && !e.pending && time.Since(e.resolveTime) > DefaultDNSCacheDuration { e.pending = true e = nil } d.tcpAddrsLock.Unlock() if e == nil { addrs, err := resolveTCPAddrs(addr, d.DualStack) if err != nil { d.tcpAddrsLock.Lock() e = d.tcpAddrsMap[addr] if e != nil && e.pending { e.pending = false } d.tcpAddrsLock.Unlock() return nil, 0, err } e = &tcpAddrEntry{ addrs: addrs, resolveTime: time.Now(), } d.tcpAddrsLock.Lock() d.tcpAddrsMap[addr] = e d.tcpAddrsLock.Unlock() } idx := uint32(0) if len(e.addrs) > 0 { idx = atomic.AddUint32(&e.addrsIdx, 1) } return e.addrs, idx, nil } func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) { host, portS, err := net.SplitHostPort(addr) if err != nil { return nil, err } port, err := strconv.Atoi(portS) if err != nil { return nil, err } ips, err := net.LookupIP(host) if err != nil { return nil, err } n := len(ips) addrs := make([]net.TCPAddr, 0, n) for i := 0; i < n; i++ { ip := ips[i] if !dualStack && ip.To4() == nil { continue } addrs = append(addrs, net.TCPAddr{ IP: ip, Port: port, }) } return addrs, nil } golang-github-valyala-fasthttp-20160617/timer.go000066400000000000000000000013471273074646000214600ustar00rootroot00000000000000package fasthttp import ( "sync" "time" ) func initTimer(t *time.Timer, timeout time.Duration) *time.Timer { if t == nil { return time.NewTimer(timeout) } if t.Reset(timeout) { panic("BUG: active timer trapped into initTimer()") } return t } func stopTimer(t *time.Timer) { if !t.Stop() { // Collect possibly added time from the channel // if timer has been stopped and nobody collected its' value. select { case <-t.C: default: } } } func acquireTimer(timeout time.Duration) *time.Timer { v := timerPool.Get() if v == nil { return time.NewTimer(timeout) } t := v.(*time.Timer) initTimer(t, timeout) return t } func releaseTimer(t *time.Timer) { stopTimer(t) timerPool.Put(t) } var timerPool sync.Pool golang-github-valyala-fasthttp-20160617/uri.go000066400000000000000000000274701273074646000211440ustar00rootroot00000000000000package fasthttp import ( "bytes" "io" "sync" ) // AcquireURI returns an empty URI instance from the pool. // // Release the URI with ReleaseURI after the URI is no longer needed. // This allows reducing GC load. func AcquireURI() *URI { return uriPool.Get().(*URI) } // ReleaseURI releases the URI acquired via AcquireURI. // // The released URI mustn't be used after releasing it, otherwise data races // may occur. func ReleaseURI(u *URI) { u.Reset() uriPool.Put(u) } var uriPool = &sync.Pool{ New: func() interface{} { return &URI{} }, } // URI represents URI :) . // // It is forbidden copying URI instances. Create new instance and use CopyTo // instead. // // URI instance MUST NOT be used from concurrently running goroutines. type URI struct { noCopy noCopy pathOriginal []byte scheme []byte path []byte queryString []byte hash []byte host []byte queryArgs Args parsedQueryArgs bool fullURI []byte requestURI []byte h *RequestHeader } // CopyTo copies uri contents to dst. func (u *URI) CopyTo(dst *URI) { dst.Reset() dst.pathOriginal = append(dst.pathOriginal[:0], u.pathOriginal...) dst.scheme = append(dst.scheme[:0], u.scheme...) dst.path = append(dst.path[:0], u.path...) dst.queryString = append(dst.queryString[:0], u.queryString...) dst.hash = append(dst.hash[:0], u.hash...) dst.host = append(dst.host[:0], u.host...) u.queryArgs.CopyTo(&dst.queryArgs) dst.parsedQueryArgs = u.parsedQueryArgs // fullURI and requestURI shouldn't be copied, since they are created // from scratch on each FullURI() and RequestURI() call. dst.h = u.h } // Hash returns URI hash, i.e. qwe of http://aaa.com/foo/bar?baz=123#qwe . // // The returned value is valid until the next URI method call. func (u *URI) Hash() []byte { return u.hash } // SetHash sets URI hash. func (u *URI) SetHash(hash string) { u.hash = append(u.hash[:0], hash...) } // SetHashBytes sets URI hash. func (u *URI) SetHashBytes(hash []byte) { u.hash = append(u.hash[:0], hash...) } // QueryString returns URI query string, // i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe . // // The returned value is valid until the next URI method call. func (u *URI) QueryString() []byte { return u.queryString } // SetQueryString sets URI query string. func (u *URI) SetQueryString(queryString string) { u.queryString = append(u.queryString[:0], queryString...) u.parsedQueryArgs = false } // SetQueryStringBytes sets URI query string. func (u *URI) SetQueryStringBytes(queryString []byte) { u.queryString = append(u.queryString[:0], queryString...) u.parsedQueryArgs = false } // Path returns URI path, i.e. /foo/bar of http://aaa.com/foo/bar?baz=123#qwe . // // The returned path is always urldecoded and normalized, // i.e. '//f%20obar/baz/../zzz' becomes '/f obar/zzz'. // // The returned value is valid until the next URI method call. func (u *URI) Path() []byte { path := u.path if len(path) == 0 { path = strSlash } return path } // SetPath sets URI path. func (u *URI) SetPath(path string) { u.pathOriginal = append(u.pathOriginal[:0], path...) u.path = normalizePath(u.path, u.pathOriginal) } // SetPathBytes sets URI path. func (u *URI) SetPathBytes(path []byte) { u.pathOriginal = append(u.pathOriginal[:0], path...) u.path = normalizePath(u.path, u.pathOriginal) } // PathOriginal returns the original path from requestURI passed to URI.Parse(). // // The returned value is valid until the next URI method call. func (u *URI) PathOriginal() []byte { return u.pathOriginal } // Scheme returns URI scheme, i.e. http of http://aaa.com/foo/bar?baz=123#qwe . // // Returned scheme is always lowercased. // // The returned value is valid until the next URI method call. func (u *URI) Scheme() []byte { scheme := u.scheme if len(scheme) == 0 { scheme = strHTTP } return scheme } // SetScheme sets URI scheme, i.e. http, https, ftp, etc. func (u *URI) SetScheme(scheme string) { u.scheme = append(u.scheme[:0], scheme...) lowercaseBytes(u.scheme) } // SetSchemeBytes sets URI scheme, i.e. http, https, ftp, etc. func (u *URI) SetSchemeBytes(scheme []byte) { u.scheme = append(u.scheme[:0], scheme...) lowercaseBytes(u.scheme) } // Reset clears uri. func (u *URI) Reset() { u.pathOriginal = u.pathOriginal[:0] u.scheme = u.scheme[:0] u.path = u.path[:0] u.queryString = u.queryString[:0] u.hash = u.hash[:0] u.host = u.host[:0] u.queryArgs.Reset() u.parsedQueryArgs = false // There is no need in u.fullURI = u.fullURI[:0], since full uri // is calucalted on each call to FullURI(). // There is no need in u.requestURI = u.requestURI[:0], since requestURI // is calculated on each call to RequestURI(). u.h = nil } // Host returns host part, i.e. aaa.com of http://aaa.com/foo/bar?baz=123#qwe . // // Host is always lowercased. func (u *URI) Host() []byte { if len(u.host) == 0 && u.h != nil { u.host = append(u.host[:0], u.h.Host()...) lowercaseBytes(u.host) u.h = nil } return u.host } // SetHost sets host for the uri. func (u *URI) SetHost(host string) { u.host = append(u.host[:0], host...) lowercaseBytes(u.host) } // SetHostBytes sets host for the uri. func (u *URI) SetHostBytes(host []byte) { u.host = append(u.host[:0], host...) lowercaseBytes(u.host) } // Parse initializes URI from the given host and uri. func (u *URI) Parse(host, uri []byte) { u.parse(host, uri, nil) } func (u *URI) parseQuick(uri []byte, h *RequestHeader) { u.parse(nil, uri, h) } func (u *URI) parse(host, uri []byte, h *RequestHeader) { u.Reset() u.h = h scheme, host, uri := splitHostURI(host, uri) u.scheme = append(u.scheme, scheme...) lowercaseBytes(u.scheme) u.host = append(u.host, host...) lowercaseBytes(u.host) b := uri queryIndex := bytes.IndexByte(b, '?') fragmentIndex := bytes.IndexByte(b, '#') // Ignore query in fragment part if fragmentIndex >= 0 && queryIndex > fragmentIndex { queryIndex = -1 } if queryIndex < 0 && fragmentIndex < 0 { u.pathOriginal = append(u.pathOriginal, b...) u.path = normalizePath(u.path, u.pathOriginal) return } if queryIndex >= 0 { // Path is everything up to the start of the query u.pathOriginal = append(u.pathOriginal, b[:queryIndex]...) u.path = normalizePath(u.path, u.pathOriginal) if fragmentIndex < 0 { u.queryString = append(u.queryString, b[queryIndex+1:]...) } else { u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...) u.hash = append(u.hash, b[fragmentIndex+1:]...) } return } // fragmentIndex >= 0 && queryIndex < 0 // Path is up to the start of fragment u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...) u.path = normalizePath(u.path, u.pathOriginal) u.hash = append(u.hash, b[fragmentIndex+1:]...) } func normalizePath(dst, src []byte) []byte { dst = dst[:0] dst = addLeadingSlash(dst, src) dst = decodeArgAppend(dst, src, false) // remove duplicate slashes b := dst bSize := len(b) for { n := bytes.Index(b, strSlashSlash) if n < 0 { break } b = b[n:] copy(b, b[1:]) b = b[:len(b)-1] bSize-- } dst = dst[:bSize] // remove /./ parts b = dst for { n := bytes.Index(b, strSlashDotSlash) if n < 0 { break } nn := n + len(strSlashDotSlash) - 1 copy(b[n:], b[nn:]) b = b[:len(b)-nn+n] } // remove /foo/../ parts for { n := bytes.Index(b, strSlashDotDotSlash) if n < 0 { break } nn := bytes.LastIndexByte(b[:n], '/') if nn < 0 { nn = 0 } n += len(strSlashDotDotSlash) - 1 copy(b[nn:], b[n:]) b = b[:len(b)-n+nn] } // remove trailing /foo/.. n := bytes.LastIndex(b, strSlashDotDot) if n >= 0 && n+len(strSlashDotDot) == len(b) { nn := bytes.LastIndexByte(b[:n], '/') if nn < 0 { return strSlash } b = b[:nn+1] } return b } // RequestURI returns RequestURI - i.e. URI without Scheme and Host. func (u *URI) RequestURI() []byte { dst := appendQuotedPath(u.requestURI[:0], u.Path()) if u.queryArgs.Len() > 0 { dst = append(dst, '?') dst = u.queryArgs.AppendBytes(dst) } else if len(u.queryString) > 0 { dst = append(dst, '?') dst = append(dst, u.queryString...) } if len(u.hash) > 0 { dst = append(dst, '#') dst = append(dst, u.hash...) } u.requestURI = dst return u.requestURI } // LastPathSegment returns the last part of uri path after '/'. // // Examples: // // * For /foo/bar/baz.html path returns baz.html. // * For /foo/bar/ returns empty byte slice. // * For /foobar.js returns foobar.js. func (u *URI) LastPathSegment() []byte { path := u.Path() n := bytes.LastIndexByte(path, '/') if n < 0 { return path } return path[n+1:] } // Update updates uri. // // The following newURI types are accepted: // // * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. // * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // * Relative path, i.e. xx?yy=abc . In this case the original RequestURI // is updated according to the new relative path. func (u *URI) Update(newURI string) { u.UpdateBytes(s2b(newURI)) } // UpdateBytes updates uri. // // The following newURI types are accepted: // // * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. // * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // * Relative path, i.e. xx?yy=abc . In this case the original RequestURI // is updated according to the new relative path. func (u *URI) UpdateBytes(newURI []byte) { u.requestURI = u.updateBytes(newURI, u.requestURI) } func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(newURI) == 0 { return buf } if newURI[0] == '/' { // uri without host buf = u.appendSchemeHost(buf[:0]) buf = append(buf, newURI...) u.Parse(nil, buf) return buf } n := bytes.Index(newURI, strColonSlashSlash) if n >= 0 { // absolute uri u.Parse(nil, newURI) return buf } // relative path switch newURI[0] { case '?': // query string only update u.SetQueryStringBytes(newURI[1:]) return append(buf[:0], u.FullURI()...) case '#': // update only hash u.SetHashBytes(newURI[1:]) return append(buf[:0], u.FullURI()...) default: // update the last path part after the slash path := u.Path() n = bytes.LastIndexByte(path, '/') if n < 0 { panic("BUG: path must contain at least one slash") } buf = u.appendSchemeHost(buf[:0]) buf = appendQuotedPath(buf, path[:n+1]) buf = append(buf, newURI...) u.Parse(nil, buf) return buf } } // FullURI returns full uri in the form {Scheme}://{Host}{RequestURI}#{Hash}. func (u *URI) FullURI() []byte { u.fullURI = u.AppendBytes(u.fullURI[:0]) return u.fullURI } // AppendBytes appends full uri to dst and returns the extended dst. func (u *URI) AppendBytes(dst []byte) []byte { dst = u.appendSchemeHost(dst) return append(dst, u.RequestURI()...) } func (u *URI) appendSchemeHost(dst []byte) []byte { dst = append(dst, u.Scheme()...) dst = append(dst, strColonSlashSlash...) return append(dst, u.Host()...) } // WriteTo writes full uri to w. // // WriteTo implements io.WriterTo interface. func (u *URI) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(u.FullURI()) return int64(n), err } // String returns full uri. func (u *URI) String() string { return string(u.FullURI()) } func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { n := bytes.Index(uri, strColonSlashSlash) if n < 0 { return strHTTP, host, uri } scheme := uri[:n] if bytes.IndexByte(scheme, '/') >= 0 { return strHTTP, host, uri } n += len(strColonSlashSlash) uri = uri[n:] n = bytes.IndexByte(uri, '/') if n < 0 { return scheme, uri, strSlash } return scheme, uri[:n], uri[n:] } // QueryArgs returns query args. func (u *URI) QueryArgs() *Args { u.parseQueryArgs() return &u.queryArgs } func (u *URI) parseQueryArgs() { if u.parsedQueryArgs { return } u.queryArgs.ParseBytes(u.queryString) u.parsedQueryArgs = true } golang-github-valyala-fasthttp-20160617/uri_test.go000066400000000000000000000240321273074646000221720ustar00rootroot00000000000000package fasthttp import ( "bytes" "fmt" "testing" "time" ) func TestURICopyToQueryArgs(t *testing.T) { var u URI a := u.QueryArgs() a.Set("foo", "bar") var u1 URI u.CopyTo(&u1) a1 := u1.QueryArgs() if string(a1.Peek("foo")) != "bar" { t.Fatalf("unexpected query args value %q. Expecting %q", a1.Peek("foo"), "bar") } } func TestURIAcquireReleaseSequential(t *testing.T) { testURIAcquireRelease(t) } func TestURIAcquireReleaseConcurrent(t *testing.T) { ch := make(chan struct{}, 10) for i := 0; i < 10; i++ { go func() { testURIAcquireRelease(t) ch <- struct{}{} }() } for i := 0; i < 10; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testURIAcquireRelease(t *testing.T) { for i := 0; i < 10; i++ { u := AcquireURI() host := fmt.Sprintf("host.%d.com", i*23) path := fmt.Sprintf("/foo/%d/bar", i*17) queryArgs := "?foo=bar&baz=aass" u.Parse([]byte(host), []byte(path+queryArgs)) if string(u.Host()) != host { t.Fatalf("unexpected host %q. Expecting %q", u.Host(), host) } if string(u.Path()) != path { t.Fatalf("unexpected path %q. Expecting %q", u.Path(), path) } ReleaseURI(u) } } func TestURILastPathSegment(t *testing.T) { testURILastPathSegment(t, "", "") testURILastPathSegment(t, "/", "") testURILastPathSegment(t, "/foo/bar/", "") testURILastPathSegment(t, "/foobar.js", "foobar.js") testURILastPathSegment(t, "/foo/bar/baz.html", "baz.html") } func testURILastPathSegment(t *testing.T, path, expectedSegment string) { var u URI u.SetPath(path) segment := u.LastPathSegment() if string(segment) != expectedSegment { t.Fatalf("unexpected last path segment for path %q: %q. Expecting %q", path, segment, expectedSegment) } } func TestURIPathEscape(t *testing.T) { testURIPathEscape(t, "/foo/bar", "/foo/bar") testURIPathEscape(t, "/f_o-o=b:ar,b.c&q", "/f_o-o=b:ar,b.c&q") testURIPathEscape(t, "/aa?bb.тест~qq", "/aa%3Fbb.%D1%82%D0%B5%D1%81%D1%82~qq") } func testURIPathEscape(t *testing.T, path, expectedRequestURI string) { var u URI u.SetPath(path) requestURI := u.RequestURI() if string(requestURI) != expectedRequestURI { t.Fatalf("unexpected requestURI %q. Expecting %q. path %q", requestURI, expectedRequestURI, path) } } func TestURIUpdate(t *testing.T) { // full uri testURIUpdate(t, "http://foo.bar/baz?aaa=22#aaa", "https://aa.com/bb", "https://aa.com/bb") // empty uri testURIUpdate(t, "http://aaa.com/aaa.html?234=234#add", "", "http://aaa.com/aaa.html?234=234#add") // request uri testURIUpdate(t, "ftp://aaa/xxx/yyy?aaa=bb#aa", "/boo/bar?xx", "ftp://aaa/boo/bar?xx") // relative uri testURIUpdate(t, "http://foo.bar/baz/xxx.html?aaa=22#aaa", "bb.html?xx=12#pp", "http://foo.bar/baz/bb.html?xx=12#pp") testURIUpdate(t, "http://xx/a/b/c/d", "../qwe/p?zx=34", "http://xx/a/b/qwe/p?zx=34") testURIUpdate(t, "https://qqq/aaa.html?foo=bar", "?baz=434&aaa#xcv", "https://qqq/aaa.html?baz=434&aaa#xcv") testURIUpdate(t, "http://foo.bar/baz", "~a/%20b=c,тест?йцу=ке", "http://foo.bar/~a/%20b=c,%D1%82%D0%B5%D1%81%D1%82?йцу=ке") testURIUpdate(t, "http://foo.bar/baz", "/qwe#fragment", "http://foo.bar/qwe#fragment") testURIUpdate(t, "http://foobar/baz/xxx", "aaa.html#bb?cc=dd&ee=dfd", "http://foobar/baz/aaa.html#bb?cc=dd&ee=dfd") // hash testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment") } func testURIUpdate(t *testing.T, base, update, result string) { var u URI u.Parse(nil, []byte(base)) u.Update(update) s := u.String() if s != result { t.Fatalf("unexpected result %q. Expecting %q. base=%q, update=%q", s, result, base, update) } } func TestURIPathNormalize(t *testing.T) { var u URI // double slash testURIPathNormalize(t, &u, "/aa//bb", "/aa/bb") // triple slash testURIPathNormalize(t, &u, "/x///y/", "/x/y/") // multi slashes testURIPathNormalize(t, &u, "/abc//de///fg////", "/abc/de/fg/") // encoded slashes testURIPathNormalize(t, &u, "/xxxx%2fyyy%2f%2F%2F", "/xxxx/yyy/") // dotdot testURIPathNormalize(t, &u, "/aaa/..", "/") // dotdot with trailing slash testURIPathNormalize(t, &u, "/xxx/yyy/../", "/xxx/") // multi dotdots testURIPathNormalize(t, &u, "/aaa/bbb/ccc/../../ddd", "/aaa/ddd") // dotdots separated by other data testURIPathNormalize(t, &u, "/a/b/../c/d/../e/..", "/a/c/") // too many dotdots testURIPathNormalize(t, &u, "/aaa/../../../../xxx", "/xxx") testURIPathNormalize(t, &u, "/../../../../../..", "/") testURIPathNormalize(t, &u, "/../../../../../../", "/") // encoded dotdots testURIPathNormalize(t, &u, "/aaa%2Fbbb%2F%2E.%2Fxxx", "/aaa/xxx") // double slash with dotdots testURIPathNormalize(t, &u, "/aaa////..//b", "/b") // fake dotdot testURIPathNormalize(t, &u, "/aaa/..bbb/ccc/..", "/aaa/..bbb/") // single dot testURIPathNormalize(t, &u, "/a/./b/././c/./d.html", "/a/b/c/d.html") testURIPathNormalize(t, &u, "./foo/", "/foo/") testURIPathNormalize(t, &u, "./../.././../../aaa/bbb/../../../././../", "/") testURIPathNormalize(t, &u, "./a/./.././../b/./foo.html", "/b/foo.html") } func testURIPathNormalize(t *testing.T, u *URI, requestURI, expectedPath string) { u.Parse(nil, []byte(requestURI)) if string(u.Path()) != expectedPath { t.Fatalf("Unexpected path %q. Expected %q. requestURI=%q", u.Path(), expectedPath, requestURI) } } func TestURIFullURI(t *testing.T) { var args Args // empty scheme, path and hash testURIFullURI(t, "", "foobar.com", "", "", &args, "http://foobar.com/") // empty scheme and hash testURIFullURI(t, "", "aa.com", "/foo/bar", "", &args, "http://aa.com/foo/bar") // empty hash testURIFullURI(t, "fTP", "XXx.com", "/foo", "", &args, "ftp://xxx.com/foo") // empty args testURIFullURI(t, "https", "xx.com", "/", "aaa", &args, "https://xx.com/#aaa") // non-empty args and non-ASCII path args.Set("foo", "bar") args.Set("xxx", "йух") testURIFullURI(t, "", "xxx.com", "/тест123", "2er", &args, "http://xxx.com/%D1%82%D0%B5%D1%81%D1%82123?foo=bar&xxx=%D0%B9%D1%83%D1%85#2er") // test with empty args and non-empty query string var u URI u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) uri := u.FullURI() expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq" if string(uri) != expectedURI { t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI) } } func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, expectedURI string) { var u URI u.SetScheme(scheme) u.SetHost(host) u.SetPath(path) u.SetHash(hash) args.CopyTo(u.QueryArgs()) uri := u.FullURI() if string(uri) != expectedURI { t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI) } } func TestURIParseNilHost(t *testing.T) { testURIParseScheme(t, "http://google.com/foo?bar#baz", "http") testURIParseScheme(t, "HTtP://google.com/", "http") testURIParseScheme(t, "://google.com/", "http") testURIParseScheme(t, "fTP://aaa.com", "ftp") testURIParseScheme(t, "httPS://aaa.com", "https") } func testURIParseScheme(t *testing.T, uri, expectedScheme string) { var u URI u.Parse(nil, []byte(uri)) if string(u.Scheme()) != expectedScheme { t.Fatalf("Unexpected scheme %q. Expected %q for uri %q", u.Scheme(), expectedScheme, uri) } } func TestURIParse(t *testing.T) { var u URI // no args testURIParse(t, &u, "aaa", "sdfdsf", "http://aaa/sdfdsf", "aaa", "/sdfdsf", "sdfdsf", "", "") // args testURIParse(t, &u, "xx", "/aa?ss", "http://xx/aa?ss", "xx", "/aa", "/aa", "ss", "") // args and hash testURIParse(t, &u, "foobar.com", "/a.b.c?def=gkl#mnop", "http://foobar.com/a.b.c?def=gkl#mnop", "foobar.com", "/a.b.c", "/a.b.c", "def=gkl", "mnop") // '?' and '#' in hash testURIParse(t, &u, "aaa.com", "/foo#bar?baz=aaa#bbb", "http://aaa.com/foo#bar?baz=aaa#bbb", "aaa.com", "/foo", "/foo", "", "bar?baz=aaa#bbb") // encoded path testURIParse(t, &u, "aa.com", "/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "http://aa.com/Test%20%2B%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "aa.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf") // host in uppercase testURIParse(t, &u, "FOObar.COM", "/bC?De=F#Gh", "http://foobar.com/bC?De=F#Gh", "foobar.com", "/bC", "/bC", "De=F", "Gh") // uri with hostname testURIParse(t, &u, "xxx.com", "http://aaa.com/foo/bar?baz=aaa#ddd", "http://aaa.com/foo/bar?baz=aaa#ddd", "aaa.com", "/foo/bar", "/foo/bar", "baz=aaa", "ddd") testURIParse(t, &u, "xxx.com", "https://ab.com/f/b%20r?baz=aaa#ddd", "https://ab.com/f/b%20r?baz=aaa#ddd", "ab.com", "/f/b r", "/f/b%20r", "baz=aaa", "ddd") // no slash after hostname in uri testURIParse(t, &u, "aaa.com", "http://google.com", "http://google.com/", "google.com", "/", "/", "", "") // uppercase hostname in uri testURIParse(t, &u, "abc.com", "http://GoGLE.com/aaa", "http://gogle.com/aaa", "gogle.com", "/aaa", "/aaa", "", "") // http:// in query params testURIParse(t, &u, "aaa.com", "/foo?bar=http://google.com", "http://aaa.com/foo?bar=http://google.com", "aaa.com", "/foo", "/foo", "bar=http://google.com", "") } func testURIParse(t *testing.T, u *URI, host, uri, expectedURI, expectedHost, expectedPath, expectedPathOriginal, expectedArgs, expectedHash string) { u.Parse([]byte(host), []byte(uri)) if !bytes.Equal(u.FullURI(), []byte(expectedURI)) { t.Fatalf("Unexpected uri %q. Expected %q. host=%q, uri=%q", u.FullURI(), expectedURI, host, uri) } if !bytes.Equal(u.Host(), []byte(expectedHost)) { t.Fatalf("Unexpected host %q. Expected %q. host=%q, uri=%q", u.Host(), expectedHost, host, uri) } if !bytes.Equal(u.PathOriginal(), []byte(expectedPathOriginal)) { t.Fatalf("Unexpected original path %q. Expected %q. host=%q, uri=%q", u.PathOriginal(), expectedPathOriginal, host, uri) } if !bytes.Equal(u.Path(), []byte(expectedPath)) { t.Fatalf("Unexpected path %q. Expected %q. host=%q, uri=%q", u.Path(), expectedPath, host, uri) } if !bytes.Equal(u.QueryString(), []byte(expectedArgs)) { t.Fatalf("Unexpected args %q. Expected %q. host=%q, uri=%q", u.QueryString(), expectedArgs, host, uri) } if !bytes.Equal(u.Hash(), []byte(expectedHash)) { t.Fatalf("Unexpected hash %q. Expected %q. host=%q, uri=%q", u.Hash(), expectedHash, host, uri) } } golang-github-valyala-fasthttp-20160617/uri_timing_test.go000066400000000000000000000022221273074646000235360ustar00rootroot00000000000000package fasthttp import ( "testing" ) func BenchmarkURIParsePath(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar") } func BenchmarkURIParsePathQueryString(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value") } func BenchmarkURIParsePathQueryStringHash(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value#hashstring") } func BenchmarkURIParseHostname(b *testing.B) { benchmarkURIParse(b, "google.com", "http://foobar.com/foo/bar?query=string&other=value#hashstring") } func BenchmarkURIFullURI(b *testing.B) { host := []byte("foobar.com") requestURI := []byte("/foobar/baz?aaa=bbb&ccc=ddd") uriLen := len(host) + len(requestURI) + 7 b.RunParallel(func(pb *testing.PB) { var u URI u.Parse(host, requestURI) for pb.Next() { uri := u.FullURI() if len(uri) != uriLen { b.Fatalf("unexpected uri len %d. Expecting %d", len(uri), uriLen) } } }) } func benchmarkURIParse(b *testing.B, host, uri string) { strHost, strURI := []byte(host), []byte(uri) b.RunParallel(func(pb *testing.PB) { var u URI for pb.Next() { u.Parse(strHost, strURI) } }) } golang-github-valyala-fasthttp-20160617/uri_unix.go000066400000000000000000000003121273074646000221710ustar00rootroot00000000000000// +build !windows package fasthttp func addLeadingSlash(dst, src []byte) []byte { // add leading slash for unix paths if len(src) == 0 || src[0] != '/' { dst = append(dst, '/') } return dst } golang-github-valyala-fasthttp-20160617/uri_windows.go000066400000000000000000000003251273074646000227040ustar00rootroot00000000000000// +build windows package fasthttp func addLeadingSlash(dst, src []byte) []byte { // zero length and "C:/" case if len(src) == 0 || (len(src) > 2 && src[1] != ':') { dst = append(dst, '/') } return dst } golang-github-valyala-fasthttp-20160617/uri_windows_test.go000066400000000000000000000003551273074646000237460ustar00rootroot00000000000000// +build windows package fasthttp import "testing" func TestURIPathNormalizeIssue86(t *testing.T) { // see https://github.com/valyala/fasthttp/issues/86 var u URI testURIPathNormalize(t, &u, `C:\a\b\c\fs.go`, `C:\a\b\c\fs.go`) } golang-github-valyala-fasthttp-20160617/userdata.go000066400000000000000000000021171273074646000221440ustar00rootroot00000000000000package fasthttp import ( "io" ) type userDataKV struct { key []byte value interface{} } type userData []userDataKV func (d *userData) Set(key string, value interface{}) { args := *d n := len(args) for i := 0; i < n; i++ { kv := &args[i] if string(kv.key) == key { kv.value = value return } } c := cap(args) if c > n { args = args[:n+1] kv := &args[n] kv.key = append(kv.key[:0], key...) kv.value = value *d = args return } kv := userDataKV{} kv.key = append(kv.key[:0], key...) kv.value = value *d = append(args, kv) } func (d *userData) SetBytes(key []byte, value interface{}) { d.Set(b2s(key), value) } func (d *userData) Get(key string) interface{} { args := *d n := len(args) for i := 0; i < n; i++ { kv := &args[i] if string(kv.key) == key { return kv.value } } return nil } func (d *userData) GetBytes(key []byte) interface{} { return d.Get(b2s(key)) } func (d *userData) Reset() { args := *d n := len(args) for i := 0; i < n; i++ { v := args[i].value if vc, ok := v.(io.Closer); ok { vc.Close() } } *d = (*d)[:0] } golang-github-valyala-fasthttp-20160617/userdata_test.go000066400000000000000000000026011273074646000232010ustar00rootroot00000000000000package fasthttp import ( "fmt" "reflect" "testing" ) func TestUserData(t *testing.T) { var u userData for i := 0; i < 10; i++ { key := []byte(fmt.Sprintf("key_%d", i)) u.SetBytes(key, i+5) testUserDataGet(t, &u, key, i+5) u.SetBytes(key, i) testUserDataGet(t, &u, key, i) } for i := 0; i < 10; i++ { key := []byte(fmt.Sprintf("key_%d", i)) testUserDataGet(t, &u, key, i) } u.Reset() for i := 0; i < 10; i++ { key := []byte(fmt.Sprintf("key_%d", i)) testUserDataGet(t, &u, key, nil) } } func testUserDataGet(t *testing.T, u *userData, key []byte, value interface{}) { v := u.GetBytes(key) if v == nil && value != nil { t.Fatalf("cannot obtain value for key=%q", key) } if !reflect.DeepEqual(v, value) { t.Fatalf("unexpected value for key=%q: %d. Expecting %d", key, v, value) } } func TestUserDataValueClose(t *testing.T) { var u userData closeCalls := 0 // store values implementing io.Closer for i := 0; i < 5; i++ { key := fmt.Sprintf("key_%d", i) u.Set(key, &closerValue{&closeCalls}) } // store values without io.Closer for i := 0; i < 10; i++ { key := fmt.Sprintf("key_noclose_%d", i) u.Set(key, i) } u.Reset() if closeCalls != 5 { t.Fatalf("unexpected number of Close calls: %d. Expecting 10", closeCalls) } } type closerValue struct { closeCalls *int } func (cv *closerValue) Close() error { (*cv.closeCalls)++ return nil } golang-github-valyala-fasthttp-20160617/userdata_timing_test.go000066400000000000000000000016661273074646000245620ustar00rootroot00000000000000package fasthttp import ( "testing" ) func BenchmarkUserDataCustom(b *testing.B) { keys := []string{"foobar", "baz", "aaa", "bsdfs"} b.RunParallel(func(pb *testing.PB) { var u userData var v interface{} = u for pb.Next() { for _, key := range keys { u.Set(key, v) } for _, key := range keys { vv := u.Get(key) if _, ok := vv.(userData); !ok { b.Fatalf("unexpected value %v for key %q", vv, key) } } u.Reset() } }) } func BenchmarkUserDataStdMap(b *testing.B) { keys := []string{"foobar", "baz", "aaa", "bsdfs"} b.RunParallel(func(pb *testing.PB) { u := make(map[string]interface{}) var v interface{} = u for pb.Next() { for _, key := range keys { u[key] = v } for _, key := range keys { vv := u[key] if _, ok := vv.(map[string]interface{}); !ok { b.Fatalf("unexpected value %v for key %q", vv, key) } } for k := range u { delete(u, k) } } }) } golang-github-valyala-fasthttp-20160617/workerpool.go000066400000000000000000000113131273074646000225350ustar00rootroot00000000000000package fasthttp import ( "net" "runtime" "runtime/debug" "strings" "sync" "time" ) // workerPool serves incoming connections via a pool of workers // in FILO order, i.e. the most recently stopped worker will serve the next // incoming connection. // // Such a scheme keeps CPU caches hot (in theory). type workerPool struct { // Function for serving server connections. // It must leave c unclosed. WorkerFunc func(c net.Conn) error MaxWorkersCount int LogAllErrors bool MaxIdleWorkerDuration time.Duration Logger Logger lock sync.Mutex workersCount int mustStop bool ready []*workerChan stopCh chan struct{} workerChanPool sync.Pool } type workerChan struct { lastUseTime time.Time ch chan net.Conn } func (wp *workerPool) Start() { if wp.stopCh != nil { panic("BUG: workerPool already started") } wp.stopCh = make(chan struct{}) stopCh := wp.stopCh go func() { var scratch []*workerChan for { wp.clean(&scratch) select { case <-stopCh: return default: time.Sleep(wp.getMaxIdleWorkerDuration()) } } }() } func (wp *workerPool) Stop() { if wp.stopCh == nil { panic("BUG: workerPool wasn't started") } close(wp.stopCh) wp.stopCh = nil // Stop all the workers waiting for incoming connections. // Do not wait for busy workers - they will stop after // serving the connection and noticing wp.mustStop = true. wp.lock.Lock() ready := wp.ready for i, ch := range ready { ch.ch <- nil ready[i] = nil } wp.ready = ready[:0] wp.mustStop = true wp.lock.Unlock() } func (wp *workerPool) getMaxIdleWorkerDuration() time.Duration { if wp.MaxIdleWorkerDuration <= 0 { return 10 * time.Second } return wp.MaxIdleWorkerDuration } func (wp *workerPool) clean(scratch *[]*workerChan) { maxIdleWorkerDuration := wp.getMaxIdleWorkerDuration() // Clean least recently used workers if they didn't serve connections // for more than maxIdleWorkerDuration. currentTime := time.Now() wp.lock.Lock() ready := wp.ready n := len(ready) i := 0 for i < n && currentTime.Sub(ready[i].lastUseTime) > maxIdleWorkerDuration { i++ } *scratch = append((*scratch)[:0], ready[:i]...) if i > 0 { m := copy(ready, ready[i:]) for i = m; i < n; i++ { ready[i] = nil } wp.ready = ready[:m] } wp.lock.Unlock() // Notify obsolete workers to stop. // This notification must be outside the wp.lock, since ch.ch // may be blocking and may consume a lot of time if many workers // are located on non-local CPUs. tmp := *scratch for i, ch := range tmp { ch.ch <- nil tmp[i] = nil } } func (wp *workerPool) Serve(c net.Conn) bool { ch := wp.getCh() if ch == nil { return false } ch.ch <- c return true } var workerChanCap = func() int { // Use blocking workerChan if GOMAXPROCS=1. // This immediately switches Serve to WorkerFunc, which results // in higher performance (under go1.5 at least). if runtime.GOMAXPROCS(0) == 1 { return 0 } // Use non-blocking workerChan if GOMAXPROCS>1, // since otherwise the Serve caller (Acceptor) may lag accepting // new connections if WorkerFunc is CPU-bound. return 1 }() func (wp *workerPool) getCh() *workerChan { var ch *workerChan createWorker := false wp.lock.Lock() ready := wp.ready n := len(ready) - 1 if n < 0 { if wp.workersCount < wp.MaxWorkersCount { createWorker = true wp.workersCount++ } } else { ch = ready[n] ready[n] = nil wp.ready = ready[:n] } wp.lock.Unlock() if ch == nil { if !createWorker { return nil } vch := wp.workerChanPool.Get() if vch == nil { vch = &workerChan{ ch: make(chan net.Conn, workerChanCap), } } ch = vch.(*workerChan) go func() { wp.workerFunc(ch) wp.workerChanPool.Put(vch) }() } return ch } func (wp *workerPool) release(ch *workerChan) bool { ch.lastUseTime = time.Now() wp.lock.Lock() if wp.mustStop { wp.lock.Unlock() return false } wp.ready = append(wp.ready, ch) wp.lock.Unlock() return true } func (wp *workerPool) workerFunc(ch *workerChan) { var c net.Conn defer func() { if r := recover(); r != nil { wp.Logger.Printf("panic: %s\nStack trace:\n%s", r, debug.Stack()) if c != nil { c.Close() } } wp.lock.Lock() wp.workersCount-- wp.lock.Unlock() }() var err error for c = range ch.ch { if c == nil { break } if err = wp.WorkerFunc(c); err != nil && err != errHijacked { errStr := err.Error() if wp.LogAllErrors || !(strings.Contains(errStr, "broken pipe") || strings.Contains(errStr, "reset by peer") || strings.Contains(errStr, "i/o timeout")) { wp.Logger.Printf("error when serving connection %q<->%q: %s", c.LocalAddr(), c.RemoteAddr(), err) } } if err != errHijacked { c.Close() } c = nil if !wp.release(ch) { break } } } golang-github-valyala-fasthttp-20160617/workerpool_test.go000066400000000000000000000126361273074646000236050ustar00rootroot00000000000000package fasthttp import ( "fmt" "io/ioutil" "net" "sync/atomic" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) func TestWorkerPoolStartStopSerial(t *testing.T) { testWorkerPoolStartStop(t) } func TestWorkerPoolStartStopConcurrent(t *testing.T) { concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testWorkerPoolStartStop(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testWorkerPoolStartStop(t *testing.T) { wp := &workerPool{ WorkerFunc: func(conn net.Conn) error { return nil }, MaxWorkersCount: 10, Logger: defaultLogger, } for i := 0; i < 10; i++ { wp.Start() wp.Stop() } } func TestWorkerPoolMaxWorkersCountSerial(t *testing.T) { testWorkerPoolMaxWorkersCountMulti(t) } func TestWorkerPoolMaxWorkersCountConcurrent(t *testing.T) { concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testWorkerPoolMaxWorkersCountMulti(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testWorkerPoolMaxWorkersCountMulti(t *testing.T) { for i := 0; i < 5; i++ { testWorkerPoolMaxWorkersCount(t) } } func testWorkerPoolMaxWorkersCount(t *testing.T) { ready := make(chan struct{}) wp := &workerPool{ WorkerFunc: func(conn net.Conn) error { buf := make([]byte, 100) n, err := conn.Read(buf) if err != nil { t.Fatalf("unexpected error: %s", err) } buf = buf[:n] if string(buf) != "foobar" { t.Fatalf("unexpected data read: %q. Expecting %q", buf, "foobar") } if _, err = conn.Write([]byte("baz")); err != nil { t.Fatalf("unexpected error: %s", err) } <-ready return nil }, MaxWorkersCount: 10, Logger: defaultLogger, } wp.Start() ln := fasthttputil.NewInmemoryListener() clientCh := make(chan struct{}, wp.MaxWorkersCount) for i := 0; i < wp.MaxWorkersCount; i++ { go func() { conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } if _, err = conn.Write([]byte("foobar")); err != nil { t.Fatalf("unexpected error: %s", err) } data, err := ioutil.ReadAll(conn) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(data) != "baz" { t.Fatalf("unexpected value read: %q. Expecting %q", data, "baz") } if err = conn.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } clientCh <- struct{}{} }() } for i := 0; i < wp.MaxWorkersCount; i++ { conn, err := ln.Accept() if err != nil { t.Fatalf("unexpected error: %s", err) } if !wp.Serve(conn) { t.Fatalf("worker pool must have enough workers to serve the conn") } } go func() { if _, err := ln.Dial(); err != nil { t.Fatalf("unexpected error: %s", err) } }() conn, err := ln.Accept() if err != nil { t.Fatalf("unexpected error: %s", err) } for i := 0; i < 5; i++ { if wp.Serve(conn) { t.Fatalf("worker pool must be full") } } if err = conn.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } close(ready) for i := 0; i < wp.MaxWorkersCount; i++ { select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } wp.Stop() } func TestWorkerPoolPanicErrorSerial(t *testing.T) { testWorkerPoolPanicErrorMulti(t) } func TestWorkerPoolPanicErrorConcurrent(t *testing.T) { concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testWorkerPoolPanicErrorMulti(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func testWorkerPoolPanicErrorMulti(t *testing.T) { var globalCount uint64 wp := &workerPool{ WorkerFunc: func(conn net.Conn) error { count := atomic.AddUint64(&globalCount, 1) switch count % 3 { case 0: panic("foobar") case 1: return fmt.Errorf("fake error") } return nil }, MaxWorkersCount: 1000, MaxIdleWorkerDuration: time.Millisecond, Logger: &customLogger{}, } for i := 0; i < 10; i++ { testWorkerPoolPanicError(t, wp) } } func testWorkerPoolPanicError(t *testing.T, wp *workerPool) { wp.Start() ln := fasthttputil.NewInmemoryListener() clientsCount := 10 clientCh := make(chan struct{}, clientsCount) for i := 0; i < clientsCount; i++ { go func() { conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } data, err := ioutil.ReadAll(conn) if err != nil { t.Fatalf("unexpected error: %s", err) } if len(data) > 0 { t.Fatalf("unexpected data read: %q. Expecting empty data", data) } if err = conn.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } clientCh <- struct{}{} }() } for i := 0; i < clientsCount; i++ { conn, err := ln.Accept() if err != nil { t.Fatalf("unexpected error: %s", err) } if !wp.Serve(conn) { t.Fatalf("worker pool mustn't be full") } } for i := 0; i < clientsCount; i++ { select { case <-clientCh: case <-time.After(time.Second): t.Fatalf("timeout") } } if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } wp.Stop() }