pax_global_header 0000666 0000000 0000000 00000000064 14130360711 0014506 g ustar 00root root 0000000 0000000 52 comment=81fc96827033a5ee92d8a098ab1cdb9827e1eb8d
fasthttp-1.31.0/ 0000775 0000000 0000000 00000000000 14130360711 0013425 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/.github/ 0000775 0000000 0000000 00000000000 14130360711 0014765 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/.github/workflows/ 0000775 0000000 0000000 00000000000 14130360711 0017022 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/.github/workflows/lint.yml 0000664 0000000 0000000 00000000563 14130360711 0020517 0 ustar 00root root 0000000 0000000 name: Lint
on:
push:
branches:
- master
pull_request:
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
with:
go-version: 1.17.x
- run: go version
- run: diff -u <(echo -n) <(gofmt -d .)
- uses: golangci/golangci-lint-action@v2
with:
version: v1.28.3
fasthttp-1.31.0/.github/workflows/security.yml 0000664 0000000 0000000 00000001031 14130360711 0021407 0 ustar 00root root 0000000 0000000 name: Security
on:
push:
branches:
- master
pull_request:
jobs:
test:
strategy:
matrix:
go-version: [1.17.x]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Install Go
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v2
- name: Security
run: go get github.com/securego/gosec/cmd/gosec; `go env GOPATH`/bin/gosec -exclude=G104,G304 ./...
fasthttp-1.31.0/.github/workflows/test.yml 0000664 0000000 0000000 00000000723 14130360711 0020526 0 ustar 00root root 0000000 0000000 name: Test
on:
push:
branches:
- master
pull_request:
jobs:
test:
strategy:
matrix:
go-version: [1.15.x, 1.16.x, 1.17.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- run: go version
- run: go test ./...
- run: go test -race ./...
fasthttp-1.31.0/.gitignore 0000664 0000000 0000000 00000000101 14130360711 0015405 0 ustar 00root root 0000000 0000000 tags
*.pprof
*.fasthttp.gz
*.fasthttp.br
.idea
.DS_Store
vendor/
fasthttp-1.31.0/LICENSE 0000664 0000000 0000000 00000002206 14130360711 0014432 0 ustar 00root root 0000000 0000000 The MIT License (MIT)
Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
fasthttp-1.31.0/README.md 0000664 0000000 0000000 00000073327 14130360711 0014720 0 ustar 00root root 0000000 0000000 # fasthttp [](http://godoc.org/github.com/valyala/fasthttp) [](https://goreportcard.com/report/github.com/valyala/fasthttp)

Fast HTTP implementation for Go.
# fasthttp might not be for you!
fasthttp was design for some high performance edge cases. **Unless** your server/client needs to handle **thousands of small to medium requests per seconds** and needs a consistent low millisecond response time fasthttp might not be for you. **For most cases `net/http` is much better** as it's easier to use and can handle more cases. For most cases you won't even notice the performance difference.
## General info and links
Currently fasthttp is successfully used by [VertaMedia](https://vertamedia.com/)
in a production serving up to 200K rps from more than 1.5M concurrent keep-alive
connections per physical server.
[TechEmpower Benchmark round 19 results](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext)
[Server Benchmarks](#http-server-performance-comparison-with-nethttp)
[Client Benchmarks](#http-client-comparison-with-nethttp)
[Install](#install)
[Documentation](https://godoc.org/github.com/valyala/fasthttp)
[Examples from docs](https://godoc.org/github.com/valyala/fasthttp#pkg-examples)
[Code examples](examples)
[Awesome fasthttp tools](https://github.com/fasthttp)
[Switching from net/http to fasthttp](#switching-from-nethttp-to-fasthttp)
[Fasthttp best practices](#fasthttp-best-practices)
[Tricks with byte buffers](#tricks-with-byte-buffers)
[Related projects](#related-projects)
[FAQ](#faq)
## HTTP server performance comparison with [net/http](https://golang.org/pkg/net/http/)
In short, fasthttp server is up to 10 times faster than net/http.
Below are benchmark results.
*GOMAXPROCS=1*
net/http server:
```
$ GOMAXPROCS=1 go test -bench=NetHTTPServerGet -benchmem -benchtime=10s
BenchmarkNetHTTPServerGet1ReqPerConn 1000000 12052 ns/op 2297 B/op 29 allocs/op
BenchmarkNetHTTPServerGet2ReqPerConn 1000000 12278 ns/op 2327 B/op 24 allocs/op
BenchmarkNetHTTPServerGet10ReqPerConn 2000000 8903 ns/op 2112 B/op 19 allocs/op
BenchmarkNetHTTPServerGet10KReqPerConn 2000000 8451 ns/op 2058 B/op 18 allocs/op
BenchmarkNetHTTPServerGet1ReqPerConn10KClients 500000 26733 ns/op 3229 B/op 29 allocs/op
BenchmarkNetHTTPServerGet2ReqPerConn10KClients 1000000 23351 ns/op 3211 B/op 24 allocs/op
BenchmarkNetHTTPServerGet10ReqPerConn10KClients 1000000 13390 ns/op 2483 B/op 19 allocs/op
BenchmarkNetHTTPServerGet100ReqPerConn10KClients 1000000 13484 ns/op 2171 B/op 18 allocs/op
```
fasthttp server:
```
$ GOMAXPROCS=1 go test -bench=kServerGet -benchmem -benchtime=10s
BenchmarkServerGet1ReqPerConn 10000000 1559 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet2ReqPerConn 10000000 1248 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10ReqPerConn 20000000 797 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10KReqPerConn 20000000 716 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet1ReqPerConn10KClients 10000000 1974 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet2ReqPerConn10KClients 10000000 1352 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10ReqPerConn10KClients 20000000 789 ns/op 2 B/op 0 allocs/op
BenchmarkServerGet100ReqPerConn10KClients 20000000 604 ns/op 0 B/op 0 allocs/op
```
*GOMAXPROCS=4*
net/http server:
```
$ GOMAXPROCS=4 go test -bench=NetHTTPServerGet -benchmem -benchtime=10s
BenchmarkNetHTTPServerGet1ReqPerConn-4 3000000 4529 ns/op 2389 B/op 29 allocs/op
BenchmarkNetHTTPServerGet2ReqPerConn-4 5000000 3896 ns/op 2418 B/op 24 allocs/op
BenchmarkNetHTTPServerGet10ReqPerConn-4 5000000 3145 ns/op 2160 B/op 19 allocs/op
BenchmarkNetHTTPServerGet10KReqPerConn-4 5000000 3054 ns/op 2065 B/op 18 allocs/op
BenchmarkNetHTTPServerGet1ReqPerConn10KClients-4 1000000 10321 ns/op 3710 B/op 30 allocs/op
BenchmarkNetHTTPServerGet2ReqPerConn10KClients-4 2000000 7556 ns/op 3296 B/op 24 allocs/op
BenchmarkNetHTTPServerGet10ReqPerConn10KClients-4 5000000 3905 ns/op 2349 B/op 19 allocs/op
BenchmarkNetHTTPServerGet100ReqPerConn10KClients-4 5000000 3435 ns/op 2130 B/op 18 allocs/op
```
fasthttp server:
```
$ GOMAXPROCS=4 go test -bench=kServerGet -benchmem -benchtime=10s
BenchmarkServerGet1ReqPerConn-4 10000000 1141 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet2ReqPerConn-4 20000000 707 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10ReqPerConn-4 30000000 341 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10KReqPerConn-4 50000000 310 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet1ReqPerConn10KClients-4 10000000 1119 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet2ReqPerConn10KClients-4 20000000 644 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet10ReqPerConn10KClients-4 30000000 346 ns/op 0 B/op 0 allocs/op
BenchmarkServerGet100ReqPerConn10KClients-4 50000000 282 ns/op 0 B/op 0 allocs/op
```
## HTTP client comparison with net/http
In short, fasthttp client is up to 10 times faster than net/http.
Below are benchmark results.
*GOMAXPROCS=1*
net/http client:
```
$ GOMAXPROCS=1 go test -bench='HTTPClient(Do|GetEndToEnd)' -benchmem -benchtime=10s
BenchmarkNetHTTPClientDoFastServer 1000000 12567 ns/op 2616 B/op 35 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1TCP 200000 67030 ns/op 5028 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd10TCP 300000 51098 ns/op 5031 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd100TCP 300000 45096 ns/op 5026 B/op 55 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1Inmemory 500000 24779 ns/op 5035 B/op 57 allocs/op
BenchmarkNetHTTPClientGetEndToEnd10Inmemory 1000000 26425 ns/op 5035 B/op 57 allocs/op
BenchmarkNetHTTPClientGetEndToEnd100Inmemory 500000 28515 ns/op 5045 B/op 57 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1000Inmemory 500000 39511 ns/op 5096 B/op 56 allocs/op
```
fasthttp client:
```
$ GOMAXPROCS=1 go test -bench='kClient(Do|GetEndToEnd)' -benchmem -benchtime=10s
BenchmarkClientDoFastServer 20000000 865 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1TCP 1000000 18711 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd10TCP 1000000 14664 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd100TCP 1000000 14043 ns/op 1 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1Inmemory 5000000 3965 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd10Inmemory 3000000 4060 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd100Inmemory 5000000 3396 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1000Inmemory 5000000 3306 ns/op 2 B/op 0 allocs/op
```
*GOMAXPROCS=4*
net/http client:
```
$ GOMAXPROCS=4 go test -bench='HTTPClient(Do|GetEndToEnd)' -benchmem -benchtime=10s
BenchmarkNetHTTPClientDoFastServer-4 2000000 8774 ns/op 2619 B/op 35 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1TCP-4 500000 22951 ns/op 5047 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd10TCP-4 1000000 19182 ns/op 5037 B/op 55 allocs/op
BenchmarkNetHTTPClientGetEndToEnd100TCP-4 1000000 16535 ns/op 5031 B/op 55 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1Inmemory-4 1000000 14495 ns/op 5038 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd10Inmemory-4 1000000 10237 ns/op 5034 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd100Inmemory-4 1000000 10125 ns/op 5045 B/op 56 allocs/op
BenchmarkNetHTTPClientGetEndToEnd1000Inmemory-4 1000000 11132 ns/op 5136 B/op 56 allocs/op
```
fasthttp client:
```
$ GOMAXPROCS=4 go test -bench='kClient(Do|GetEndToEnd)' -benchmem -benchtime=10s
BenchmarkClientDoFastServer-4 50000000 397 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1TCP-4 2000000 7388 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd10TCP-4 2000000 6689 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd100TCP-4 3000000 4927 ns/op 1 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1Inmemory-4 10000000 1604 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd10Inmemory-4 10000000 1458 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd100Inmemory-4 10000000 1329 ns/op 0 B/op 0 allocs/op
BenchmarkClientGetEndToEnd1000Inmemory-4 10000000 1316 ns/op 5 B/op 0 allocs/op
```
## Install
```
go get -u github.com/valyala/fasthttp
```
## Switching from net/http to fasthttp
Unfortunately, fasthttp doesn't provide API identical to net/http.
See the [FAQ](#faq) for details.
There is [net/http -> fasthttp handler converter](https://godoc.org/github.com/valyala/fasthttp/fasthttpadaptor),
but it is better to write fasthttp request handlers by hand in order to use
all of the fasthttp advantages (especially high performance :) ).
Important points:
* Fasthttp works with [RequestHandler functions](https://godoc.org/github.com/valyala/fasthttp#RequestHandler)
instead of objects implementing [Handler interface](https://golang.org/pkg/net/http/#Handler).
Fortunately, it is easy to pass bound struct methods to fasthttp:
```go
type MyHandler struct {
foobar string
}
// request handler in net/http style, i.e. method bound to MyHandler struct.
func (h *MyHandler) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
// notice that we may access MyHandler properties here - see h.foobar.
fmt.Fprintf(ctx, "Hello, world! Requested path is %q. Foobar is %q",
ctx.Path(), h.foobar)
}
// request handler in fasthttp style, i.e. just plain function.
func fastHTTPHandler(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hi there! RequestURI is %q", ctx.RequestURI())
}
// pass bound struct method to fasthttp
myHandler := &MyHandler{
foobar: "foobar",
}
fasthttp.ListenAndServe(":8080", myHandler.HandleFastHTTP)
// pass plain function to fasthttp
fasthttp.ListenAndServe(":8081", fastHTTPHandler)
```
* The [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler)
accepts only one argument - [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx).
It contains all the functionality required for http request processing
and response writing. Below is an example of a simple request handler conversion
from net/http to fasthttp.
```go
// net/http request handler
requestHandler := func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/foo":
fooHandler(w, r)
case "/bar":
barHandler(w, r)
default:
http.Error(w, "Unsupported path", http.StatusNotFound)
}
}
```
```go
// the corresponding fasthttp request handler
requestHandler := func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Path()) {
case "/foo":
fooHandler(ctx)
case "/bar":
barHandler(ctx)
default:
ctx.Error("Unsupported path", fasthttp.StatusNotFound)
}
}
```
* Fasthttp allows setting response headers and writing response body
in an arbitrary order. There is no 'headers first, then body' restriction
like in net/http. The following code is valid for fasthttp:
```go
requestHandler := func(ctx *fasthttp.RequestCtx) {
// set some headers and status code first
ctx.SetContentType("foo/bar")
ctx.SetStatusCode(fasthttp.StatusOK)
// then write the first part of body
fmt.Fprintf(ctx, "this is the first part of body\n")
// then set more headers
ctx.Response.Header.Set("Foo-Bar", "baz")
// then write more body
fmt.Fprintf(ctx, "this is the second part of body\n")
// then override already written body
ctx.SetBody([]byte("this is completely new body contents"))
// then update status code
ctx.SetStatusCode(fasthttp.StatusNotFound)
// basically, anything may be updated many times before
// returning from RequestHandler.
//
// Unlike net/http fasthttp doesn't put response to the wire until
// returning from RequestHandler.
}
```
* Fasthttp doesn't provide [ServeMux](https://golang.org/pkg/net/http/#ServeMux),
but there are more powerful third-party routers and web frameworks
with fasthttp support:
* [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing)
* [router](https://github.com/fasthttp/router)
* [lu](https://github.com/vincentLiuxiang/lu)
* [atreugo](https://github.com/savsgio/atreugo)
* [Fiber](https://github.com/gofiber/fiber)
* [Gearbox](https://github.com/gogearbox/gearbox)
Net/http code with simple ServeMux is trivially converted to fasthttp code:
```go
// net/http code
m := &http.ServeMux{}
m.HandleFunc("/foo", fooHandlerFunc)
m.HandleFunc("/bar", barHandlerFunc)
m.Handle("/baz", bazHandler)
http.ListenAndServe(":80", m)
```
```go
// the corresponding fasthttp code
m := func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Path()) {
case "/foo":
fooHandlerFunc(ctx)
case "/bar":
barHandlerFunc(ctx)
case "/baz":
bazHandler.HandlerFunc(ctx)
default:
ctx.Error("not found", fasthttp.StatusNotFound)
}
}
fasthttp.ListenAndServe(":80", m)
```
* net/http -> fasthttp conversion table:
* All the pseudocode below assumes w, r and ctx have these types:
```go
var (
w http.ResponseWriter
r *http.Request
ctx *fasthttp.RequestCtx
)
```
* r.Body -> [ctx.PostBody()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostBody)
* r.URL.Path -> [ctx.Path()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Path)
* r.URL -> [ctx.URI()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.URI)
* r.Method -> [ctx.Method()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Method)
* r.Header -> [ctx.Request.Header](https://godoc.org/github.com/valyala/fasthttp#RequestHeader)
* r.Header.Get() -> [ctx.Request.Header.Peek()](https://godoc.org/github.com/valyala/fasthttp#RequestHeader.Peek)
* r.Host -> [ctx.Host()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Host)
* r.Form -> [ctx.QueryArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.QueryArgs) +
[ctx.PostArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostArgs)
* r.PostForm -> [ctx.PostArgs()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostArgs)
* r.FormValue() -> [ctx.FormValue()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.FormValue)
* r.FormFile() -> [ctx.FormFile()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.FormFile)
* r.MultipartForm -> [ctx.MultipartForm()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.MultipartForm)
* r.RemoteAddr -> [ctx.RemoteAddr()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.RemoteAddr)
* r.RequestURI -> [ctx.RequestURI()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.RequestURI)
* r.TLS -> [ctx.IsTLS()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.IsTLS)
* r.Cookie() -> [ctx.Request.Header.Cookie()](https://godoc.org/github.com/valyala/fasthttp#RequestHeader.Cookie)
* r.Referer() -> [ctx.Referer()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Referer)
* r.UserAgent() -> [ctx.UserAgent()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.UserAgent)
* w.Header() -> [ctx.Response.Header](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader)
* w.Header().Set() -> [ctx.Response.Header.Set()](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader.Set)
* w.Header().Set("Content-Type") -> [ctx.SetContentType()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetContentType)
* w.Header().Set("Set-Cookie") -> [ctx.Response.Header.SetCookie()](https://godoc.org/github.com/valyala/fasthttp#ResponseHeader.SetCookie)
* w.Write() -> [ctx.Write()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Write),
[ctx.SetBody()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBody),
[ctx.SetBodyStream()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBodyStream),
[ctx.SetBodyStreamWriter()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetBodyStreamWriter)
* w.WriteHeader() -> [ctx.SetStatusCode()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.SetStatusCode)
* w.(http.Hijacker).Hijack() -> [ctx.Hijack()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack)
* http.Error() -> [ctx.Error()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Error)
* http.FileServer() -> [fasthttp.FSHandler()](https://godoc.org/github.com/valyala/fasthttp#FSHandler),
[fasthttp.FS](https://godoc.org/github.com/valyala/fasthttp#FS)
* http.ServeFile() -> [fasthttp.ServeFile()](https://godoc.org/github.com/valyala/fasthttp#ServeFile)
* http.Redirect() -> [ctx.Redirect()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Redirect)
* http.NotFound() -> [ctx.NotFound()](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.NotFound)
* http.StripPrefix() -> [fasthttp.PathRewriteFunc](https://godoc.org/github.com/valyala/fasthttp#PathRewriteFunc)
* *VERY IMPORTANT!* Fasthttp disallows holding references
to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx) or to its'
members after returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler).
Otherwise [data races](http://blog.golang.org/race-detector) are inevitable.
Carefully inspect all the net/http request handlers converted to fasthttp whether
they retain references to RequestCtx or to its' members after returning.
RequestCtx provides the following _band aids_ for this case:
* Wrap RequestHandler into [TimeoutHandler](https://godoc.org/github.com/valyala/fasthttp#TimeoutHandler).
* Call [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError)
before returning from RequestHandler if there are references to RequestCtx or to its' members.
See [the example](https://godoc.org/github.com/valyala/fasthttp#example-RequestCtx-TimeoutError)
for more details.
Use this brilliant tool - [race detector](http://blog.golang.org/race-detector) -
for detecting and eliminating data races in your program. If you detected
data race related to fasthttp in your program, then there is high probability
you forgot calling [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError)
before returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler).
* Blind switching from net/http to fasthttp won't give you performance boost.
While fasthttp is optimized for speed, its' performance may be easily saturated
by slow [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler).
So [profile](http://blog.golang.org/profiling-go-programs) and optimize your
code after switching to fasthttp. For instance, use [quicktemplate](https://github.com/valyala/quicktemplate)
instead of [html/template](https://golang.org/pkg/html/template/).
* See also [fasthttputil](https://godoc.org/github.com/valyala/fasthttp/fasthttputil),
[fasthttpadaptor](https://godoc.org/github.com/valyala/fasthttp/fasthttpadaptor) and
[expvarhandler](https://godoc.org/github.com/valyala/fasthttp/expvarhandler).
## Performance optimization tips for multi-core systems
* Use [reuseport](https://godoc.org/github.com/valyala/fasthttp/reuseport) listener.
* Run a separate server instance per CPU core with GOMAXPROCS=1.
* Pin each server instance to a separate CPU core using [taskset](http://linux.die.net/man/1/taskset).
* Ensure the interrupts of multiqueue network card are evenly distributed between CPU cores.
See [this article](https://blog.cloudflare.com/how-to-achieve-low-latency/) for details.
* Use the latest version of Go as each version contains performance improvements.
## Fasthttp best practices
* Do not allocate objects and `[]byte` buffers - just reuse them as much
as possible. Fasthttp API design encourages this.
* [sync.Pool](https://golang.org/pkg/sync/#Pool) is your best friend.
* [Profile your program](http://blog.golang.org/profiling-go-programs)
in production.
`go tool pprof --alloc_objects your-program mem.pprof` usually gives better
insights for optimization opportunities than `go tool pprof your-program cpu.pprof`.
* Write [tests and benchmarks](https://golang.org/pkg/testing/) for hot paths.
* Avoid conversion between `[]byte` and `string`, since this may result in memory
allocation+copy. Fasthttp API provides functions for both `[]byte` and `string` -
use these functions instead of converting manually between `[]byte` and `string`.
There are some exceptions - see [this wiki page](https://github.com/golang/go/wiki/CompilerOptimizations#string-and-byte)
for more details.
* Verify your tests and production code under
[race detector](https://golang.org/doc/articles/race_detector.html) on a regular basis.
* Prefer [quicktemplate](https://github.com/valyala/quicktemplate) instead of
[html/template](https://golang.org/pkg/html/template/) in your webserver.
## Tricks with `[]byte` buffers
The following tricks are used by fasthttp. Use them in your code too.
* Standard Go functions accept nil buffers
```go
var (
// both buffers are uninitialized
dst []byte
src []byte
)
dst = append(dst, src...) // is legal if dst is nil and/or src is nil
copy(dst, src) // is legal if dst is nil and/or src is nil
(string(src) == "") // is true if src is nil
(len(src) == 0) // is true if src is nil
src = src[:0] // works like a charm with nil src
// this for loop doesn't panic if src is nil
for i, ch := range src {
doSomething(i, ch)
}
```
So throw away nil checks for `[]byte` buffers from you code. For example,
```go
srcLen := 0
if src != nil {
srcLen = len(src)
}
```
becomes
```go
srcLen := len(src)
```
* String may be appended to `[]byte` buffer with `append`
```go
dst = append(dst, "foobar"...)
```
* `[]byte` buffer may be extended to its' capacity.
```go
buf := make([]byte, 100)
a := buf[:10] // len(a) == 10, cap(a) == 100.
b := a[:100] // is valid, since cap(a) == 100.
```
* All fasthttp functions accept nil `[]byte` buffer
```go
statusCode, body, err := fasthttp.Get(nil, "http://google.com/")
uintBuf := fasthttp.AppendUint(nil, 1234)
```
## Related projects
* [fasthttp](https://github.com/fasthttp) - various useful
helpers for projects based on fasthttp.
* [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) - fast and
powerful routing package for fasthttp servers.
* [http2](https://github.com/dgrr/http2) - HTTP/2 implementation for fasthttp.
* [router](https://github.com/fasthttp/router) - a high
performance fasthttp request router that scales well.
* [fastws](https://github.com/fasthttp/fastws) - Bloatless WebSocket package made for fasthttp
to handle Read/Write operations concurrently.
* [gramework](https://github.com/gramework/gramework) - a web framework made by one of fasthttp maintainers
* [lu](https://github.com/vincentLiuxiang/lu) - a high performance
go middleware web framework which is based on fasthttp.
* [websocket](https://github.com/fasthttp/websocket) - Gorilla-based
websocket implementation for fasthttp.
* [websocket](https://github.com/dgrr/websocket) - Event-based high-performance WebSocket library for zero-allocation
websocket servers and clients.
* [fasthttpsession](https://github.com/phachon/fasthttpsession) - a fast and powerful session package for fasthttp servers.
* [atreugo](https://github.com/savsgio/atreugo) - High performance and extensible micro web framework with zero memory allocations in hot paths.
* [kratgo](https://github.com/savsgio/kratgo) - Simple, lightweight and ultra-fast HTTP Cache to speed up your websites.
* [kit-plugins](https://github.com/wencan/kit-plugins/tree/master/transport/fasthttp) - go-kit transport implementation for fasthttp.
* [Fiber](https://github.com/gofiber/fiber) - An Expressjs inspired web framework running on Fasthttp
* [Gearbox](https://github.com/gogearbox/gearbox) - :gear: gearbox is a web framework written in Go with a focus on high performance and memory optimization
## FAQ
* *Why creating yet another http package instead of optimizing net/http?*
Because net/http API limits many optimization opportunities.
For example:
* net/http Request object lifetime isn't limited by request handler execution
time. So the server must create a new request object per each request instead
of reusing existing objects like fasthttp does.
* net/http headers are stored in a `map[string][]string`. So the server
must parse all the headers, convert them from `[]byte` to `string` and put
them into the map before calling user-provided request handler.
This all requires unnecessary memory allocations avoided by fasthttp.
* net/http client API requires creating a new response object per each request.
* *Why fasthttp API is incompatible with net/http?*
Because net/http API limits many optimization opportunities. See the answer
above for more details. Also certain net/http API parts are suboptimal
for use:
* Compare [net/http connection hijacking](https://golang.org/pkg/net/http/#Hijacker)
to [fasthttp connection hijacking](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack).
* Compare [net/http Request.Body reading](https://golang.org/pkg/net/http/#Request)
to [fasthttp request body reading](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.PostBody).
* *Why fasthttp doesn't support HTTP/2.0 and WebSockets?*
[HTTP/2.0 support](https://github.com/fasthttp/http2) is in progress. [WebSockets](https://github.com/fasthttp/websockets) has been done already.
Third parties also may use [RequestCtx.Hijack](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack)
for implementing these goodies.
* *Are there known net/http advantages comparing to fasthttp?*
Yes:
* net/http supports [HTTP/2.0 starting from go1.6](https://http2.golang.org/).
* net/http API is stable, while fasthttp API constantly evolves.
* net/http handles more HTTP corner cases.
* net/http can stream both request and response bodies
* net/http can handle bigger bodies as it doesn't read the whole body into memory
* net/http should contain less bugs, since it is used and tested by much
wider audience.
* *Why fasthttp API prefers returning `[]byte` instead of `string`?*
Because `[]byte` to `string` conversion isn't free - it requires memory
allocation and copy. Feel free wrapping returned `[]byte` result into
`string()` if you prefer working with strings instead of byte slices.
But be aware that this has non-zero overhead.
* *Which GO versions are supported by fasthttp?*
Go 1.15.x. Older versions won't be supported.
* *Please provide real benchmark data and server information*
See [this issue](https://github.com/valyala/fasthttp/issues/4).
* *Are there plans to add request routing to fasthttp?*
There are no plans to add request routing into fasthttp.
Use third-party routers and web frameworks with fasthttp support:
* [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing)
* [router](https://github.com/fasthttp/router)
* [gramework](https://github.com/gramework/gramework)
* [lu](https://github.com/vincentLiuxiang/lu)
* [atreugo](https://github.com/savsgio/atreugo)
* [Fiber](https://github.com/gofiber/fiber)
* [Gearbox](https://github.com/gogearbox/gearbox)
See also [this issue](https://github.com/valyala/fasthttp/issues/9) for more info.
* *I detected data race in fasthttp!*
Cool! [File a bug](https://github.com/valyala/fasthttp/issues/new). But before
doing this check the following in your code:
* Make sure there are no references to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx)
or to its' members after returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler).
* Make sure you call [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError)
before returning from [RequestHandler](https://godoc.org/github.com/valyala/fasthttp#RequestHandler)
if there are references to [RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx)
or to its' members, which may be accessed by other goroutines.
* *I didn't find an answer for my question here*
Try exploring [these questions](https://github.com/valyala/fasthttp/issues?q=label%3Aquestion).
fasthttp-1.31.0/SECURITY.md 0000664 0000000 0000000 00000015110 14130360711 0015214 0 ustar 00root root 0000000 0000000 ### TL;DR
We use a simplified version of [Golang Security Policy](https://golang.org/security).
For example, for now we skip CVE assignment.
### Reporting a Security Bug
Please report to us any issues you find. This document explains how to do that and what to expect in return.
All security bugs in our releases should be reported by email to oss-security@highload.solutions.
This mail is delivered to a small security team.
Your email will be acknowledged within 24 hours, and you'll receive a more detailed response
to your email within 72 hours indicating the next steps in handling your report.
For critical problems, you can encrypt your report using our PGP key (listed below).
Please use a descriptive subject line for your report email.
After the initial reply to your report, the security team will
endeavor to keep you informed of the progress being made towards a fix and full announcement.
These updates will be sent at least every five days.
In reality, this is more likely to be every 24-48 hours.
If you have not received a reply to your email within 48 hours or you have not heard from the security
team for the past five days please contact us by email to developers@highload.solutions or by Telegram message
to [our support](https://t.me/highload_support).
Please note that developers@highload.solutions list includes all developers, who may be outside our opensource security team.
When escalating on this list, please do not disclose the details of the issue.
Simply state that you're trying to reach a member of the security team.
### Flagging Existing Issues as Security-related
If you believe that an existing issue is security-related, we ask that you send an email to oss-security@highload.solutions.
The email should include the issue ID and a short description of why it should be handled according to this security policy.
### Disclosure Process
Our project uses the following disclosure process:
- Once the security report is received it is assigned a primary handler. This person coordinates the fix and release process.
- The issue is confirmed and a list of affected software is determined.
- Code is audited to find any potential similar problems.
- Fixes are prepared for the two most recent major releases and the head/master revision. These fixes are not yet committed to the public repository.
- To notify users, a new issue without security details is submitted to our GitHub repository.
- Three working days following this notification, the fixes are applied to the public repository and a new release is issued.
- On the date that the fixes are applied, announcement is published in the issue.
This process can take some time, especially when coordination is required with maintainers of other projects.
Every effort will be made to handle the bug in as timely a manner as possible, however it's important that we follow
the process described above to ensure that disclosures are handled consistently.
### Receiving Security Updates
The best way to receive security announcements is to subscribe ("Watch") to our repository.
Any GitHub issues pertaining to a security issue will be prefixed with [security].
### Comments on This Policy
If you have any suggestions to improve this policy, please send an email to oss-security@highload.solutions for discussion.
### PGP Key for oss-security@highload.ltd
We accept PGP-encrypted email, but the majority of the security team are not regular PGP users
so it's somewhat inconvenient. Please only use PGP for critical security reports.
```
-----BEGIN PGP PUBLIC KEY BLOCK-----
mQINBFzdjYUBEACa3YN+QVSlnXofUjxr+YrmIaF+da0IUq+TRM4aqUXALsemEdGh
yIl7Z6qOOy1d2kPe6t//H9l/92lJ1X7i6aEBK4n/pnPZkwbpy9gGpebgvTZFvcbe
mFhF6k1FM35D8TxneJSjizPyGhJPqcr5qccqf8R64TlQx5Ud1JqT2l8P1C5N7gNS
lEYXq1h4zBCvTWk1wdeLRRPx7Bn6xrgmyu/k61dLoJDvpvWNATVFDA67oTrPgzTW
xtLbbk/xm0mK4a8zMzIpNyz1WkaJW9+4HFXaL+yKlsx7iHe2O7VlGoqS0kdeQup4
1HIw/P7yc0jBlNMLUzpuA6ElYUwESWsnCI71YY1x4rKgI+GqH1mWwgn7tteuXQtb
Zj0vEdjK3IKIOSbzbzAvSbDt8F1+o7EMtdy1eUysjKSQgFkDlT6JRmYvEup5/IoG
iknh/InQq9RmGFKii6pXWWoltC0ebfCwYOXvymyDdr/hYDqJeHS9Tenpy86Doaaf
HGf5nIFAMB2G5ctNpBwzNXR2MAWkeHQgdr5a1xmog0hS125usjnUTet3QeCyo4kd
gVouoOroMcqFFUXdYaMH4c3KWz0afhTmIaAsFFOv/eMdadVA4QyExTJf3TAoQ+kH
lKDlbOAIxEZWRPDFxMRixaVPQC+VxhBcaQ+yNoaUkM0V2m8u8sDBpzi1OQARAQAB
tDxPU1MgU2VjdXJpdHksIEhpZ2hsb2FkIExURCA8b3NzLXNlY3VyaXR5QGhpZ2hs
b2FkLnNvbHV0aW9ucz6JAlQEEwEIAD4WIQRljYp380uKq2g8TeqsQcvu+Qp2TAUC
XN2NhQIbAwUJB4YfgAULCQgHAgYVCgkICwIEFgIDAQIeAQIXgAAKCRCsQcvu+Qp2
TKmED/96YoQoOjD28blFFrigvAsiNcNNZoX9I0dX1lNpD83fBJf+/9i+x4jqUnI5
5XK/DFTDbhpw8kQBpxS9eEuIYnuo0RdLLp1ctNWTlpwfyHn92mGddl/uBdYHUuUk
cjhIQcFaCcWRY+EpamDlv1wmZ83IwBr8Hu5FS+/Msyw1TBvtTRVKW1KoGYMYoXLk
BzIglRPwn821B6s4BvK/RJnZkrmHMBZBfYMf+iSMSYd2yPmfT8wbcAjgjLfQa28U
gbt4u9xslgKjuM83IqwFfEXBnm7su3OouGWqc+62mQTsbnK65zRFnx6GXRXC1BAi
6m9Tm1PU0IiINz66ainquspkXYeHjd9hTwfR3BdFnzBTRRM01cKMFabWbLj8j0p8
fF4g9cxEdiLrzEF7Yz4WY0mI4Cpw4eJZfsHMc07Jn7QxfJhIoq+rqBOtEmTjnxMh
aWeykoXMHlZN4K0ZrAytozVH1D4bugWA9Zuzi9U3F9hrVVABm11yyhd2iSqI6/FR
GcCFOCBW1kEJbzoEguub+BV8LDi8ldljHalvur5k/VFhoDBxniYNsKmiCLVCmDWs
/nF84hCReAOJt0vDGwqHe3E2BFFPbKwdJLRNkjxBY0c/pvaV+JxbWQmaxDZNeIFV
hFcVGp48HNY3qLWZdsQIfT9m1masJFLVuq8Wx7bYs8Et5eFnH7kCDQRc3Y2FARAA
2DJWAxABydyIdCxgFNdqnYyWS46vh2DmLmRMqgasNlD0ozG4S9bszBsgnUI2Xs06
J76kFRh8MMHcu9I4lUKCQzfrA4uHkiOK5wvNCaWP+C6JUYNHsqPwk/ILO3gtQ/Ws
LLf/PW3rJZVOZB+WY8iaYc20l5vukTaVw4qbEi9dtLkJvVpNHt//+jayXU6s3ew1
2X5xdwyAZxaxlnzFaY/Xo/qR+bZhVFC0T9pAECnHv9TVhFGp0JE9ipPGnro5xTIS
LttdAkzv4AuSVTIgWgTkh8nN8t7STJqfPEv0I12nmmYHMUyTYOurkfskF3jY2x6x
8l02NQ4d5KdC3ReV1j51swrGcZCwsWNp51jnEXKwo+B0NM5OmoRrNJgF2iDgLehs
hP00ljU7cB8/1/7kdHZStYaUHICFOFqHzg415FlYm+jpY0nJp/b9BAO0d0/WYnEe
Xjihw8EVBAqzEt4kay1BQonZAypeYnGBJr7vNvdiP+mnRwly5qZSGiInxGvtZZFt
zL1E3osiF+muQxFcM63BeGdJeYXy+MoczkWa4WNggfcHlGAZkMYiv28zpr4PfrK9
mvj4Nu8s71PE9pPpBoZcNDf9v1sHuu96jDSITsPx5YMvvKZWhzJXFKzk6YgAsNH/
MF0G+/qmKJZpCdvtHKpYM1uHX85H81CwWJFfBPthyD8AEQEAAYkCPAQYAQgAJhYh
BGWNinfzS4qraDxN6qxBy+75CnZMBQJc3Y2FAhsMBQkHhh+AAAoJEKxBy+75CnZM
Rn8P/RyL1bhU4Q4WpvmlkepCAwNA0G3QvnKcSZNHEPE5h7H3IyrA/qy16A9eOsgm
sthsHYlo5A5lRIy4wPHkFCClMrMHdKuoS72//qgw+oOrBcwb7Te+Nas+ewhaJ7N9
vAX06vDH9bLl52CPbtats5+eBpePgP3HDPxd7CWHxq9bzJTbzqsTkN7JvoovR2dP
itPJDij7QYLYVEM1t7QxUVpVwAjDi/kCtC9ts5L+V0snF2n3bHZvu04EXdpvxOQI
pG/7Q+/WoI8NU6Bb/FA3tJGYIhSwI3SY+5XV/TAZttZaYSh2SD8vhc+eo+gW9sAN
xa+VESBQCht9+tKIwEwHs1efoRgFdbwwJ2c+33+XydQ6yjdXoX1mn2uyCr82jorZ
xTzbkY04zr7oZ+0fLpouOFg/mrSL4w2bWEhdHuyoVthLBjnRme0wXCaS3g3mYdLG
nSUkogOGOOvvvBtoq/vfx0Eu79piUtw5D8yQSrxLDuz8GxCrVRZ0tYIHb26aTE9G
cDsW/Lg5PjcY/LgVNEWOxDQDFVurlImnlVJFb3q+NrWvPbgeIEWwJDCay/z25SEH
k3bSOXLp8YGRnlkWUmoeL4g/CCK52iAAlfscZNoKMILhBnbCoD657jpa5GQKJj/U
Q8kjgr7kwV/RSosNV9HCPj30mVyiCQ1xg+ZLzMKXVCuBWd+G
=lnt2
-----END PGP PUBLIC KEY BLOCK-----
```
fasthttp-1.31.0/TODO 0000664 0000000 0000000 00000000305 14130360711 0014113 0 ustar 00root root 0000000 0000000 - SessionClient with referer and cookies support.
- ProxyHandler similar to FSHandler.
- WebSockets. See https://tools.ietf.org/html/rfc6455 .
- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 .
fasthttp-1.31.0/allocation_test.go 0000664 0000000 0000000 00000003144 14130360711 0017142 0 ustar 00root root 0000000 0000000 //go:build !race
// +build !race
package fasthttp
import (
"net"
"testing"
)
func TestAllocationServeConn(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
rw := &readWriter{}
// Make space for the request and response here so it
// doesn't allocate within the test.
rw.r.Grow(1024)
rw.w.Grow(1024)
n := testing.AllocsPerRun(100, func() {
rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\nCookie: foo=bar\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatal(err)
}
// Reset the write buffer to make space for the next response.
rw.w.Reset()
})
if n != 0 {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
func TestAllocationClient(t *testing.T) {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("cannot listen: %s", err)
}
defer ln.Close()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{}
url := "http://test:test@" + ln.Addr().String() + "/foo?bar=baz"
n := testing.AllocsPerRun(100, func() {
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI(url)
if err := c.Do(req, res); err != nil {
t.Fatal(err)
}
ReleaseRequest(req)
ReleaseResponse(res)
})
if n != 0 {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
func TestAllocationURI(t *testing.T) {
uri := []byte("http://username:password@hello.%e4%b8%96%e7%95%8c.com/some/path?foo=bar#test")
n := testing.AllocsPerRun(100, func() {
u := AcquireURI()
u.Parse(nil, uri) //nolint:errcheck
ReleaseURI(u)
})
if n != 0 {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
fasthttp-1.31.0/args.go 0000664 0000000 0000000 00000033110 14130360711 0014706 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"errors"
"io"
"sort"
"sync"
"github.com/valyala/bytebufferpool"
)
const (
argsNoValue = true
argsHasValue = false
)
// AcquireArgs returns an empty Args object from the pool.
//
// The returned Args may be returned to the pool with ReleaseArgs
// when no longer needed. This allows reducing GC load.
func AcquireArgs() *Args {
return argsPool.Get().(*Args)
}
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
//
// Do not access the released Args object, otherwise data races may occur.
func ReleaseArgs(a *Args) {
a.Reset()
argsPool.Put(a)
}
var argsPool = &sync.Pool{
New: func() interface{} {
return &Args{}
},
}
// Args represents query arguments.
//
// It is forbidden copying Args instances. Create new instances instead
// and use CopyTo().
//
// Args instance MUST NOT be used from concurrently running goroutines.
type Args struct {
noCopy noCopy //nolint:unused,structcheck
args []argsKV
buf []byte
}
type argsKV struct {
key []byte
value []byte
noValue bool
}
// Reset clears query args.
func (a *Args) Reset() {
a.args = a.args[:0]
}
// CopyTo copies all args to dst.
func (a *Args) CopyTo(dst *Args) {
dst.Reset()
dst.args = copyArgs(dst.args, a.args)
}
// VisitAll calls f for each existing arg.
//
// f must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (a *Args) VisitAll(f func(key, value []byte)) {
visitArgs(a.args, f)
}
// Len returns the number of query args.
func (a *Args) Len() int {
return len(a.args)
}
// Parse parses the given string containing query args.
func (a *Args) Parse(s string) {
a.buf = append(a.buf[:0], s...)
a.ParseBytes(a.buf)
}
// ParseBytes parses the given b containing query args.
func (a *Args) ParseBytes(b []byte) {
a.Reset()
var s argsScanner
s.b = b
var kv *argsKV
a.args, kv = allocArg(a.args)
for s.next(kv) {
if len(kv.key) > 0 || len(kv.value) > 0 {
a.args, kv = allocArg(a.args)
}
}
a.args = releaseArg(a.args)
}
// String returns string representation of query args.
func (a *Args) String() string {
return string(a.QueryString())
}
// QueryString returns query string for the args.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) QueryString() []byte {
a.buf = a.AppendBytes(a.buf[:0])
return a.buf
}
// Sort sorts Args by key and then value using 'f' as comparison function.
//
// For example args.Sort(bytes.Compare)
func (a *Args) Sort(f func(x, y []byte) int) {
sort.SliceStable(a.args, func(i, j int) bool {
n := f(a.args[i].key, a.args[j].key)
if n == 0 {
return f(a.args[i].value, a.args[j].value) == -1
}
return n == -1
})
}
// AppendBytes appends query string to dst and returns the extended dst.
func (a *Args) AppendBytes(dst []byte) []byte {
for i, n := 0, len(a.args); i < n; i++ {
kv := &a.args[i]
dst = AppendQuotedArg(dst, kv.key)
if !kv.noValue {
dst = append(dst, '=')
if len(kv.value) > 0 {
dst = AppendQuotedArg(dst, kv.value)
}
}
if i+1 < n {
dst = append(dst, '&')
}
}
return dst
}
// WriteTo writes query string to w.
//
// WriteTo implements io.WriterTo interface.
func (a *Args) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(a.QueryString())
return int64(n), err
}
// Del deletes argument with the given key from query args.
func (a *Args) Del(key string) {
a.args = delAllArgs(a.args, key)
}
// DelBytes deletes argument with the given key from query args.
func (a *Args) DelBytes(key []byte) {
a.args = delAllArgs(a.args, b2s(key))
}
// Add adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) Add(key, value string) {
a.args = appendArg(a.args, key, value, argsHasValue)
}
// AddBytesK adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesK(key []byte, value string) {
a.args = appendArg(a.args, b2s(key), value, argsHasValue)
}
// AddBytesV adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesV(key string, value []byte) {
a.args = appendArg(a.args, key, b2s(value), argsHasValue)
}
// AddBytesKV adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesKV(key, value []byte) {
a.args = appendArg(a.args, b2s(key), b2s(value), argsHasValue)
}
// AddNoValue adds only 'key' as argument without the '='.
//
// Multiple values for the same key may be added.
func (a *Args) AddNoValue(key string) {
a.args = appendArg(a.args, key, "", argsNoValue)
}
// AddBytesKNoValue adds only 'key' as argument without the '='.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesKNoValue(key []byte) {
a.args = appendArg(a.args, b2s(key), "", argsNoValue)
}
// Set sets 'key=value' argument.
func (a *Args) Set(key, value string) {
a.args = setArg(a.args, key, value, argsHasValue)
}
// SetBytesK sets 'key=value' argument.
func (a *Args) SetBytesK(key []byte, value string) {
a.args = setArg(a.args, b2s(key), value, argsHasValue)
}
// SetBytesV sets 'key=value' argument.
func (a *Args) SetBytesV(key string, value []byte) {
a.args = setArg(a.args, key, b2s(value), argsHasValue)
}
// SetBytesKV sets 'key=value' argument.
func (a *Args) SetBytesKV(key, value []byte) {
a.args = setArgBytes(a.args, key, value, argsHasValue)
}
// SetNoValue sets only 'key' as argument without the '='.
//
// Only key in argumemt, like key1&key2
func (a *Args) SetNoValue(key string) {
a.args = setArg(a.args, key, "", argsNoValue)
}
// SetBytesKNoValue sets 'key' argument.
func (a *Args) SetBytesKNoValue(key []byte) {
a.args = setArg(a.args, b2s(key), "", argsNoValue)
}
// Peek returns query arg value for the given key.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) Peek(key string) []byte {
return peekArgStr(a.args, key)
}
// PeekBytes returns query arg value for the given key.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) PeekBytes(key []byte) []byte {
return peekArgBytes(a.args, key)
}
// PeekMulti returns all the arg values for the given key.
func (a *Args) PeekMulti(key string) [][]byte {
var values [][]byte
a.VisitAll(func(k, v []byte) {
if string(k) == key {
values = append(values, v)
}
})
return values
}
// PeekMultiBytes returns all the arg values for the given key.
func (a *Args) PeekMultiBytes(key []byte) [][]byte {
return a.PeekMulti(b2s(key))
}
// Has returns true if the given key exists in Args.
func (a *Args) Has(key string) bool {
return hasArg(a.args, key)
}
// HasBytes returns true if the given key exists in Args.
func (a *Args) HasBytes(key []byte) bool {
return hasArg(a.args, b2s(key))
}
// ErrNoArgValue is returned when Args value with the given key is missing.
var ErrNoArgValue = errors.New("no Args value for the given key")
// GetUint returns uint value for the given key.
func (a *Args) GetUint(key string) (int, error) {
value := a.Peek(key)
if len(value) == 0 {
return -1, ErrNoArgValue
}
return ParseUint(value)
}
// SetUint sets uint value for the given key.
func (a *Args) SetUint(key string, value int) {
bb := bytebufferpool.Get()
bb.B = AppendUint(bb.B[:0], value)
a.SetBytesV(key, bb.B)
bytebufferpool.Put(bb)
}
// SetUintBytes sets uint value for the given key.
func (a *Args) SetUintBytes(key []byte, value int) {
a.SetUint(b2s(key), value)
}
// GetUintOrZero returns uint value for the given key.
//
// Zero (0) is returned on error.
func (a *Args) GetUintOrZero(key string) int {
n, err := a.GetUint(key)
if err != nil {
n = 0
}
return n
}
// GetUfloat returns ufloat value for the given key.
func (a *Args) GetUfloat(key string) (float64, error) {
value := a.Peek(key)
if len(value) == 0 {
return -1, ErrNoArgValue
}
return ParseUfloat(value)
}
// GetUfloatOrZero returns ufloat value for the given key.
//
// Zero (0) is returned on error.
func (a *Args) GetUfloatOrZero(key string) float64 {
f, err := a.GetUfloat(key)
if err != nil {
f = 0
}
return f
}
// GetBool returns boolean value for the given key.
//
// true is returned for "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes",
// otherwise false is returned.
func (a *Args) GetBool(key string) bool {
switch b2s(a.Peek(key)) {
// Support the same true cases as strconv.ParseBool
// See: https://github.com/golang/go/blob/4e1b11e2c9bdb0ddea1141eed487be1a626ff5be/src/strconv/atob.go#L12
// and Y and Yes versions.
case "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes":
return true
default:
return false
}
}
func visitArgs(args []argsKV, f func(k, v []byte)) {
for i, n := 0, len(args); i < n; i++ {
kv := &args[i]
f(kv.key, kv.value)
}
}
func copyArgs(dst, src []argsKV) []argsKV {
if cap(dst) < len(src) {
tmp := make([]argsKV, len(src))
dst = dst[:cap(dst)] // copy all of dst.
copy(tmp, dst)
for i := len(dst); i < len(tmp); i++ {
// Make sure nothing is nil.
tmp[i].key = []byte{}
tmp[i].value = []byte{}
}
dst = tmp
}
n := len(src)
dst = dst[:n]
for i := 0; i < n; i++ {
dstKV := &dst[i]
srcKV := &src[i]
dstKV.key = append(dstKV.key[:0], srcKV.key...)
if srcKV.noValue {
dstKV.value = dstKV.value[:0]
} else {
dstKV.value = append(dstKV.value[:0], srcKV.value...)
}
dstKV.noValue = srcKV.noValue
}
return dst
}
func delAllArgsBytes(args []argsKV, key []byte) []argsKV {
return delAllArgs(args, b2s(key))
}
func delAllArgs(args []argsKV, key string) []argsKV {
for i, n := 0, len(args); i < n; i++ {
kv := &args[i]
if key == string(kv.key) {
tmp := *kv
copy(args[i:], args[i+1:])
n--
i--
args[n] = tmp
args = args[:n]
}
}
return args
}
func setArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV {
return setArg(h, b2s(key), b2s(value), noValue)
}
func setArg(h []argsKV, key, value string, noValue bool) []argsKV {
n := len(h)
for i := 0; i < n; i++ {
kv := &h[i]
if key == string(kv.key) {
if noValue {
kv.value = kv.value[:0]
} else {
kv.value = append(kv.value[:0], value...)
}
kv.noValue = noValue
return h
}
}
return appendArg(h, key, value, noValue)
}
func appendArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV {
return appendArg(h, b2s(key), b2s(value), noValue)
}
func appendArg(args []argsKV, key, value string, noValue bool) []argsKV {
var kv *argsKV
args, kv = allocArg(args)
kv.key = append(kv.key[:0], key...)
if noValue {
kv.value = kv.value[:0]
} else {
kv.value = append(kv.value[:0], value...)
}
kv.noValue = noValue
return args
}
func allocArg(h []argsKV) ([]argsKV, *argsKV) {
n := len(h)
if cap(h) > n {
h = h[:n+1]
} else {
h = append(h, argsKV{
value: []byte{},
})
}
return h, &h[n]
}
func releaseArg(h []argsKV) []argsKV {
return h[:len(h)-1]
}
func hasArg(h []argsKV, key string) bool {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if key == string(kv.key) {
return true
}
}
return false
}
func peekArgBytes(h []argsKV, k []byte) []byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if bytes.Equal(kv.key, k) {
return kv.value
}
}
return nil
}
func peekArgStr(h []argsKV, k string) []byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if string(kv.key) == k {
return kv.value
}
}
return nil
}
type argsScanner struct {
b []byte
}
func (s *argsScanner) next(kv *argsKV) bool {
if len(s.b) == 0 {
return false
}
kv.noValue = argsHasValue
isKey := true
k := 0
for i, c := range s.b {
switch c {
case '=':
if isKey {
isKey = false
kv.key = decodeArgAppend(kv.key[:0], s.b[:i])
k = i + 1
}
case '&':
if isKey {
kv.key = decodeArgAppend(kv.key[:0], s.b[:i])
kv.value = kv.value[:0]
kv.noValue = argsNoValue
} else {
kv.value = decodeArgAppend(kv.value[:0], s.b[k:i])
}
s.b = s.b[i+1:]
return true
}
}
if isKey {
kv.key = decodeArgAppend(kv.key[:0], s.b)
kv.value = kv.value[:0]
kv.noValue = argsNoValue
} else {
kv.value = decodeArgAppend(kv.value[:0], s.b[k:])
}
s.b = s.b[len(s.b):]
return true
}
func decodeArgAppend(dst, src []byte) []byte {
if bytes.IndexByte(src, '%') < 0 && bytes.IndexByte(src, '+') < 0 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
}
// slow path
for i := 0; i < len(src); i++ {
c := src[i]
if c == '%' {
if i+2 >= len(src) {
return append(dst, src[i:]...)
}
x2 := hex2intTable[src[i+2]]
x1 := hex2intTable[src[i+1]]
if x1 == 16 || x2 == 16 {
dst = append(dst, '%')
} else {
dst = append(dst, x1<<4|x2)
i += 2
}
} else if c == '+' {
dst = append(dst, ' ')
} else {
dst = append(dst, c)
}
}
return dst
}
// decodeArgAppendNoPlus is almost identical to decodeArgAppend, but it doesn't
// substitute '+' with ' '.
//
// The function is copy-pasted from decodeArgAppend due to the performance
// reasons only.
func decodeArgAppendNoPlus(dst, src []byte) []byte {
if bytes.IndexByte(src, '%') < 0 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
}
// slow path
for i := 0; i < len(src); i++ {
c := src[i]
if c == '%' {
if i+2 >= len(src) {
return append(dst, src[i:]...)
}
x2 := hex2intTable[src[i+2]]
x1 := hex2intTable[src[i+1]]
if x1 == 16 || x2 == 16 {
dst = append(dst, '%')
} else {
dst = append(dst, x1<<4|x2)
i += 2
}
} else {
dst = append(dst, c)
}
}
return dst
}
fasthttp-1.31.0/args_test.go 0000664 0000000 0000000 00000034444 14130360711 0015760 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/valyala/bytebufferpool"
)
func TestDecodeArgAppend(t *testing.T) {
t.Parallel()
testDecodeArgAppend(t, "", "")
testDecodeArgAppend(t, "foobar", "foobar")
testDecodeArgAppend(t, "тест", "тест")
testDecodeArgAppend(t, "a%", "a%")
testDecodeArgAppend(t, "%a%21", "%a!")
testDecodeArgAppend(t, "ab%test", "ab%test")
testDecodeArgAppend(t, "d%тестF", "d%тестF")
testDecodeArgAppend(t, "a%\xffb%20c", "a%\xffb c")
testDecodeArgAppend(t, "foo%20bar", "foo bar")
testDecodeArgAppend(t, "f.o%2C1%3A2%2F4=%7E%60%21%40%23%24%25%5E%26*%28%29_-%3D%2B%5C%7C%2F%5B%5D%7B%7D%3B%3A%27%22%3C%3E%2C.%2F%3F",
"f.o,1:2/4=~`!@#$%^&*()_-=+\\|/[]{};:'\"<>,./?")
}
func testDecodeArgAppend(t *testing.T, s, expectedResult string) {
result := decodeArgAppend(nil, []byte(s))
if string(result) != expectedResult {
t.Fatalf("unexpected decodeArgAppend(%q)=%q; expecting %q", s, result, expectedResult)
}
}
func TestArgsAdd(t *testing.T) {
t.Parallel()
var a Args
a.Add("foo", "bar")
a.Add("foo", "baz")
a.Add("foo", "1")
a.Add("ba", "23")
a.Add("foo", "")
a.AddNoValue("foo")
if a.Len() != 6 {
t.Fatalf("unexpected number of elements: %d. Expecting 6", a.Len())
}
s := a.String()
expectedS := "foo=bar&foo=baz&foo=1&ba=23&foo=&foo"
if s != expectedS {
t.Fatalf("unexpected result: %q. Expecting %q", s, expectedS)
}
a.Sort(bytes.Compare)
ss := a.String()
expectedSS := "ba=23&foo=&foo&foo=1&foo=bar&foo=baz"
if ss != expectedSS {
t.Fatalf("unexpected result: %q. Expecting %q", ss, expectedSS)
}
var a1 Args
a1.Parse(s)
if a1.Len() != 6 {
t.Fatalf("unexpected number of elements: %d. Expecting 6", a.Len())
}
var barFound, bazFound, oneFound, emptyFound1, emptyFound2, baFound bool
a1.VisitAll(func(k, v []byte) {
switch string(k) {
case "foo":
switch string(v) {
case "bar":
barFound = true
case "baz":
bazFound = true
case "1":
oneFound = true
case "":
if emptyFound1 {
emptyFound2 = true
} else {
emptyFound1 = true
}
default:
t.Fatalf("unexpected value %q", v)
}
case "ba":
if string(v) != "23" {
t.Fatalf("unexpected value: %q. Expecting %q", v, "23")
}
baFound = true
default:
t.Fatalf("unexpected key found %q", k)
}
})
if !barFound || !bazFound || !oneFound || !emptyFound1 || !emptyFound2 || !baFound {
t.Fatalf("something is missing: %v, %v, %v, %v, %v, %v", barFound, bazFound, oneFound, emptyFound1, emptyFound2, baFound)
}
}
func TestArgsAcquireReleaseSequential(t *testing.T) {
testArgsAcquireRelease(t)
}
func TestArgsAcquireReleaseConcurrent(t *testing.T) {
ch := make(chan struct{}, 10)
for i := 0; i < 10; i++ {
go func() {
testArgsAcquireRelease(t)
ch <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testArgsAcquireRelease(t *testing.T) {
a := AcquireArgs()
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i*3+123)
a.Set(k, v)
}
s := a.String()
a.Reset()
a.Parse(s)
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key_%d", i)
expectedV := fmt.Sprintf("value_%d", i*3+123)
v := a.Peek(k)
if string(v) != expectedV {
t.Fatalf("unexpected value %q for key %q. Expecting %q", v, k, expectedV)
}
}
ReleaseArgs(a)
}
func TestArgsPeekMulti(t *testing.T) {
t.Parallel()
var a Args
a.Parse("foo=123&bar=121&foo=321&foo=&barz=sdf")
vv := a.PeekMulti("foo")
expectedVV := [][]byte{
[]byte("123"),
[]byte("321"),
[]byte(nil),
}
if !reflect.DeepEqual(vv, expectedVV) {
t.Fatalf("unexpected vv\n%#v\nExpecting\n%#v\n", vv, expectedVV)
}
vv = a.PeekMulti("aaaa")
if len(vv) > 0 {
t.Fatalf("expecting empty result for non-existing key. Got %#v", vv)
}
vv = a.PeekMulti("bar")
expectedVV = [][]byte{[]byte("121")}
if !reflect.DeepEqual(vv, expectedVV) {
t.Fatalf("unexpected vv\n%#v\nExpecting\n%#v\n", vv, expectedVV)
}
}
func TestArgsEscape(t *testing.T) {
t.Parallel()
testArgsEscape(t, "foo", "bar", "foo=bar")
// Test all characters
k := "f.o,1:2/4"
var v = make([]byte, 256)
for i := 0; i < 256; i++ {
v[i] = byte(i)
}
u := url.Values{}
u.Add(k, string(v))
testArgsEscape(t, k, string(v), u.Encode())
}
func testArgsEscape(t *testing.T, k, v, expectedS string) {
var a Args
a.Set(k, v)
s := a.String()
if s != expectedS {
t.Fatalf("unexpected args %q. Expecting %q. k=%q, v=%q", s, expectedS, k, v)
}
}
func TestPathEscape(t *testing.T) {
t.Parallel()
testPathEscape(t, "/foo/bar")
testPathEscape(t, "")
testPathEscape(t, "/")
testPathEscape(t, "//")
testPathEscape(t, "*") // See https://github.com/golang/go/issues/11202
// Test all characters
var pathSegment = make([]byte, 256)
for i := 0; i < 256; i++ {
pathSegment[i] = byte(i)
}
testPathEscape(t, "/foo/"+string(pathSegment))
}
func testPathEscape(t *testing.T, s string) {
u := url.URL{Path: s}
expectedS := u.EscapedPath()
res := string(appendQuotedPath(nil, []byte(s)))
if res != expectedS {
t.Fatalf("unexpected args %q. Expecting %q.", res, expectedS)
}
}
func TestArgsWriteTo(t *testing.T) {
t.Parallel()
s := "foo=bar&baz=123&aaa=bbb"
var a Args
a.Parse(s)
var w bytebufferpool.ByteBuffer
n, err := a.WriteTo(&w)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected n: %d. Expecting %d", n, len(s))
}
result := string(w.B)
if result != s {
t.Fatalf("unexpected result %q. Expecting %q", result, s)
}
}
func TestArgsGetBool(t *testing.T) {
t.Parallel()
testArgsGetBool(t, "", false)
testArgsGetBool(t, "0", false)
testArgsGetBool(t, "n", false)
testArgsGetBool(t, "no", false)
testArgsGetBool(t, "1", true)
testArgsGetBool(t, "y", true)
testArgsGetBool(t, "yes", true)
testArgsGetBool(t, "123", false)
testArgsGetBool(t, "foobar", false)
}
func testArgsGetBool(t *testing.T, value string, expectedResult bool) {
var a Args
a.Parse("v=" + value)
result := a.GetBool("v")
if result != expectedResult {
t.Fatalf("unexpected result %v. Expecting %v for value %q", result, expectedResult, value)
}
}
func TestArgsUint(t *testing.T) {
t.Parallel()
var a Args
a.SetUint("foo", 123)
a.SetUint("bar", 0)
a.SetUint("aaaa", 34566)
expectedS := "foo=123&bar=0&aaaa=34566"
s := string(a.QueryString())
if s != expectedS {
t.Fatalf("unexpected args %q. Expecting %q", s, expectedS)
}
if a.GetUintOrZero("foo") != 123 {
t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("foo"), 123)
}
if a.GetUintOrZero("bar") != 0 {
t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("bar"), 0)
}
if a.GetUintOrZero("aaaa") != 34566 {
t.Fatalf("unexpected arg value %d. Expecting %d", a.GetUintOrZero("aaaa"), 34566)
}
if string(a.Peek("foo")) != "123" {
t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("foo"), "123")
}
if string(a.Peek("bar")) != "0" {
t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("bar"), "0")
}
if string(a.Peek("aaaa")) != "34566" {
t.Fatalf("unexpected arg value %q. Expecting %q", a.Peek("aaaa"), "34566")
}
}
func TestArgsCopyTo(t *testing.T) {
t.Parallel()
var a Args
// empty args
testCopyTo(t, &a)
a.Set("foo", "bar")
testCopyTo(t, &a)
a.Set("xxx", "yyy")
a.AddNoValue("ba")
testCopyTo(t, &a)
a.Del("foo")
testCopyTo(t, &a)
}
func testCopyTo(t *testing.T, a *Args) {
keys := make(map[string]struct{})
a.VisitAll(func(k, v []byte) {
keys[string(k)] = struct{}{}
})
var b Args
a.CopyTo(&b)
if !reflect.DeepEqual(*a, b) { //nolint
t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint
}
b.VisitAll(func(k, v []byte) {
if _, ok := keys[string(k)]; !ok {
t.Fatalf("unexpected key %q after copying from %q", k, a.String())
}
delete(keys, string(k))
})
if len(keys) > 0 {
t.Fatalf("missing keys %#v after copying from %q", keys, a.String())
}
}
func TestArgsVisitAll(t *testing.T) {
t.Parallel()
var a Args
a.Set("foo", "bar")
i := 0
a.VisitAll(func(k, v []byte) {
if string(k) != "foo" {
t.Fatalf("unexpected key %q. Expected %q", k, "foo")
}
if string(v) != "bar" {
t.Fatalf("unexpected value %q. Expected %q", v, "bar")
}
i++
})
if i != 1 {
t.Fatalf("unexpected number of VisitAll calls: %d. Expected %d", i, 1)
}
}
func TestArgsStringCompose(t *testing.T) {
t.Parallel()
var a Args
a.Set("foo", "bar")
a.Set("aa", "bbb")
a.Set("привет", "мир")
a.SetNoValue("bb")
a.Set("", "xxxx")
a.Set("cvx", "")
a.SetNoValue("novalue")
expectedS := "foo=bar&aa=bbb&%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=%D0%BC%D0%B8%D1%80&bb&=xxxx&cvx=&novalue"
s := a.String()
if s != expectedS {
t.Fatalf("Unexpected string %q. Exected %q", s, expectedS)
}
}
func TestArgsString(t *testing.T) {
t.Parallel()
var a Args
testArgsString(t, &a, "")
testArgsString(t, &a, "foobar")
testArgsString(t, &a, "foo=bar")
testArgsString(t, &a, "foo=bar&baz=sss")
testArgsString(t, &a, "")
testArgsString(t, &a, "f+o=x.x%2A-_8x%D0%BF%D1%80%D0%B8%D0%B2%D0%B5aaa&sdf=ss")
testArgsString(t, &a, "=asdfsdf")
}
func testArgsString(t *testing.T, a *Args, s string) {
a.Parse(s)
s1 := a.String()
if s != s1 {
t.Fatalf("Unexpected args %q. Expected %q", s1, s)
}
}
func TestArgsSetGetDel(t *testing.T) {
t.Parallel()
var a Args
if len(a.Peek("foo")) > 0 {
t.Fatalf("Unexpected value: %q", a.Peek("foo"))
}
if len(a.Peek("")) > 0 {
t.Fatalf("Unexpected value: %q", a.Peek(""))
}
a.Del("xxx")
for j := 0; j < 3; j++ {
for i := 0; i < 10; i++ {
k := fmt.Sprintf("foo%d", i)
v := fmt.Sprintf("bar_%d", i)
a.Set(k, v)
if string(a.Peek(k)) != v {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v)
}
}
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("foo%d", i)
v := fmt.Sprintf("bar_%d", i)
if string(a.Peek(k)) != v {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v)
}
a.Del(k)
if string(a.Peek(k)) != "" {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "")
}
}
a.Parse("aaa=xxx&bb=aa")
if string(a.Peek("foo0")) != "" {
t.Fatalf("Unexpected value %q", a.Peek("foo0"))
}
if string(a.Peek("aaa")) != "xxx" {
t.Fatalf("Unexpected value %q. Expected %q", a.Peek("aaa"), "xxx")
}
if string(a.Peek("bb")) != "aa" {
t.Fatalf("Unexpected value %q. Expected %q", a.Peek("bb"), "aa")
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("xx%d", i)
v := fmt.Sprintf("yy%d", i)
a.Set(k, v)
if string(a.Peek(k)) != v {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v)
}
}
for i := 5; i < 10; i++ {
k := fmt.Sprintf("xx%d", i)
v := fmt.Sprintf("yy%d", i)
if string(a.Peek(k)) != v {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v)
}
a.Del(k)
if string(a.Peek(k)) != "" {
t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "")
}
}
}
func TestArgsParse(t *testing.T) {
t.Parallel()
var a Args
// empty args
testArgsParse(t, &a, "", 0, "foo=", "bar=", "=")
// arg without value
testArgsParse(t, &a, "foo1", 1, "foo=", "bar=", "=")
// arg without value, but with equal sign
testArgsParse(t, &a, "foo2=", 1, "foo=", "bar=", "=")
// arg with value
testArgsParse(t, &a, "foo3=bar1", 1, "foo3=bar1", "bar=", "=")
// empty key
testArgsParse(t, &a, "=bar2", 1, "foo=", "=bar2", "bar2=")
// missing kv
testArgsParse(t, &a, "&&&&", 0, "foo=", "bar=", "=")
// multiple values with the same key
testArgsParse(t, &a, "x=1&x=2&x=3", 3, "x=1")
// multiple args
testArgsParse(t, &a, "&&&qw=er&tyx=124&&&zxc_ss=2234&&", 3, "qw=er", "tyx=124", "zxc_ss=2234")
// multiple args without values
testArgsParse(t, &a, "&&a&&b&&bar&baz", 4, "a=", "b=", "bar=", "baz=")
// values with '='
testArgsParse(t, &a, "zz=1&k=v=v=a=a=s", 2, "k=v=v=a=a=s", "zz=1")
// mixed '=' and '&'
testArgsParse(t, &a, "sss&z=dsf=&df", 3, "sss=", "z=dsf=", "df=")
// encoded args
testArgsParse(t, &a, "f+o%20o=%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82+test", 1, "f o o=привет test")
// invalid percent encoding
testArgsParse(t, &a, "f%=x&qw%z=d%0k%20p&%%20=%%%20x", 3, "f%=x", "qw%z=d%0k p", "% =%% x")
// special chars
testArgsParse(t, &a, "a.b,c:d/e=f.g,h:i/q", 1, "a.b,c:d/e=f.g,h:i/q")
}
func TestArgsHas(t *testing.T) {
t.Parallel()
var a Args
// single arg
testArgsHas(t, &a, "foo", "foo")
testArgsHasNot(t, &a, "foo", "bar", "baz", "")
// multi args without values
testArgsHas(t, &a, "foo&bar", "foo", "bar")
testArgsHasNot(t, &a, "foo&bar", "", "aaaa")
// multi args
testArgsHas(t, &a, "b=xx&=aaa&c=", "b", "", "c")
testArgsHasNot(t, &a, "b=xx&=aaa&c=", "xx", "aaa", "foo")
// encoded args
testArgsHas(t, &a, "a+b=c+d%20%20e", "a b")
testArgsHasNot(t, &a, "a+b=c+d", "a+b", "c+d")
}
func testArgsHas(t *testing.T, a *Args, s string, expectedKeys ...string) {
a.Parse(s)
for _, key := range expectedKeys {
if !a.Has(key) {
t.Fatalf("Missing key %q in %q", key, s)
}
}
}
func testArgsHasNot(t *testing.T, a *Args, s string, unexpectedKeys ...string) {
a.Parse(s)
for _, key := range unexpectedKeys {
if a.Has(key) {
t.Fatalf("Unexpected key %q in %q", key, s)
}
}
}
func testArgsParse(t *testing.T, a *Args, s string, expectedLen int, expectedArgs ...string) {
a.Parse(s)
if a.Len() != expectedLen {
t.Fatalf("Unexpected args len %d. Expected %d. s=%q", a.Len(), expectedLen, s)
}
for _, xx := range expectedArgs {
tmp := strings.SplitN(xx, "=", 2)
k := tmp[0]
v := tmp[1]
buf := a.Peek(k)
if string(buf) != v {
t.Fatalf("Unexpected value for key=%q: %q. Expected %q. s=%q", k, buf, v, s)
}
}
}
func TestArgsDeleteAll(t *testing.T) {
t.Parallel()
var a Args
a.Add("q1", "foo")
a.Add("q1", "bar")
a.Add("q1", "baz")
a.Add("q1", "quux")
a.Add("q2", "1234")
a.Del("q1")
if a.Len() != 1 || a.Has("q1") {
t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %s", a.String())
}
}
func TestIssue932(t *testing.T) {
t.Parallel()
var a []argsKV
a = setArg(a, "t1", "ok", argsHasValue)
a = setArg(a, "t2", "", argsHasValue)
a = setArg(a, "t1", "", argsHasValue)
a = setArgBytes(a, s2b("t3"), []byte{}, argsHasValue)
a = setArgBytes(a, s2b("t4"), nil, argsHasValue)
if peekArgStr(a, "t1") == nil {
t.Error("nil not expected for t1")
}
if peekArgStr(a, "t2") == nil {
t.Error("nil not expected for t2")
}
if peekArgStr(a, "t3") == nil {
t.Error("nil not expected for t3")
}
if peekArgStr(a, "t4") != nil {
t.Error("nil expected for t4")
}
}
fasthttp-1.31.0/args_timing_test.go 0000664 0000000 0000000 00000001046 14130360711 0017317 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"testing"
)
func BenchmarkArgsParse(b *testing.B) {
s := []byte("foo=bar&baz=qqq&aaaaa=bbbb")
b.RunParallel(func(pb *testing.PB) {
var a Args
for pb.Next() {
a.ParseBytes(s)
}
})
}
func BenchmarkArgsPeek(b *testing.B) {
value := []byte("foobarbaz1234")
key := "foobarbaz"
b.RunParallel(func(pb *testing.PB) {
var a Args
a.SetBytesV(key, value)
for pb.Next() {
if !bytes.Equal(a.Peek(key), value) {
b.Fatalf("unexpected arg value %q. Expecting %q", a.Peek(key), value)
}
}
})
}
fasthttp-1.31.0/brotli.go 0000664 0000000 0000000 00000012137 14130360711 0015253 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"io"
"sync"
"github.com/andybalholm/brotli"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
// Supported compression levels.
const (
CompressBrotliNoCompression = 0
CompressBrotliBestSpeed = brotli.BestSpeed
CompressBrotliBestCompression = brotli.BestCompression
// Choose a default brotli compression level comparable to
// CompressDefaultCompression (gzip 6)
// See: https://github.com/valyala/fasthttp/issues/798#issuecomment-626293806
CompressBrotliDefaultCompression = 4
)
func acquireBrotliReader(r io.Reader) (*brotli.Reader, error) {
v := brotliReaderPool.Get()
if v == nil {
return brotli.NewReader(r), nil
}
zr := v.(*brotli.Reader)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseBrotliReader(zr *brotli.Reader) {
brotliReaderPool.Put(zr)
}
var brotliReaderPool sync.Pool
func acquireStacklessBrotliWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeBrotliCompressLevel(level)
p := stacklessBrotliWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealBrotliWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessBrotliWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeBrotliCompressLevel(level)
p := stacklessBrotliWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealBrotliWriter(w io.Writer, level int) *brotli.Writer {
nLevel := normalizeBrotliCompressLevel(level)
p := realBrotliWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw := brotli.NewWriterLevel(w, level)
return zw
}
zw := v.(*brotli.Writer)
zw.Reset(w)
return zw
}
func releaseRealBrotliWriter(zw *brotli.Writer, level int) {
zw.Close()
nLevel := normalizeBrotliCompressLevel(level)
p := realBrotliWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessBrotliWriterPoolMap = newCompressWriterPoolMap()
realBrotliWriterPoolMap = newCompressWriterPoolMap()
)
// AppendBrotliBytesLevel appends brotlied src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// * CompressBrotliNoCompression
// * CompressBrotliBestSpeed
// * CompressBrotliBestCompression
// * CompressBrotliDefaultCompression
func AppendBrotliBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteBrotliLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteBrotliLevel writes brotlied p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// * CompressBrotliNoCompression
// * CompressBrotliBestSpeed
// * CompressBrotliBestCompression
// * CompressBrotliDefaultCompression
func WriteBrotliLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteBrotli
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteBrotli(ctx)
return len(p), nil
default:
zw := acquireStacklessBrotliWriter(w, level)
n, err := zw.Write(p)
releaseStacklessBrotliWriter(zw, level)
return n, err
}
}
var stacklessWriteBrotli = stackless.NewFunc(nonblockingWriteBrotli)
func nonblockingWriteBrotli(ctxv interface{}) {
ctx := ctxv.(*compressCtx)
zw := acquireRealBrotliWriter(ctx.w, ctx.level)
_, err := zw.Write(ctx.p)
if err != nil {
panic(fmt.Sprintf("BUG: brotli.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
}
releaseRealBrotliWriter(zw, ctx.level)
}
// WriteBrotli writes brotlied p to w and returns the number of compressed
// bytes written to w.
func WriteBrotli(w io.Writer, p []byte) (int, error) {
return WriteBrotliLevel(w, p, CompressBrotliDefaultCompression)
}
// AppendBrotliBytes appends brotlied src to dst and returns the resulting dst.
func AppendBrotliBytes(dst, src []byte) []byte {
return AppendBrotliBytesLevel(dst, src, CompressBrotliDefaultCompression)
}
// WriteUnbrotli writes unbrotlied p to w and returns the number of uncompressed
// bytes written to w.
func WriteUnbrotli(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{p}
zr, err := acquireBrotliReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseBrotliReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data unbrotlied: %d", n)
}
return nn, err
}
// AppendUnbrotliBytes appends unbrotlied src to dst and returns the resulting dst.
func AppendUnbrotliBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{dst}
_, err := WriteUnbrotli(w, src)
return w.b, err
}
// normalizes compression level into [0..11], so it could be used as an index
// in *PoolMap.
func normalizeBrotliCompressLevel(level int) int {
// -2 is the lowest compression level - CompressHuffmanOnly
// 9 is the highest compression level - CompressBestCompression
if level < 0 || level > 11 {
level = CompressBrotliDefaultCompression
}
return level
}
fasthttp-1.31.0/brotli_test.go 0000664 0000000 0000000 00000010757 14130360711 0016320 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"testing"
)
func TestBrotliBytesSerial(t *testing.T) {
t.Parallel()
if err := testBrotliBytes(); err != nil {
t.Fatal(err)
}
}
func TestBrotliBytesConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testBrotliBytes); err != nil {
t.Fatal(err)
}
}
func testBrotliBytes() error {
for _, s := range compressTestcases {
if err := testBrotliBytesSingleCase(s); err != nil {
return err
}
}
return nil
}
func testBrotliBytesSingleCase(s string) error {
prefix := []byte("foobar")
brotlipedS := AppendBrotliBytes(prefix, []byte(s))
if !bytes.Equal(brotlipedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, brotlipedS[:len(prefix)], prefix)
}
unbrotliedS, err := AppendUnbrotliBytes(prefix, brotlipedS[len(prefix):])
if err != nil {
return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
}
if !bytes.Equal(unbrotliedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, unbrotliedS[:len(prefix)], prefix)
}
unbrotliedS = unbrotliedS[len(prefix):]
if string(unbrotliedS) != s {
return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", unbrotliedS, s)
}
return nil
}
func TestBrotliCompressSerial(t *testing.T) {
t.Parallel()
if err := testBrotliCompress(); err != nil {
t.Fatal(err)
}
}
func TestBrotliCompressConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testBrotliCompress); err != nil {
t.Fatal(err)
}
}
func testBrotliCompress() error {
for _, s := range compressTestcases {
if err := testBrotliCompressSingleCase(s); err != nil {
return err
}
}
return nil
}
func testBrotliCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessBrotliWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
releaseStacklessBrotliWriter(zw, CompressDefaultCompression)
zr, err := acquireBrotliReader(&buf)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
body, err := ioutil.ReadAll(zr)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
}
releaseBrotliReader(zr)
return nil
}
func TestCompressHandlerBrotliLevel(t *testing.T) {
t.Parallel()
expectedBody := string(createFixedBody(2e4))
h := CompressHandlerBrotliLevel(func(ctx *RequestCtx) {
ctx.Write([]byte(expectedBody)) //nolint:errcheck
}, CompressBrotliDefaultCompression, CompressDefaultCompression)
var ctx RequestCtx
var resp Response
// verify uncompressed response
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify gzip-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify brotli-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc, br")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "br" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "br")
}
body, err = resp.BodyUnbrotli()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
}
fasthttp-1.31.0/bytesconv.go 0000664 0000000 0000000 00000020323 14130360711 0015770 0 ustar 00root root 0000000 0000000 //go:generate go run bytesconv_table_gen.go
package fasthttp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"math"
"net"
"reflect"
"strings"
"sync"
"time"
"unsafe"
)
// AppendHTMLEscape appends html-escaped s to dst and returns the extended dst.
func AppendHTMLEscape(dst []byte, s string) []byte {
if strings.IndexByte(s, '<') < 0 &&
strings.IndexByte(s, '>') < 0 &&
strings.IndexByte(s, '"') < 0 &&
strings.IndexByte(s, '\'') < 0 {
// fast path - nothing to escape
return append(dst, s...)
}
// slow path
var prev int
var sub string
for i, n := 0, len(s); i < n; i++ {
sub = ""
switch s[i] {
case '<':
sub = "<"
case '>':
sub = ">"
case '"':
sub = """
case '\'':
sub = "'"
}
if len(sub) > 0 {
dst = append(dst, s[prev:i]...)
dst = append(dst, sub...)
prev = i + 1
}
}
return append(dst, s[prev:]...)
}
// AppendHTMLEscapeBytes appends html-escaped s to dst and returns
// the extended dst.
func AppendHTMLEscapeBytes(dst, s []byte) []byte {
return AppendHTMLEscape(dst, b2s(s))
}
// AppendIPv4 appends string representation of the given ip v4 to dst
// and returns the extended dst.
func AppendIPv4(dst []byte, ip net.IP) []byte {
ip = ip.To4()
if ip == nil {
return append(dst, "non-v4 ip passed to AppendIPv4"...)
}
dst = AppendUint(dst, int(ip[0]))
for i := 1; i < 4; i++ {
dst = append(dst, '.')
dst = AppendUint(dst, int(ip[i]))
}
return dst
}
var errEmptyIPStr = errors.New("empty ip address string")
// ParseIPv4 parses ip address from ipStr into dst and returns the extended dst.
func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) {
if len(ipStr) == 0 {
return dst, errEmptyIPStr
}
if len(dst) < net.IPv4len {
dst = make([]byte, net.IPv4len)
}
copy(dst, net.IPv4zero)
dst = dst.To4()
if dst == nil {
panic("BUG: dst must not be nil")
}
b := ipStr
for i := 0; i < 3; i++ {
n := bytes.IndexByte(b, '.')
if n < 0 {
return dst, fmt.Errorf("cannot find dot in ipStr %q", ipStr)
}
v, err := ParseUint(b[:n])
if err != nil {
return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
}
dst[i] = byte(v)
b = b[n+1:]
}
v, err := ParseUint(b)
if err != nil {
return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
}
dst[3] = byte(v)
return dst, nil
}
// AppendHTTPDate appends HTTP-compliant (RFC1123) representation of date
// to dst and returns the extended dst.
func AppendHTTPDate(dst []byte, date time.Time) []byte {
dst = date.In(time.UTC).AppendFormat(dst, time.RFC1123)
copy(dst[len(dst)-3:], strGMT)
return dst
}
// ParseHTTPDate parses HTTP-compliant (RFC1123) date.
func ParseHTTPDate(date []byte) (time.Time, error) {
return time.Parse(time.RFC1123, b2s(date))
}
// AppendUint appends n to dst and returns the extended dst.
func AppendUint(dst []byte, n int) []byte {
if n < 0 {
panic("BUG: int must be positive")
}
var b [20]byte
buf := b[:]
i := len(buf)
var q int
for n >= 10 {
i--
q = n / 10
buf[i] = '0' + byte(n-q*10)
n = q
}
i--
buf[i] = '0' + byte(n)
dst = append(dst, buf[i:]...)
return dst
}
// ParseUint parses uint from buf.
func ParseUint(buf []byte) (int, error) {
v, n, err := parseUintBuf(buf)
if n != len(buf) {
return -1, errUnexpectedTrailingChar
}
return v, err
}
var (
errEmptyInt = errors.New("empty integer")
errUnexpectedFirstChar = errors.New("unexpected first char found. Expecting 0-9")
errUnexpectedTrailingChar = errors.New("unexpected trailing char found. Expecting 0-9")
errTooLongInt = errors.New("too long int")
)
func parseUintBuf(b []byte) (int, int, error) {
n := len(b)
if n == 0 {
return -1, 0, errEmptyInt
}
v := 0
for i := 0; i < n; i++ {
c := b[i]
k := c - '0'
if k > 9 {
if i == 0 {
return -1, i, errUnexpectedFirstChar
}
return v, i, nil
}
vNew := 10*v + int(k)
// Test for overflow.
if vNew < v {
return -1, i, errTooLongInt
}
v = vNew
}
return v, n, nil
}
var (
errEmptyFloat = errors.New("empty float number")
errDuplicateFloatPoint = errors.New("duplicate point found in float number")
errUnexpectedFloatEnd = errors.New("unexpected end of float number")
errInvalidFloatExponent = errors.New("invalid float number exponent")
errUnexpectedFloatChar = errors.New("unexpected char found in float number")
)
// ParseUfloat parses unsigned float from buf.
func ParseUfloat(buf []byte) (float64, error) {
if len(buf) == 0 {
return -1, errEmptyFloat
}
b := buf
var v uint64
var offset = 1.0
var pointFound bool
for i, c := range b {
if c < '0' || c > '9' {
if c == '.' {
if pointFound {
return -1, errDuplicateFloatPoint
}
pointFound = true
continue
}
if c == 'e' || c == 'E' {
if i+1 >= len(b) {
return -1, errUnexpectedFloatEnd
}
b = b[i+1:]
minus := -1
switch b[0] {
case '+':
b = b[1:]
minus = 1
case '-':
b = b[1:]
default:
minus = 1
}
vv, err := ParseUint(b)
if err != nil {
return -1, errInvalidFloatExponent
}
return float64(v) * offset * math.Pow10(minus*int(vv)), nil
}
return -1, errUnexpectedFloatChar
}
v = 10*v + uint64(c-'0')
if pointFound {
offset /= 10
}
}
return float64(v) * offset, nil
}
var (
errEmptyHexNum = errors.New("empty hex number")
errTooLargeHexNum = errors.New("too large hex number")
)
func readHexInt(r *bufio.Reader) (int, error) {
n := 0
i := 0
var k int
for {
c, err := r.ReadByte()
if err != nil {
if err == io.EOF && i > 0 {
return n, nil
}
return -1, err
}
k = int(hex2intTable[c])
if k == 16 {
if i == 0 {
return -1, errEmptyHexNum
}
if err := r.UnreadByte(); err != nil {
return -1, err
}
return n, nil
}
if i >= maxHexIntChars {
return -1, errTooLargeHexNum
}
n = (n << 4) | k
i++
}
}
var hexIntBufPool sync.Pool
func writeHexInt(w *bufio.Writer, n int) error {
if n < 0 {
panic("BUG: int must be positive")
}
v := hexIntBufPool.Get()
if v == nil {
v = make([]byte, maxHexIntChars+1)
}
buf := v.([]byte)
i := len(buf) - 1
for {
buf[i] = lowerhex[n&0xf]
n >>= 4
if n == 0 {
break
}
i--
}
_, err := w.Write(buf[i:])
hexIntBufPool.Put(v)
return err
}
const (
upperhex = "0123456789ABCDEF"
lowerhex = "0123456789abcdef"
)
func lowercaseBytes(b []byte) {
for i := 0; i < len(b); i++ {
p := &b[i]
*p = toLowerTable[*p]
}
}
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func b2s(b []byte) string {
/* #nosec G103 */
return *(*string)(unsafe.Pointer(&b))
}
// s2b converts string to a byte slice without memory allocation.
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func s2b(s string) (b []byte) {
/* #nosec G103 */
bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
/* #nosec G103 */
sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
bh.Data = sh.Data
bh.Cap = sh.Len
bh.Len = sh.Len
return b
}
// AppendUnquotedArg appends url-decoded src to dst and returns appended dst.
//
// dst may point to src. In this case src will be overwritten.
func AppendUnquotedArg(dst, src []byte) []byte {
return decodeArgAppend(dst, src)
}
// AppendQuotedArg appends url-encoded src to dst and returns appended dst.
func AppendQuotedArg(dst, src []byte) []byte {
for _, c := range src {
switch {
case c == ' ':
dst = append(dst, '+')
case quotedArgShouldEscapeTable[int(c)] != 0:
dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf])
default:
dst = append(dst, c)
}
}
return dst
}
func appendQuotedPath(dst, src []byte) []byte {
// Fix issue in https://github.com/golang/go/issues/11202
if len(src) == 1 && src[0] == '*' {
return append(dst, '*')
}
for _, c := range src {
if quotedPathShouldEscapeTable[int(c)] != 0 {
dst = append(dst, '%', upperhex[c>>4], upperhex[c&15])
} else {
dst = append(dst, c)
}
}
return dst
}
fasthttp-1.31.0/bytesconv_32.go 0000664 0000000 0000000 00000000213 14130360711 0016270 0 ustar 00root root 0000000 0000000 //go:build !amd64 && !arm64 && !ppc64 && !ppc64le
// +build !amd64,!arm64,!ppc64,!ppc64le
package fasthttp
const (
maxHexIntChars = 7
)
fasthttp-1.31.0/bytesconv_32_test.go 0000664 0000000 0000000 00000002433 14130360711 0017335 0 ustar 00root root 0000000 0000000 //go:build !amd64 && !arm64 && !ppc64 && !ppc64le
// +build !amd64,!arm64,!ppc64,!ppc64le
package fasthttp
import (
"testing"
)
func TestWriteHexInt(t *testing.T) {
t.Parallel()
testWriteHexInt(t, 0, "0")
testWriteHexInt(t, 1, "1")
testWriteHexInt(t, 0x123, "123")
testWriteHexInt(t, 0x7fffffff, "7fffffff")
}
func TestAppendUint(t *testing.T) {
t.Parallel()
testAppendUint(t, 0)
testAppendUint(t, 123)
testAppendUint(t, 0x7fffffff)
for i := 0; i < 2345; i++ {
testAppendUint(t, i)
}
}
func TestReadHexIntSuccess(t *testing.T) {
t.Parallel()
testReadHexIntSuccess(t, "0", 0)
testReadHexIntSuccess(t, "fF", 0xff)
testReadHexIntSuccess(t, "00abc", 0xabc)
testReadHexIntSuccess(t, "7ffffff", 0x7ffffff)
testReadHexIntSuccess(t, "000", 0)
testReadHexIntSuccess(t, "1234ZZZ", 0x1234)
}
func TestParseUintError32(t *testing.T) {
t.Parallel()
// Overflow by last digit: 2 ** 32 / 2 * 10 ** n
testParseUintError(t, "2147483648")
testParseUintError(t, "21474836480")
testParseUintError(t, "214748364800")
}
func TestParseUintSuccess(t *testing.T) {
t.Parallel()
testParseUintSuccess(t, "0", 0)
testParseUintSuccess(t, "123", 123)
testParseUintSuccess(t, "123456789", 123456789)
// Max supported value: 2 ** 32 / 2 - 1
testParseUintSuccess(t, "2147483647", 2147483647)
}
fasthttp-1.31.0/bytesconv_64.go 0000664 0000000 0000000 00000000204 14130360711 0016275 0 ustar 00root root 0000000 0000000 //go:build amd64 || arm64 || ppc64 || ppc64le
// +build amd64 arm64 ppc64 ppc64le
package fasthttp
const (
maxHexIntChars = 15
)
fasthttp-1.31.0/bytesconv_64_test.go 0000664 0000000 0000000 00000002737 14130360711 0017351 0 ustar 00root root 0000000 0000000 //go:build amd64 || arm64 || ppc64 || ppc64le
// +build amd64 arm64 ppc64 ppc64le
package fasthttp
import (
"testing"
)
func TestWriteHexInt(t *testing.T) {
t.Parallel()
testWriteHexInt(t, 0, "0")
testWriteHexInt(t, 1, "1")
testWriteHexInt(t, 0x123, "123")
testWriteHexInt(t, 0x7fffffffffffffff, "7fffffffffffffff")
}
func TestAppendUint(t *testing.T) {
t.Parallel()
testAppendUint(t, 0)
testAppendUint(t, 123)
testAppendUint(t, 0x7fffffffffffffff)
for i := 0; i < 2345; i++ {
testAppendUint(t, i)
}
}
func TestReadHexIntSuccess(t *testing.T) {
t.Parallel()
testReadHexIntSuccess(t, "0", 0)
testReadHexIntSuccess(t, "fF", 0xff)
testReadHexIntSuccess(t, "00abc", 0xabc)
testReadHexIntSuccess(t, "7fffffff", 0x7fffffff)
testReadHexIntSuccess(t, "000", 0)
testReadHexIntSuccess(t, "1234ZZZ", 0x1234)
testReadHexIntSuccess(t, "7ffffffffffffff", 0x7ffffffffffffff)
}
func TestParseUintError64(t *testing.T) {
t.Parallel()
// Overflow by last digit: 2 ** 64 / 2 * 10 ** n
testParseUintError(t, "9223372036854775808")
testParseUintError(t, "92233720368547758080")
testParseUintError(t, "922337203685477580800")
}
func TestParseUintSuccess(t *testing.T) {
t.Parallel()
testParseUintSuccess(t, "0", 0)
testParseUintSuccess(t, "123", 123)
testParseUintSuccess(t, "1234567890", 1234567890)
testParseUintSuccess(t, "123456789012345678", 123456789012345678)
// Max supported value: 2 ** 64 / 2 - 1
testParseUintSuccess(t, "9223372036854775807", 9223372036854775807)
}
fasthttp-1.31.0/bytesconv_table.go 0000664 0000000 0000000 00000011314 14130360711 0017137 0 ustar 00root root 0000000 0000000 package fasthttp
// Code generated by go run bytesconv_table_gen.go; DO NOT EDIT.
// See bytesconv_table_gen.go for more information about these tables.
const hex2intTable = "\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x00\x01\x02\x03\x04\x05\x06\a\b\t\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10"
const toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
const toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ABCDEFGHIJKLMNOPQRSTUVWXYZ{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
const quotedArgShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01"
const quotedPathShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01"
fasthttp-1.31.0/bytesconv_table_gen.go 0000664 0000000 0000000 00000004541 14130360711 0017774 0 ustar 00root root 0000000 0000000 //go:build ignore
// +build ignore
package main
import (
"bytes"
"fmt"
"io/ioutil"
"log"
)
const (
toLower = 'a' - 'A'
)
func main() {
hex2intTable := func() [256]byte {
var b [256]byte
for i := 0; i < 256; i++ {
c := byte(16)
if i >= '0' && i <= '9' {
c = byte(i) - '0'
} else if i >= 'a' && i <= 'f' {
c = byte(i) - 'a' + 10
} else if i >= 'A' && i <= 'F' {
c = byte(i) - 'A' + 10
}
b[i] = c
}
return b
}()
toLowerTable := func() [256]byte {
var a [256]byte
for i := 0; i < 256; i++ {
c := byte(i)
if c >= 'A' && c <= 'Z' {
c += toLower
}
a[i] = c
}
return a
}()
toUpperTable := func() [256]byte {
var a [256]byte
for i := 0; i < 256; i++ {
c := byte(i)
if c >= 'a' && c <= 'z' {
c -= toLower
}
a[i] = c
}
return a
}()
quotedArgShouldEscapeTable := func() [256]byte {
// According to RFC 3986 §2.3
var a [256]byte
for i := 0; i < 256; i++ {
a[i] = 1
}
// ALPHA
for i := int('a'); i <= int('z'); i++ {
a[i] = 0
}
for i := int('A'); i <= int('Z'); i++ {
a[i] = 0
}
// DIGIT
for i := int('0'); i <= int('9'); i++ {
a[i] = 0
}
// Unreserved characters
for _, v := range `-_.~` {
a[v] = 0
}
return a
}()
quotedPathShouldEscapeTable := func() [256]byte {
// The implementation here equal to net/url shouldEscape(s, encodePath)
//
// The RFC allows : @ & = + $ but saves / ; , for assigning
// meaning to individual path segments. This package
// only manipulates the path as a whole, so we allow those
// last three as well. That leaves only ? to escape.
var a = quotedArgShouldEscapeTable
for _, v := range `$&+,/:;=@` {
a[v] = 0
}
return a
}()
w := new(bytes.Buffer)
w.WriteString(pre)
fmt.Fprintf(w, "const hex2intTable = %q\n", hex2intTable)
fmt.Fprintf(w, "const toLowerTable = %q\n", toLowerTable)
fmt.Fprintf(w, "const toUpperTable = %q\n", toUpperTable)
fmt.Fprintf(w, "const quotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable)
fmt.Fprintf(w, "const quotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable)
if err := ioutil.WriteFile("bytesconv_table.go", w.Bytes(), 0660); err != nil {
log.Fatal(err)
}
}
const pre = `package fasthttp
// Code generated by go run bytesconv_table_gen.go; DO NOT EDIT.
// See bytesconv_table_gen.go for more information about these tables.
`
fasthttp-1.31.0/bytesconv_test.go 0000664 0000000 0000000 00000020011 14130360711 0017021 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"fmt"
"net"
"testing"
"time"
"github.com/valyala/bytebufferpool"
)
func TestAppendHTMLEscape(t *testing.T) {
t.Parallel()
testAppendHTMLEscape(t, "", "")
testAppendHTMLEscape(t, "<", "<")
testAppendHTMLEscape(t, "a", "a")
testAppendHTMLEscape(t, `><"''`, "><"''")
testAppendHTMLEscape(t, "foaxxx", "fo<b x='ss'>a</b>xxx")
}
func testAppendHTMLEscape(t *testing.T, s, expectedS string) {
buf := AppendHTMLEscapeBytes(nil, []byte(s))
if string(buf) != expectedS {
t.Fatalf("unexpected html-escaped string %q. Expecting %q. Original string %q", buf, expectedS, s)
}
}
func TestParseIPv4(t *testing.T) {
t.Parallel()
testParseIPv4(t, "0.0.0.0", true)
testParseIPv4(t, "255.255.255.255", true)
testParseIPv4(t, "123.45.67.89", true)
// ipv6 shouldn't work
testParseIPv4(t, "2001:4860:0:2001::68", false)
// invalid ip
testParseIPv4(t, "foobar", false)
testParseIPv4(t, "1.2.3", false)
testParseIPv4(t, "123.456.789.11", false)
}
func testParseIPv4(t *testing.T, ipStr string, isValid bool) {
ip, err := ParseIPv4(nil, []byte(ipStr))
if isValid {
if err != nil {
t.Fatalf("unexpected error when parsing ip %q: %s", ipStr, err)
}
s := string(AppendIPv4(nil, ip))
if s != ipStr {
t.Fatalf("unexpected ip parsed %q. Expecting %q", s, ipStr)
}
} else {
if err == nil {
t.Fatalf("expecting error when parsing ip %q", ipStr)
}
}
}
func TestAppendIPv4(t *testing.T) {
t.Parallel()
testAppendIPv4(t, "0.0.0.0", true)
testAppendIPv4(t, "127.0.0.1", true)
testAppendIPv4(t, "8.8.8.8", true)
testAppendIPv4(t, "123.45.67.89", true)
// ipv6 shouldn't work
testAppendIPv4(t, "2001:4860:0:2001::68", false)
}
func testAppendIPv4(t *testing.T, ipStr string, isValid bool) {
ip := net.ParseIP(ipStr)
if ip == nil {
t.Fatalf("cannot parse ip %q", ipStr)
}
s := string(AppendIPv4(nil, ip))
if isValid {
if s != ipStr {
t.Fatalf("unexpected ip %q. Expecting %q", s, ipStr)
}
} else {
ipStr = "non-v4 ip passed to AppendIPv4"
if s != ipStr {
t.Fatalf("unexpected ip %q. Expecting %q", s, ipStr)
}
}
}
func testAppendUint(t *testing.T, n int) {
expectedS := fmt.Sprintf("%d", n)
s := AppendUint(nil, n)
if string(s) != expectedS {
t.Fatalf("unexpected uint %q. Expecting %q. n=%d", s, expectedS, n)
}
}
func testWriteHexInt(t *testing.T, n int, expectedS string) {
var w bytebufferpool.ByteBuffer
bw := bufio.NewWriter(&w)
if err := writeHexInt(bw, n); err != nil {
t.Fatalf("unexpected error when writing hex %x: %s", n, err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error when flushing hex %x: %s", n, err)
}
s := string(w.B)
if s != expectedS {
t.Fatalf("unexpected hex after writing %q. Expected %q", s, expectedS)
}
}
func TestReadHexIntError(t *testing.T) {
t.Parallel()
testReadHexIntError(t, "")
testReadHexIntError(t, "ZZZ")
testReadHexIntError(t, "-123")
testReadHexIntError(t, "+434")
}
func testReadHexIntError(t *testing.T, s string) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
n, err := readHexInt(br)
if err == nil {
t.Fatalf("expecting error when reading hex int %q", s)
}
if n >= 0 {
t.Fatalf("unexpected hex value read %d for hex int %q. must be negative", n, s)
}
}
func testReadHexIntSuccess(t *testing.T, s string, expectedN int) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
n, err := readHexInt(br)
if err != nil {
t.Fatalf("unexpected error: %s. s=%q", err, s)
}
if n != expectedN {
t.Fatalf("unexpected hex int %d. Expected %d. s=%q", n, expectedN, s)
}
}
func TestAppendHTTPDate(t *testing.T) {
t.Parallel()
d := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
s := string(AppendHTTPDate(nil, d))
expectedS := "Tue, 10 Nov 2009 23:00:00 GMT"
if s != expectedS {
t.Fatalf("unexpected date %q. Expecting %q", s, expectedS)
}
b := []byte("prefix")
s = string(AppendHTTPDate(b, d))
if s[:len(b)] != string(b) {
t.Fatalf("unexpected prefix %q. Expecting %q", s[:len(b)], b)
}
s = s[len(b):]
if s != expectedS {
t.Fatalf("unexpected date %q. Expecting %q", s, expectedS)
}
}
func TestParseUintError(t *testing.T) {
t.Parallel()
// empty string
testParseUintError(t, "")
// negative value
testParseUintError(t, "-123")
// non-num
testParseUintError(t, "foobar234")
// non-num chars at the end
testParseUintError(t, "123w")
// floating point num
testParseUintError(t, "1234.545")
// too big num
testParseUintError(t, "12345678901234567890")
testParseUintError(t, "1234567890123456789012")
}
func TestParseUfloatSuccess(t *testing.T) {
t.Parallel()
testParseUfloatSuccess(t, "0", 0)
testParseUfloatSuccess(t, "1.", 1.)
testParseUfloatSuccess(t, ".1", 0.1)
testParseUfloatSuccess(t, "123.456", 123.456)
testParseUfloatSuccess(t, "123", 123)
testParseUfloatSuccess(t, "1234e2", 1234e2)
testParseUfloatSuccess(t, "1234E-5", 1234e-5)
testParseUfloatSuccess(t, "1.234e+3", 1.234e+3)
}
func TestParseUfloatError(t *testing.T) {
t.Parallel()
// empty num
testParseUfloatError(t, "")
// negative num
testParseUfloatError(t, "-123.53")
// non-num chars
testParseUfloatError(t, "123sdfsd")
testParseUfloatError(t, "sdsf234")
testParseUfloatError(t, "sdfdf")
// non-num chars in exponent
testParseUfloatError(t, "123e3s")
testParseUfloatError(t, "12.3e-op")
testParseUfloatError(t, "123E+SS5")
// duplicate point
testParseUfloatError(t, "1.3.4")
// duplicate exponent
testParseUfloatError(t, "123e5e6")
// missing exponent
testParseUfloatError(t, "123534e")
}
func testParseUfloatError(t *testing.T, s string) {
n, err := ParseUfloat([]byte(s))
if err == nil {
t.Fatalf("Expecting error when parsing %q. obtained %f", s, n)
}
if n >= 0 {
t.Fatalf("Expecting negative num instead of %f when parsing %q", n, s)
}
}
func testParseUfloatSuccess(t *testing.T, s string, expectedF float64) {
f, err := ParseUfloat([]byte(s))
if err != nil {
t.Fatalf("Unexpected error when parsing %q: %s", s, err)
}
delta := f - expectedF
if delta < 0 {
delta = -delta
}
if delta > expectedF*1e-10 {
t.Fatalf("Unexpected value when parsing %q: %f. Expected %f", s, f, expectedF)
}
}
func testParseUintError(t *testing.T, s string) {
n, err := ParseUint([]byte(s))
if err == nil {
t.Fatalf("Expecting error when parsing %q. obtained %d", s, n)
}
if n >= 0 {
t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, s)
}
}
func testParseUintSuccess(t *testing.T, s string, expectedN int) {
n, err := ParseUint([]byte(s))
if err != nil {
t.Fatalf("Unexpected error when parsing %q: %s", s, err)
}
if n != expectedN {
t.Fatalf("Unexpected value %d. Expected %d. num=%q", n, expectedN, s)
}
}
func TestAppendUnquotedArg(t *testing.T) {
t.Parallel()
testAppendUnquotedArg(t, "", "")
testAppendUnquotedArg(t, "abc", "abc")
testAppendUnquotedArg(t, "тест.abc", "тест.abc")
testAppendUnquotedArg(t, "%D1%82%D0%B5%D1%81%D1%82%20%=&;:", "тест %=&;:")
}
func testAppendUnquotedArg(t *testing.T, s, expectedS string) {
// test appending to nil
result := AppendUnquotedArg(nil, []byte(s))
if string(result) != expectedS {
t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS)
}
// test appending to prefix
prefix := "prefix"
dst := []byte(prefix)
dst = AppendUnquotedArg(dst, []byte(s))
if !bytes.HasPrefix(dst, []byte(prefix)) {
t.Fatalf("Unexpected prefix for AppendUnquotedArg(%q)=%q, want %q", s, dst, prefix)
}
result = dst[len(prefix):]
if string(result) != expectedS {
t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS)
}
// test in-place appending
result = []byte(s)
result = AppendUnquotedArg(result[:0], result)
if string(result) != expectedS {
t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS)
}
// verify AppendQuotedArg <-> AppendUnquotedArg conversion
quotedS := AppendQuotedArg(nil, []byte(s))
unquotedS := AppendUnquotedArg(nil, quotedS)
if s != string(unquotedS) {
t.Fatalf("Unexpected AppendUnquotedArg(AppendQuotedArg(%q))=%q, want %q", s, unquotedS, s)
}
}
fasthttp-1.31.0/bytesconv_timing_test.go 0000664 0000000 0000000 00000006632 14130360711 0020405 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"html"
"net"
"testing"
"github.com/valyala/bytebufferpool"
)
func BenchmarkAppendHTMLEscape(b *testing.B) {
sOrig := "foobarbazxxxyyyzzz"
sExpected := string(AppendHTMLEscape(nil, sOrig))
b.RunParallel(func(pb *testing.PB) {
var buf []byte
for pb.Next() {
for i := 0; i < 10; i++ {
buf = AppendHTMLEscape(buf[:0], sOrig)
if string(buf) != sExpected {
b.Fatalf("unexpected escaped string: %s. Expecting %s", buf, sExpected)
}
}
}
})
}
func BenchmarkHTMLEscapeString(b *testing.B) {
sOrig := "foobarbazxxxyyyzzz"
sExpected := html.EscapeString(sOrig)
b.RunParallel(func(pb *testing.PB) {
var s string
for pb.Next() {
for i := 0; i < 10; i++ {
s = html.EscapeString(sOrig)
if s != sExpected {
b.Fatalf("unexpected escaped string: %s. Expecting %s", s, sExpected)
}
}
}
})
}
func BenchmarkParseIPv4(b *testing.B) {
ipStr := []byte("123.145.167.189")
b.RunParallel(func(pb *testing.PB) {
var ip net.IP
var err error
for pb.Next() {
ip, err = ParseIPv4(ip, ipStr)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
}
})
}
func BenchmarkAppendIPv4(b *testing.B) {
ip := net.ParseIP("123.145.167.189")
b.RunParallel(func(pb *testing.PB) {
var buf []byte
for pb.Next() {
buf = AppendIPv4(buf[:0], ip)
}
})
}
func BenchmarkWriteHexInt(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var w bytebufferpool.ByteBuffer
bw := bufio.NewWriter(&w)
i := 0
for pb.Next() {
writeHexInt(bw, i) //nolint:errcheck
i++
if i > 0x7fffffff {
i = 0
}
w.Reset()
bw.Reset(&w)
}
})
}
func BenchmarkParseUint(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
buf := []byte("1234567")
for pb.Next() {
n, err := ParseUint(buf)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if n != 1234567 {
b.Fatalf("unexpected result: %d. Expecting %s", n, buf)
}
}
})
}
func BenchmarkAppendUint(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var buf []byte
i := 0
for pb.Next() {
buf = AppendUint(buf[:0], i)
i++
if i > 0x7fffffff {
i = 0
}
}
})
}
func BenchmarkLowercaseBytesNoop(b *testing.B) {
src := []byte("foobarbaz_lowercased_all")
b.RunParallel(func(pb *testing.PB) {
s := make([]byte, len(src))
for pb.Next() {
copy(s, src)
lowercaseBytes(s)
}
})
}
func BenchmarkLowercaseBytesAll(b *testing.B) {
src := []byte("FOOBARBAZ_UPPERCASED_ALL")
b.RunParallel(func(pb *testing.PB) {
s := make([]byte, len(src))
for pb.Next() {
copy(s, src)
lowercaseBytes(s)
}
})
}
func BenchmarkLowercaseBytesMixed(b *testing.B) {
src := []byte("Foobarbaz_Uppercased_Mix")
b.RunParallel(func(pb *testing.PB) {
s := make([]byte, len(src))
for pb.Next() {
copy(s, src)
lowercaseBytes(s)
}
})
}
func BenchmarkAppendUnquotedArgFastPath(b *testing.B) {
src := []byte("foobarbaz no quoted chars fdskjsdf jklsdfdfskljd;aflskjdsaf fdsklj fsdkj fsdl kfjsdlk jfsdklj fsdfsdf sdfkflsd")
b.RunParallel(func(pb *testing.PB) {
var dst []byte
for pb.Next() {
dst = AppendUnquotedArg(dst[:0], src)
}
})
}
func BenchmarkAppendUnquotedArgSlowPath(b *testing.B) {
src := []byte("D0%B4%20%D0%B0%D0%B2%D0%BB%D0%B4%D1%84%D1%8B%D0%B0%D0%BE%20%D1%84%D0%B2%D0%B6%D0%BB%D0%B4%D1%8B%20%D0%B0%D0%BE")
b.RunParallel(func(pb *testing.PB) {
var dst []byte
for pb.Next() {
dst = AppendUnquotedArg(dst[:0], src)
}
})
}
fasthttp-1.31.0/client.go 0000664 0000000 0000000 00000231046 14130360711 0015240 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// Do performs the given http request and fills the given http response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func Do(req *Request, resp *Response) error {
return defaultClient.Do(req, resp)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
//
// Warning: DoTimeout does not terminate the request itself. The request will
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try using a Client and setting a ReadTimeout.
func DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return defaultClient.DoTimeout(req, resp, timeout)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return defaultClient.DoDeadline(req, resp, deadline)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient)
return err
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return defaultClient.Get(dst, url)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return defaultClient.GetTimeout(dst, url, timeout)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return defaultClient.GetDeadline(dst, url, deadline)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return defaultClient.Post(dst, url, postArgs)
}
var defaultClient Client
// Client implements http client.
//
// Copying Client by value is prohibited. Create new instance instead.
//
// It is safe calling Client methods from concurrently running goroutines.
//
// The fields of a Client should not be changed while it is in use.
type Client struct {
noCopy noCopy //nolint:unused,structcheck
// Client name. Used in User-Agent request header.
//
// Default client name is used if not set.
Name string
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// Callback for establishing new connections to hosts.
//
// Default Dial is used if not set.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// TLS config for https connections.
//
// Default TLS config is used if not set.
TLSConfig *tls.Config
// Maximum number of connections per each host which may be established.
//
// DefaultMaxConnsPerHost is used if not set.
MaxConnsPerHost int
// Idle keep-alive connections are closed after this duration.
//
// By default idle connections are closed
// after DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Keep-alive connections are closed after this duration.
//
// By default connection duration is unlimited.
MaxConnDuration time.Duration
// Maximum number of attempts for idempotent calls
//
// DefaultMaxIdemponentCallAttempts is used if not set.
MaxIdemponentCallAttempts int
// Per-connection buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Per-connection buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
// Maximum response body size.
//
// The client returns ErrBodyTooLarge if this limit is greater than 0
// and response body is greater than the limit.
//
// By default response body size is unlimited.
MaxResponseBodySize int
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// Maximum duration for waiting for a free connection.
//
// By default will not waiting, return ErrNoFreeConns immediately
MaxConnWaitTimeout time.Duration
// RetryIf controls whether a retry should be attempted after an error.
//
// By default will use isIdempotent function
RetryIf RetryIfFunc
mLock sync.Mutex
m map[string]*HostClient
ms map[string]*HostClient
readerPool sync.Pool
writerPool sync.Pool
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func (c *Client) Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return clientGetURL(dst, url, c)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeout(dst, url, timeout, c)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *Client) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return clientPostURL(dst, url, postArgs, c)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
//
// Warning: DoTimeout does not terminate the request itself. The request will
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}
// Do performs the given http request and fills the given http response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) Do(req *Request, resp *Response) error {
uri := req.URI()
if uri == nil {
return ErrorInvalidURI
}
host := uri.Host()
isTLS := false
scheme := uri.Scheme()
if bytes.Equal(scheme, strHTTPS) {
isTLS = true
} else if !bytes.Equal(scheme, strHTTP) {
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
}
startCleaner := false
c.mLock.Lock()
m := c.m
if isTLS {
m = c.ms
}
if m == nil {
m = make(map[string]*HostClient)
if isTLS {
c.ms = m
} else {
c.m = m
}
}
hc := m[string(host)]
if hc == nil {
hc = &HostClient{
Addr: addMissingPort(string(host), isTLS),
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
MaxConns: c.MaxConnsPerHost,
MaxIdleConnDuration: c.MaxIdleConnDuration,
MaxConnDuration: c.MaxConnDuration,
MaxIdemponentCallAttempts: c.MaxIdemponentCallAttempts,
ReadBufferSize: c.ReadBufferSize,
WriteBufferSize: c.WriteBufferSize,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
MaxResponseBodySize: c.MaxResponseBodySize,
DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing,
DisablePathNormalizing: c.DisablePathNormalizing,
MaxConnWaitTimeout: c.MaxConnWaitTimeout,
RetryIf: c.RetryIf,
clientReaderPool: &c.readerPool,
clientWriterPool: &c.writerPool,
}
m[string(host)] = hc
if len(m) == 1 {
startCleaner = true
}
}
c.mLock.Unlock()
if startCleaner {
go c.mCleaner(m)
}
return hc.Do(req, resp)
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in a
// "keep-alive" state. It does not interrupt any connections currently
// in use.
func (c *Client) CloseIdleConnections() {
c.mLock.Lock()
for _, v := range c.m {
v.CloseIdleConnections()
}
for _, v := range c.ms {
v.CloseIdleConnections()
}
c.mLock.Unlock()
}
func (c *Client) mCleaner(m map[string]*HostClient) {
mustStop := false
sleep := c.MaxIdleConnDuration
if sleep < time.Second {
sleep = time.Second
} else if sleep > 10*time.Second {
sleep = 10 * time.Second
}
for {
c.mLock.Lock()
for k, v := range m {
v.connsLock.Lock()
shouldRemove := v.connsCount == 0
v.connsLock.Unlock()
if shouldRemove {
delete(m, k)
}
}
if len(m) == 0 {
mustStop = true
}
c.mLock.Unlock()
if mustStop {
break
}
time.Sleep(sleep)
}
}
// DefaultMaxConnsPerHost is the maximum number of concurrent connections
// http client may establish per host by default (i.e. if
// Client.MaxConnsPerHost isn't set).
const DefaultMaxConnsPerHost = 512
// DefaultMaxIdleConnDuration is the default duration before idle keep-alive
// connection is closed.
const DefaultMaxIdleConnDuration = 10 * time.Second
// DefaultMaxIdemponentCallAttempts is the default idempotent calls attempts count.
const DefaultMaxIdemponentCallAttempts = 5
// DialFunc must establish connection to addr.
//
// There is no need in establishing TLS (SSL) connection for https.
// The client automatically converts connection to TLS
// if HostClient.IsTLS is set.
//
// TCP address passed to DialFunc always contains host and port.
// Example TCP addr values:
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)
// RetryIfFunc signature of retry if function
//
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool
// TransportFunc wraps every request/response.
type TransportFunc func(*Request, *Response) error
// HostClient balances http requests among hosts listed in Addr.
//
// HostClient may be used for balancing load among multiple upstream hosts.
// While multiple addresses passed to HostClient.Addr may be used for balancing
// load among them, it would be better using LBClient instead, since HostClient
// may unevenly balance load among upstream hosts.
//
// It is forbidden copying HostClient instances. Create new instances instead.
//
// It is safe calling HostClient methods from concurrently running goroutines.
type HostClient struct {
noCopy noCopy //nolint:unused,structcheck
// Comma-separated list of upstream HTTP server host addresses,
// which are passed to Dial in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
Addr string
// Client name. Used in User-Agent request header.
Name string
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// Callback for establishing new connection to the host.
//
// Default Dial is used if not set.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// Whether to use TLS (aka SSL or HTTPS) for host connections.
IsTLS bool
// Optional TLS config.
TLSConfig *tls.Config
// Maximum number of connections which may be established to all hosts
// listed in Addr.
//
// You can change this value while the HostClient is being used
// using HostClient.SetMaxConns(value)
//
// DefaultMaxConnsPerHost is used if not set.
MaxConns int
// Keep-alive connections are closed after this duration.
//
// By default connection duration is unlimited.
MaxConnDuration time.Duration
// Idle keep-alive connections are closed after this duration.
//
// By default idle connections are closed
// after DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Maximum number of attempts for idempotent calls
//
// DefaultMaxIdemponentCallAttempts is used if not set.
MaxIdemponentCallAttempts int
// Per-connection buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Per-connection buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
// Maximum response body size.
//
// The client returns ErrBodyTooLarge if this limit is greater than 0
// and response body is greater than the limit.
//
// By default response body size is unlimited.
MaxResponseBodySize int
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// Will not log potentially sensitive content in error logs
//
// This option is useful for servers that handle sensitive data
// in the request/response.
//
// Client logs full errors by default.
SecureErrorLogMessage bool
// Maximum duration for waiting for a free connection.
//
// By default will not waiting, return ErrNoFreeConns immediately
MaxConnWaitTimeout time.Duration
// RetryIf controls whether a retry should be attempted after an error.
//
// By default will use isIdempotent function
RetryIf RetryIfFunc
// Transport defines a transport-like mechanism that wraps every request/response.
Transport TransportFunc
clientName atomic.Value
lastUseTime uint32
connsLock sync.Mutex
connsCount int
conns []*clientConn
connsWait *wantConnQueue
addrsLock sync.Mutex
addrs []string
addrIdx uint32
tlsConfigMap map[string]*tls.Config
tlsConfigMapLock sync.Mutex
readerPool sync.Pool
writerPool sync.Pool
clientReaderPool *sync.Pool
clientWriterPool *sync.Pool
pendingRequests int32
connsCleanerRun bool
}
type clientConn struct {
c net.Conn
createdTime time.Time
lastUseTime time.Time
}
var startTimeUnix = time.Now().Unix()
// LastUseTime returns time the client was last used
func (c *HostClient) LastUseTime() time.Time {
n := atomic.LoadUint32(&c.lastUseTime)
return time.Unix(startTimeUnix+int64(n), 0)
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func (c *HostClient) Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return clientGetURL(dst, url, c)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeout(dst, url, timeout, c)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *HostClient) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return clientPostURL(dst, url, postArgs, c)
}
type clientDoer interface {
Do(req *Request, resp *Response) error
}
func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
ReleaseRequest(req)
return statusCode, body, err
}
func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) {
deadline := time.Now().Add(timeout)
return clientGetURLDeadline(dst, url, deadline, c)
}
type clientURLResponse struct {
statusCode int
body []byte
err error
}
func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) {
timeout := -time.Since(deadline)
if timeout <= 0 {
return 0, dst, ErrTimeout
}
var ch chan clientURLResponse
chv := clientURLResponseChPool.Get()
if chv == nil {
chv = make(chan clientURLResponse, 1)
}
ch = chv.(chan clientURLResponse)
// Note that the request continues execution on ErrTimeout until
// client-specific ReadTimeout exceeds. This helps limiting load
// on slow hosts by MaxConns* concurrent requests.
//
// Without this 'hack' the load on slow host could exceed MaxConns*
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.
var mu sync.Mutex
var timedout, responded bool
go func() {
req := AcquireRequest()
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c)
mu.Lock()
{
if !timedout {
ch <- clientURLResponse{
statusCode: statusCodeCopy,
body: bodyCopy,
err: errCopy,
}
responded = true
}
}
mu.Unlock()
ReleaseRequest(req)
}()
tc := AcquireTimer(timeout)
select {
case resp := <-ch:
statusCode = resp.statusCode
body = resp.body
err = resp.err
case <-tc.C:
mu.Lock()
{
if responded {
resp := <-ch
statusCode = resp.statusCode
body = resp.body
err = resp.err
} else {
timedout = true
err = ErrTimeout
body = dst
}
}
mu.Unlock()
}
ReleaseTimer(tc)
clientURLResponseChPool.Put(chv)
return statusCode, body, err
}
var clientURLResponseChPool sync.Pool
func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
req.Header.SetMethod(MethodPost)
req.Header.SetContentTypeBytes(strPostArgsContentType)
if postArgs != nil {
if _, err := postArgs.WriteTo(req.BodyWriter()); err != nil {
return 0, nil, err
}
}
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
ReleaseRequest(req)
return statusCode, body, err
}
var (
// ErrMissingLocation is returned by clients when the Location header is missing on
// an HTTP response with a redirect status code.
ErrMissingLocation = errors.New("missing Location header for http redirect")
// ErrTooManyRedirects is returned by clients when the number of redirects followed
// exceed the max count.
ErrTooManyRedirects = errors.New("too many redirects detected when doing the request")
// HostClients are only able to follow redirects to the same protocol.
ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol, please use Client instead")
)
const defaultMaxRedirectsCount = 16
func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
resp := AcquireResponse()
bodyBuf := resp.bodyBuffer()
resp.keepBodyBuffer = true
oldBody := bodyBuf.B
bodyBuf.B = dst
statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c)
body = bodyBuf.B
bodyBuf.B = oldBody
resp.keepBodyBuffer = false
ReleaseResponse(resp)
return statusCode, body, err
}
func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) {
redirectsCount := 0
for {
req.SetRequestURI(url)
if err := req.parseURI(); err != nil {
return 0, nil, err
}
if err = c.Do(req, resp); err != nil {
break
}
statusCode = resp.Header.StatusCode()
if !StatusCodeIsRedirect(statusCode) {
break
}
redirectsCount++
if redirectsCount > maxRedirectsCount {
err = ErrTooManyRedirects
break
}
location := resp.Header.peek(strLocation)
if len(location) == 0 {
err = ErrMissingLocation
break
}
url = getRedirectURL(url, location)
}
return statusCode, body, err
}
func getRedirectURL(baseURL string, location []byte) string {
u := AcquireURI()
u.Update(baseURL)
u.UpdateBytes(location)
redirectURL := u.String()
ReleaseURI(u)
return redirectURL
}
// StatusCodeIsRedirect returns true if the status code indicates a redirect.
func StatusCodeIsRedirect(statusCode int) bool {
return statusCode == StatusMovedPermanently ||
statusCode == StatusFound ||
statusCode == StatusSeeOther ||
statusCode == StatusTemporaryRedirect ||
statusCode == StatusPermanentRedirect
}
var (
requestPool sync.Pool
responsePool sync.Pool
)
// AcquireRequest returns an empty Request instance from request pool.
//
// The returned Request instance may be passed to ReleaseRequest when it is
// no longer needed. This allows Request recycling, reduces GC pressure
// and usually improves performance.
func AcquireRequest() *Request {
v := requestPool.Get()
if v == nil {
return &Request{}
}
return v.(*Request)
}
// ReleaseRequest returns req acquired via AcquireRequest to request pool.
//
// It is forbidden accessing req and/or its' members after returning
// it to request pool.
func ReleaseRequest(req *Request) {
req.Reset()
requestPool.Put(req)
}
// AcquireResponse returns an empty Response instance from response pool.
//
// The returned Response instance may be passed to ReleaseResponse when it is
// no longer needed. This allows Response recycling, reduces GC pressure
// and usually improves performance.
func AcquireResponse() *Response {
v := responsePool.Get()
if v == nil {
return &Response{}
}
return v.(*Response)
}
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
//
// It is forbidden accessing resp and/or its' members after returning
// it to response pool.
func ReleaseResponse(resp *Response) {
resp.Reset()
responsePool.Put(resp)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
//
// Warning: DoTimeout does not terminate the request itself. The request will
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return clientDoTimeout(req, resp, timeout, c)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return clientDoDeadline(req, resp, deadline, c)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}
func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error {
deadline := time.Now().Add(timeout)
return clientDoDeadline(req, resp, deadline, c)
}
func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error {
timeout := -time.Since(deadline)
if timeout <= 0 {
return ErrTimeout
}
var ch chan error
chv := errorChPool.Get()
if chv == nil {
chv = make(chan error, 1)
}
ch = chv.(chan error)
// Make req and resp copies, since on timeout they no longer
// may be accessed.
reqCopy := AcquireRequest()
req.copyToSkipBody(reqCopy)
swapRequestBody(req, reqCopy)
respCopy := AcquireResponse()
if resp != nil {
// Not calling resp.copyToSkipBody(respCopy) here to avoid
// unexpected messing with headers
respCopy.SkipBody = resp.SkipBody
}
// Note that the request continues execution on ErrTimeout until
// client-specific ReadTimeout exceeds. This helps limiting load
// on slow hosts by MaxConns* concurrent requests.
//
// Without this 'hack' the load on slow host could exceed MaxConns*
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.
var mu sync.Mutex
var timedout, responded bool
go func() {
reqCopy.timeout = timeout
errDo := c.Do(reqCopy, respCopy)
mu.Lock()
{
if !timedout {
if resp != nil {
respCopy.copyToSkipBody(resp)
swapResponseBody(resp, respCopy)
}
swapRequestBody(reqCopy, req)
ch <- errDo
responded = true
}
}
mu.Unlock()
ReleaseResponse(respCopy)
ReleaseRequest(reqCopy)
}()
tc := AcquireTimer(timeout)
var err error
select {
case err = <-ch:
case <-tc.C:
mu.Lock()
{
if responded {
err = <-ch
} else {
timedout = true
err = ErrTimeout
}
}
mu.Unlock()
}
ReleaseTimer(tc)
errorChPool.Put(chv)
return err
}
var errorChPool sync.Pool
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) Do(req *Request, resp *Response) error {
var err error
var retry bool
maxAttempts := c.MaxIdemponentCallAttempts
if maxAttempts <= 0 {
maxAttempts = DefaultMaxIdemponentCallAttempts
}
isRequestRetryable := isIdempotent
if c.RetryIf != nil {
isRequestRetryable = c.RetryIf
}
attempts := 0
hasBodyStream := req.IsBodyStream()
atomic.AddInt32(&c.pendingRequests, 1)
for {
retry, err = c.do(req, resp)
if err == nil || !retry {
break
}
if hasBodyStream {
break
}
if !isRequestRetryable(req) {
// Retry non-idempotent requests if the server closes
// the connection before sending the response.
//
// This case is possible if the server closes the idle
// keep-alive connection on timeout.
//
// Apache and nginx usually do this.
if err != io.EOF {
break
}
}
attempts++
if attempts >= maxAttempts {
break
}
}
atomic.AddInt32(&c.pendingRequests, -1)
if err == io.EOF {
err = ErrConnectionClosed
}
return err
}
// PendingRequests returns the current number of requests the client
// is executing.
//
// This function may be used for balancing load among multiple HostClient
// instances.
func (c *HostClient) PendingRequests() int {
return int(atomic.LoadInt32(&c.pendingRequests))
}
func isIdempotent(req *Request) bool {
return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut()
}
func (c *HostClient) do(req *Request, resp *Response) (bool, error) {
nilResp := false
if resp == nil {
nilResp = true
resp = AcquireResponse()
}
ok, err := c.doNonNilReqResp(req, resp)
if nilResp {
ReleaseResponse(resp)
}
return ok, err
}
func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) {
if req == nil {
panic("BUG: req cannot be nil")
}
if resp == nil {
panic("BUG: resp cannot be nil")
}
// Secure header error logs configuration
resp.secureErrorLogMessage = c.SecureErrorLogMessage
resp.Header.secureErrorLogMessage = c.SecureErrorLogMessage
req.secureErrorLogMessage = c.SecureErrorLogMessage
req.Header.secureErrorLogMessage = c.SecureErrorLogMessage
if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) {
return false, ErrHostClientRedirectToDifferentScheme
}
atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix))
// Free up resources occupied by response before sending the request,
// so the GC may reclaim these resources (e.g. response body).
// backing up SkipBody in case it was set explicitly
customSkipBody := resp.SkipBody
resp.Reset()
resp.SkipBody = customSkipBody
req.URI().DisablePathNormalizing = c.DisablePathNormalizing
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}
if c.Transport != nil {
err := c.Transport(req, resp)
return err == nil, err
}
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
}
conn := cc.c
resp.parseNetConn(conn)
if c.WriteTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil {
c.closeConn(cc)
return true, err
}
}
resetConnection := false
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}
bw := c.acquireWriter(conn)
err = req.Write(bw)
if resetConnection {
req.Header.ResetConnectionClose()
}
if err == nil {
err = bw.Flush()
}
if err != nil {
c.releaseWriter(bw)
c.closeConn(cc)
return true, err
}
c.releaseWriter(bw)
if c.ReadTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil {
c.closeConn(cc)
return true, err
}
}
if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if c.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}
br := c.acquireReader(conn)
if err = resp.ReadLimitBody(br, c.MaxResponseBodySize); err != nil {
c.releaseReader(br)
c.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
retry := err != ErrBodyTooLarge
return retry, err
}
c.releaseReader(br)
if resetConnection || req.ConnectionClose() || resp.ConnectionClose() {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return false, err
}
var (
// ErrNoFreeConns is returned when no free connections available
// to the given host.
//
// Increase the allowed number of connections per host if you
// see this error.
ErrNoFreeConns = errors.New("no free connections available to host")
// ErrConnectionClosed may be returned from client methods if the server
// closes connection before returning the first response byte.
//
// If you see this error, then either fix the server by returning
// 'Connection: close' response header before closing the connection
// or add 'Connection: close' request header before sending requests
// to broken server.
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
"Make sure the server returns 'Connection: close' response header before closing the connection")
)
type timeoutError struct{}
func (e *timeoutError) Error() string {
return "timeout"
}
// Only implement the Timeout() function of the net.Error interface.
// This allows for checks like:
//
// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
func (e *timeoutError) Timeout() bool {
return true
}
// ErrTimeout is returned from timed out calls.
var ErrTimeout = &timeoutError{}
// SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr.
func (c *HostClient) SetMaxConns(newMaxConns int) {
c.connsLock.Lock()
c.MaxConns = newMaxConns
c.connsLock.Unlock()
}
func (c *HostClient) acquireConn(reqTimeout time.Duration, connectionClose bool) (cc *clientConn, err error) {
createConn := false
startCleaner := false
var n int
c.connsLock.Lock()
n = len(c.conns)
if n == 0 {
maxConns := c.MaxConns
if maxConns <= 0 {
maxConns = DefaultMaxConnsPerHost
}
if c.connsCount < maxConns {
c.connsCount++
createConn = true
if !c.connsCleanerRun && !connectionClose {
startCleaner = true
c.connsCleanerRun = true
}
}
} else {
n--
cc = c.conns[n]
c.conns[n] = nil
c.conns = c.conns[:n]
}
c.connsLock.Unlock()
if cc != nil {
return cc, nil
}
if !createConn {
if c.MaxConnWaitTimeout <= 0 {
return nil, ErrNoFreeConns
}
// reqTimeout c.MaxConnWaitTimeout wait duration
// d1 d2 min(d1, d2)
// 0(not set) d2 d2
// d1 0(don't wait) 0(don't wait)
// 0(not set) d2 d2
timeout := c.MaxConnWaitTimeout
timeoutOverridden := false
// reqTimeout == 0 means not set
if reqTimeout > 0 && reqTimeout < timeout {
timeout = reqTimeout
timeoutOverridden = true
}
// wait for a free connection
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)
w := &wantConn{
ready: make(chan struct{}, 1),
}
defer func() {
if err != nil {
w.cancel(c, err)
}
}()
c.queueForIdle(w)
select {
case <-w.ready:
return w.conn, w.err
case <-tc.C:
if timeoutOverridden {
return nil, ErrTimeout
}
return nil, ErrNoFreeConns
}
}
if startCleaner {
go c.connsCleaner()
}
conn, err := c.dialHostHard()
if err != nil {
c.decConnsCount()
return nil, err
}
cc = acquireClientConn(conn)
return cc, nil
}
func (c *HostClient) queueForIdle(w *wantConn) {
c.connsLock.Lock()
defer c.connsLock.Unlock()
if c.connsWait == nil {
c.connsWait = &wantConnQueue{}
}
c.connsWait.clearFront()
c.connsWait.pushBack(w)
}
func (c *HostClient) dialConnFor(w *wantConn) {
conn, err := c.dialHostHard()
if err != nil {
w.tryDeliver(nil, err)
c.decConnsCount()
return
}
cc := acquireClientConn(conn)
delivered := w.tryDeliver(cc, nil)
if !delivered {
// not delivered, return idle connection
c.releaseConn(cc)
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in a
// "keep-alive" state. It does not interrupt any connections currently
// in use.
func (c *HostClient) CloseIdleConnections() {
c.connsLock.Lock()
scratch := append([]*clientConn{}, c.conns...)
for i := range c.conns {
c.conns[i] = nil
}
c.conns = c.conns[:0]
c.connsLock.Unlock()
for _, cc := range scratch {
c.closeConn(cc)
}
}
func (c *HostClient) connsCleaner() {
var (
scratch []*clientConn
maxIdleConnDuration = c.MaxIdleConnDuration
)
if maxIdleConnDuration <= 0 {
maxIdleConnDuration = DefaultMaxIdleConnDuration
}
for {
currentTime := time.Now()
// Determine idle connections to be closed.
c.connsLock.Lock()
conns := c.conns
n := len(conns)
i := 0
for i < n && currentTime.Sub(conns[i].lastUseTime) > maxIdleConnDuration {
i++
}
sleepFor := maxIdleConnDuration
if i < n {
// + 1 so we actually sleep past the expiration time and not up to it.
// Otherwise the > check above would still fail.
sleepFor = maxIdleConnDuration - currentTime.Sub(conns[i].lastUseTime) + 1
}
scratch = append(scratch[:0], conns[:i]...)
if i > 0 {
m := copy(conns, conns[i:])
for i = m; i < n; i++ {
conns[i] = nil
}
c.conns = conns[:m]
}
c.connsLock.Unlock()
// Close idle connections.
for i, cc := range scratch {
c.closeConn(cc)
scratch[i] = nil
}
// Determine whether to stop the connsCleaner.
c.connsLock.Lock()
mustStop := c.connsCount == 0
if mustStop {
c.connsCleanerRun = false
}
c.connsLock.Unlock()
if mustStop {
break
}
time.Sleep(sleepFor)
}
}
func (c *HostClient) closeConn(cc *clientConn) {
c.decConnsCount()
cc.c.Close()
releaseClientConn(cc)
}
func (c *HostClient) decConnsCount() {
if c.MaxConnWaitTimeout <= 0 {
c.connsLock.Lock()
c.connsCount--
c.connsLock.Unlock()
return
}
c.connsLock.Lock()
defer c.connsLock.Unlock()
dialed := false
if q := c.connsWait; q != nil && q.len() > 0 {
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
go c.dialConnFor(w)
dialed = true
break
}
}
}
if !dialed {
c.connsCount--
}
}
// ConnsCount returns connection count of HostClient
func (c *HostClient) ConnsCount() int {
c.connsLock.Lock()
defer c.connsLock.Unlock()
return c.connsCount
}
func acquireClientConn(conn net.Conn) *clientConn {
v := clientConnPool.Get()
if v == nil {
v = &clientConn{}
}
cc := v.(*clientConn)
cc.c = conn
cc.createdTime = time.Now()
return cc
}
func releaseClientConn(cc *clientConn) {
// Reset all fields.
*cc = clientConn{}
clientConnPool.Put(cc)
}
var clientConnPool sync.Pool
func (c *HostClient) releaseConn(cc *clientConn) {
cc.lastUseTime = time.Now()
if c.MaxConnWaitTimeout <= 0 {
c.connsLock.Lock()
c.conns = append(c.conns, cc)
c.connsLock.Unlock()
return
}
// try to deliver an idle connection to a *wantConn
c.connsLock.Lock()
defer c.connsLock.Unlock()
delivered := false
if q := c.connsWait; q != nil && q.len() > 0 {
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
delivered = w.tryDeliver(cc, nil)
break
}
}
}
if !delivered {
c.conns = append(c.conns, cc)
}
}
func (c *HostClient) acquireWriter(conn net.Conn) *bufio.Writer {
var v interface{}
if c.clientWriterPool != nil {
v = c.clientWriterPool.Get()
if v == nil {
n := c.WriteBufferSize
if n <= 0 {
n = defaultWriteBufferSize
}
return bufio.NewWriterSize(conn, n)
}
} else {
v = c.writerPool.Get()
if v == nil {
n := c.WriteBufferSize
if n <= 0 {
n = defaultWriteBufferSize
}
return bufio.NewWriterSize(conn, n)
}
}
bw := v.(*bufio.Writer)
bw.Reset(conn)
return bw
}
func (c *HostClient) releaseWriter(bw *bufio.Writer) {
if c.clientWriterPool != nil {
c.clientWriterPool.Put(bw)
} else {
c.writerPool.Put(bw)
}
}
func (c *HostClient) acquireReader(conn net.Conn) *bufio.Reader {
var v interface{}
if c.clientReaderPool != nil {
v = c.clientReaderPool.Get()
if v == nil {
n := c.ReadBufferSize
if n <= 0 {
n = defaultReadBufferSize
}
return bufio.NewReaderSize(conn, n)
}
} else {
v = c.readerPool.Get()
if v == nil {
n := c.ReadBufferSize
if n <= 0 {
n = defaultReadBufferSize
}
return bufio.NewReaderSize(conn, n)
}
}
br := v.(*bufio.Reader)
br.Reset(conn)
return br
}
func (c *HostClient) releaseReader(br *bufio.Reader) {
if c.clientReaderPool != nil {
c.clientReaderPool.Put(br)
} else {
c.readerPool.Put(br)
}
}
func newClientTLSConfig(c *tls.Config, addr string) *tls.Config {
if c == nil {
c = &tls.Config{}
} else {
c = c.Clone()
}
if c.ClientSessionCache == nil {
c.ClientSessionCache = tls.NewLRUClientSessionCache(0)
}
if len(c.ServerName) == 0 {
serverName := tlsServerName(addr)
if serverName == "*" {
c.InsecureSkipVerify = true
} else {
c.ServerName = serverName
}
}
return c
}
func tlsServerName(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "*"
}
return host
}
func (c *HostClient) nextAddr() string {
c.addrsLock.Lock()
if c.addrs == nil {
c.addrs = strings.Split(c.Addr, ",")
}
addr := c.addrs[0]
if len(c.addrs) > 1 {
addr = c.addrs[c.addrIdx%uint32(len(c.addrs))]
c.addrIdx++
}
c.addrsLock.Unlock()
return addr
}
func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
// attempt to dial all the available hosts before giving up.
c.addrsLock.Lock()
n := len(c.addrs)
c.addrsLock.Unlock()
if n == 0 {
// It looks like c.addrs isn't initialized yet.
n = 1
}
timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
}
deadline := time.Now().Add(timeout)
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err == nil {
return conn, nil
}
if time.Since(deadline) >= 0 {
break
}
n--
}
return nil, err
}
func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
if !c.IsTLS {
return nil
}
c.tlsConfigMapLock.Lock()
if c.tlsConfigMap == nil {
c.tlsConfigMap = make(map[string]*tls.Config)
}
cfg := c.tlsConfigMap[addr]
if cfg == nil {
cfg = newClientTLSConfig(c.TLSConfig, addr)
c.tlsConfigMap[addr] = cfg
}
c.tlsConfigMapLock.Unlock()
return cfg
}
// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake.
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")
var timeoutErrorChPool sync.Pool
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)
var ch chan error
chv := timeoutErrorChPool.Get()
if chv == nil {
chv = make(chan error)
}
ch = chv.(chan error)
defer timeoutErrorChPool.Put(chv)
conn := tls.Client(rawConn, tlsConfig)
go func() {
ch <- conn.Handshake()
}()
select {
case <-tc.C:
rawConn.Close()
<-ch
return nil, ErrTLSHandshakeTimeout
case err := <-ch:
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
}
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
if dial == nil {
if dialDualStack {
dial = DialDualStack
} else {
dial = Dial
}
addr = addMissingPort(addr, isTLS)
}
conn, err := dial(addr)
if err != nil {
return nil, err
}
if conn == nil {
panic("BUG: DialFunc returned (nil, nil)")
}
_, isTLSAlready := conn.(*tls.Conn)
if isTLS && !isTLSAlready {
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, timeout)
}
return conn, nil
}
func (c *HostClient) getClientName() []byte {
v := c.clientName.Load()
var clientName []byte
if v == nil {
clientName = []byte(c.Name)
if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
clientName = defaultUserAgent
}
c.clientName.Store(clientName)
} else {
clientName = v.([]byte)
}
return clientName
}
func addMissingPort(addr string, isTLS bool) string {
n := strings.Index(addr, ":")
if n >= 0 {
return addr
}
port := 80
if isTLS {
port = 443
}
return net.JoinHostPort(addr, strconv.Itoa(port))
}
// A wantConn records state about a wanted connection
// (that is, an active call to getConn).
// The conn may be gotten by dialing or by finding an idle connection,
// or a cancellation may make the conn no longer wanted.
// These three options are racing against each other and use
// wantConn to coordinate and agree about the winning outcome.
//
// inspired by net/http/transport.go
type wantConn struct {
ready chan struct{}
mu sync.Mutex // protects conn, err, close(ready)
conn *clientConn
err error
}
// waiting reports whether w is still waiting for an answer (connection or error).
func (w *wantConn) waiting() bool {
select {
case <-w.ready:
return false
default:
return true
}
}
// tryDeliver attempts to deliver conn, err to w and reports whether it succeeded.
func (w *wantConn) tryDeliver(conn *clientConn, err error) bool {
w.mu.Lock()
defer w.mu.Unlock()
if w.conn != nil || w.err != nil {
return false
}
w.conn = conn
w.err = err
if w.conn == nil && w.err == nil {
panic("fasthttp: internal error: misuse of tryDeliver")
}
close(w.ready)
return true
}
// cancel marks w as no longer wanting a result (for example, due to cancellation).
// If a connection has been delivered already, cancel returns it with c.releaseConn.
func (w *wantConn) cancel(c *HostClient, err error) {
w.mu.Lock()
if w.conn == nil && w.err == nil {
close(w.ready) // catch misbehavior in future delivery
}
conn := w.conn
w.conn = nil
w.err = err
w.mu.Unlock()
if conn != nil {
c.releaseConn(conn)
}
}
// A wantConnQueue is a queue of wantConns.
//
// inspired by net/http/transport.go
type wantConnQueue struct {
// This is a queue, not a deque.
// It is split into two stages - head[headPos:] and tail.
// popFront is trivial (headPos++) on the first stage, and
// pushBack is trivial (append) on the second stage.
// If the first stage is empty, popFront can swap the
// first and second stages to remedy the situation.
//
// This two-stage split is analogous to the use of two lists
// in Okasaki's purely functional queue but without the
// overhead of reversing the list when swapping stages.
head []*wantConn
headPos int
tail []*wantConn
}
// len returns the number of items in the queue.
func (q *wantConnQueue) len() int {
return len(q.head) - q.headPos + len(q.tail)
}
// pushBack adds w to the back of the queue.
func (q *wantConnQueue) pushBack(w *wantConn) {
q.tail = append(q.tail, w)
}
// popFront removes and returns the wantConn at the front of the queue.
func (q *wantConnQueue) popFront() *wantConn {
if q.headPos >= len(q.head) {
if len(q.tail) == 0 {
return nil
}
// Pick up tail as new head, clear tail.
q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
}
w := q.head[q.headPos]
q.head[q.headPos] = nil
q.headPos++
return w
}
// peekFront returns the wantConn at the front of the queue without removing it.
func (q *wantConnQueue) peekFront() *wantConn {
if q.headPos < len(q.head) {
return q.head[q.headPos]
}
if len(q.tail) > 0 {
return q.tail[0]
}
return nil
}
// cleanFront pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
func (q *wantConnQueue) clearFront() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
return cleaned
}
q.popFront()
cleaned = true
}
}
// PipelineClient pipelines requests over a limited set of concurrent
// connections to the given Addr.
//
// This client may be used in highly loaded HTTP-based RPC systems for reducing
// context switches and network level overhead.
// See https://en.wikipedia.org/wiki/HTTP_pipelining for details.
//
// It is forbidden copying PipelineClient instances. Create new instances
// instead.
//
// It is safe calling PipelineClient methods from concurrently running
// goroutines.
type PipelineClient struct {
noCopy noCopy //nolint:unused,structcheck
// Address of the host to connect to.
Addr string
// PipelineClient name. Used in User-Agent request header.
Name string
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// The maximum number of concurrent connections to the Addr.
//
// A single connection is used by default.
MaxConns int
// The maximum number of pending pipelined requests over
// a single connection to Addr.
//
// DefaultMaxPendingRequests is used by default.
MaxPendingRequests int
// The maximum delay before sending pipelined requests as a batch
// to the server.
//
// By default requests are sent immediately to the server.
MaxBatchDelay time.Duration
// Callback for connection establishing to the host.
//
// Default Dial is used if not set.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// Response header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// Whether to use TLS (aka SSL or HTTPS) for host connections.
IsTLS bool
// Optional TLS config.
TLSConfig *tls.Config
// Idle connection to the host is closed after this duration.
//
// By default idle connection is closed after
// DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
// Logger for logging client errors.
//
// By default standard logger from log package is used.
Logger Logger
connClients []*pipelineConnClient
connClientsLock sync.Mutex
}
type pipelineConnClient struct {
noCopy noCopy //nolint:unused,structcheck
Addr string
Name string
NoDefaultUserAgentHeader bool
MaxPendingRequests int
MaxBatchDelay time.Duration
Dial DialFunc
DialDualStack bool
DisableHeaderNamesNormalizing bool
DisablePathNormalizing bool
IsTLS bool
TLSConfig *tls.Config
MaxIdleConnDuration time.Duration
ReadBufferSize int
WriteBufferSize int
ReadTimeout time.Duration
WriteTimeout time.Duration
Logger Logger
workPool sync.Pool
chLock sync.Mutex
chW chan *pipelineWork
chR chan *pipelineWork
tlsConfigLock sync.Mutex
tlsConfig *tls.Config
clientName atomic.Value
}
type pipelineWork struct {
reqCopy Request
respCopy Response
req *Request
resp *Response
t *time.Timer
deadline time.Time
err error
done chan struct{}
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
//
// Warning: DoTimeout does not terminate the request itself. The request will
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *PipelineClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return c.DoDeadline(req, resp, time.Now().Add(timeout))
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return c.getConnClient().DoDeadline(req, resp, deadline)
}
func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
c.init()
timeout := -time.Since(deadline)
if timeout < 0 {
return ErrTimeout
}
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}
w := acquirePipelineWork(&c.workPool, timeout)
w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.req = &w.reqCopy
w.resp = &w.respCopy
// Make a copy of the request in order to avoid data races on timeouts
req.copyToSkipBody(&w.reqCopy)
swapRequestBody(req, &w.reqCopy)
// Put the request to outgoing queue
select {
case c.chW <- w:
// Fast path: len(c.ch) < cap(c.ch)
default:
// Slow path
select {
case c.chW <- w:
case <-w.t.C:
releasePipelineWork(&c.workPool, w)
return ErrTimeout
}
}
// Wait for the response
var err error
select {
case <-w.done:
if resp != nil {
w.respCopy.copyToSkipBody(resp)
swapResponseBody(resp, &w.respCopy)
}
err = w.err
releasePipelineWork(&c.workPool, w)
case <-w.t.C:
err = ErrTimeout
}
return err
}
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *PipelineClient) Do(req *Request, resp *Response) error {
return c.getConnClient().Do(req, resp)
}
func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
c.init()
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}
w := acquirePipelineWork(&c.workPool, 0)
w.req = req
if resp != nil {
resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.resp = resp
} else {
w.resp = &w.respCopy
}
// Put the request to outgoing queue
select {
case c.chW <- w:
default:
// Try substituting the oldest w with the current one.
select {
case wOld := <-c.chW:
wOld.err = ErrPipelineOverflow
wOld.done <- struct{}{}
default:
}
select {
case c.chW <- w:
default:
releasePipelineWork(&c.workPool, w)
return ErrPipelineOverflow
}
}
// Wait for the response
<-w.done
err := w.err
releasePipelineWork(&c.workPool, w)
return err
}
func (c *PipelineClient) getConnClient() *pipelineConnClient {
c.connClientsLock.Lock()
cc := c.getConnClientUnlocked()
c.connClientsLock.Unlock()
return cc
}
func (c *PipelineClient) getConnClientUnlocked() *pipelineConnClient {
if len(c.connClients) == 0 {
return c.newConnClient()
}
// Return the client with the minimum number of pending requests.
minCC := c.connClients[0]
minReqs := minCC.PendingRequests()
if minReqs == 0 {
return minCC
}
for i := 1; i < len(c.connClients); i++ {
cc := c.connClients[i]
reqs := cc.PendingRequests()
if reqs == 0 {
return cc
}
if reqs < minReqs {
minCC = cc
minReqs = reqs
}
}
maxConns := c.MaxConns
if maxConns <= 0 {
maxConns = 1
}
if len(c.connClients) < maxConns {
return c.newConnClient()
}
return minCC
}
func (c *PipelineClient) newConnClient() *pipelineConnClient {
cc := &pipelineConnClient{
Addr: c.Addr,
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
MaxPendingRequests: c.MaxPendingRequests,
MaxBatchDelay: c.MaxBatchDelay,
Dial: c.Dial,
DialDualStack: c.DialDualStack,
DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing,
DisablePathNormalizing: c.DisablePathNormalizing,
IsTLS: c.IsTLS,
TLSConfig: c.TLSConfig,
MaxIdleConnDuration: c.MaxIdleConnDuration,
ReadBufferSize: c.ReadBufferSize,
WriteBufferSize: c.WriteBufferSize,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
Logger: c.Logger,
}
c.connClients = append(c.connClients, cc)
return cc
}
// ErrPipelineOverflow may be returned from PipelineClient.Do*
// if the requests' queue is overflown.
var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflown. Increase MaxConns and/or MaxPendingRequests")
// DefaultMaxPendingRequests is the default value
// for PipelineClient.MaxPendingRequests.
const DefaultMaxPendingRequests = 1024
func (c *pipelineConnClient) init() {
c.chLock.Lock()
if c.chR == nil {
maxPendingRequests := c.MaxPendingRequests
if maxPendingRequests <= 0 {
maxPendingRequests = DefaultMaxPendingRequests
}
c.chR = make(chan *pipelineWork, maxPendingRequests)
if c.chW == nil {
c.chW = make(chan *pipelineWork, maxPendingRequests)
}
go func() {
// Keep restarting the worker if it fails (connection errors for example).
for {
if err := c.worker(); err != nil {
c.logger().Printf("error in PipelineClient(%q): %s", c.Addr, err)
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
// Throttle client reconnections on temporary errors
time.Sleep(time.Second)
}
} else {
c.chLock.Lock()
stop := len(c.chR) == 0 && len(c.chW) == 0
if !stop {
c.chR = nil
c.chW = nil
}
c.chLock.Unlock()
if stop {
break
}
}
}
}()
}
c.chLock.Unlock()
}
func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err != nil {
return err
}
// Start reader and writer
stopW := make(chan struct{})
doneW := make(chan error)
go func() {
doneW <- c.writer(conn, stopW)
}()
stopR := make(chan struct{})
doneR := make(chan error)
go func() {
doneR <- c.reader(conn, stopR)
}()
// Wait until reader and writer are stopped
select {
case err = <-doneW:
conn.Close()
close(stopR)
<-doneR
case err = <-doneR:
conn.Close()
close(stopW)
<-doneW
}
// Notify pending readers
for len(c.chR) > 0 {
w := <-c.chR
w.err = errPipelineConnStopped
w.done <- struct{}{}
}
return err
}
func (c *pipelineConnClient) cachedTLSConfig() *tls.Config {
if !c.IsTLS {
return nil
}
c.tlsConfigLock.Lock()
cfg := c.tlsConfig
if cfg == nil {
cfg = newClientTLSConfig(c.TLSConfig, c.Addr)
c.tlsConfig = cfg
}
c.tlsConfigLock.Unlock()
return cfg
}
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error {
writeBufferSize := c.WriteBufferSize
if writeBufferSize <= 0 {
writeBufferSize = defaultWriteBufferSize
}
bw := bufio.NewWriterSize(conn, writeBufferSize)
defer bw.Flush()
chR := c.chR
chW := c.chW
writeTimeout := c.WriteTimeout
maxIdleConnDuration := c.MaxIdleConnDuration
if maxIdleConnDuration <= 0 {
maxIdleConnDuration = DefaultMaxIdleConnDuration
}
maxBatchDelay := c.MaxBatchDelay
var (
stopTimer = time.NewTimer(time.Hour)
flushTimer = time.NewTimer(time.Hour)
flushTimerCh <-chan time.Time
instantTimerCh = make(chan time.Time)
w *pipelineWork
err error
)
close(instantTimerCh)
for {
againChW:
select {
case w = <-chW:
// Fast path: len(chW) > 0
default:
// Slow path
stopTimer.Reset(maxIdleConnDuration)
select {
case w = <-chW:
case <-stopTimer.C:
return nil
case <-stopCh:
return nil
case <-flushTimerCh:
if err = bw.Flush(); err != nil {
return err
}
flushTimerCh = nil
goto againChW
}
}
if !w.deadline.IsZero() && time.Since(w.deadline) >= 0 {
w.err = ErrTimeout
w.done <- struct{}{}
continue
}
w.resp.parseNetConn(conn)
if writeTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
}
if err = w.req.Write(bw); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
if flushTimerCh == nil && (len(chW) == 0 || len(chR) == cap(chR)) {
if maxBatchDelay > 0 {
flushTimer.Reset(maxBatchDelay)
flushTimerCh = flushTimer.C
} else {
flushTimerCh = instantTimerCh
}
}
againChR:
select {
case chR <- w:
// Fast path: len(chR) < cap(chR)
default:
// Slow path
select {
case chR <- w:
case <-stopCh:
w.err = errPipelineConnStopped
w.done <- struct{}{}
return nil
case <-flushTimerCh:
if err = bw.Flush(); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
flushTimerCh = nil
goto againChR
}
}
}
}
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error {
readBufferSize := c.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = defaultReadBufferSize
}
br := bufio.NewReaderSize(conn, readBufferSize)
chR := c.chR
readTimeout := c.ReadTimeout
var (
w *pipelineWork
err error
)
for {
select {
case w = <-chR:
// Fast path: len(chR) > 0
default:
// Slow path
select {
case w = <-chR:
case <-stopCh:
return nil
}
}
if readTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetReadDeadline(currentTime.Add(readTimeout)); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
}
if err = w.resp.Read(br); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
w.done <- struct{}{}
}
}
func (c *pipelineConnClient) logger() Logger {
if c.Logger != nil {
return c.Logger
}
return defaultLogger
}
// PendingRequests returns the current number of pending requests pipelined
// to the server.
//
// This number may exceed MaxPendingRequests*MaxConns by up to two times, since
// each connection to the server may keep up to MaxPendingRequests requests
// in the queue before sending them to the server.
//
// This function may be used for balancing load among multiple PipelineClient
// instances.
func (c *PipelineClient) PendingRequests() int {
c.connClientsLock.Lock()
n := 0
for _, cc := range c.connClients {
n += cc.PendingRequests()
}
c.connClientsLock.Unlock()
return n
}
func (c *pipelineConnClient) PendingRequests() int {
c.init()
c.chLock.Lock()
n := len(c.chR) + len(c.chW)
c.chLock.Unlock()
return n
}
func (c *pipelineConnClient) getClientName() []byte {
v := c.clientName.Load()
var clientName []byte
if v == nil {
clientName = []byte(c.Name)
if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
clientName = defaultUserAgent
}
c.clientName.Store(clientName)
} else {
clientName = v.([]byte)
}
return clientName
}
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork {
v := pool.Get()
if v == nil {
v = &pipelineWork{
done: make(chan struct{}, 1),
}
}
w := v.(*pipelineWork)
if timeout > 0 {
if w.t == nil {
w.t = time.NewTimer(timeout)
} else {
w.t.Reset(timeout)
}
w.deadline = time.Now().Add(timeout)
} else {
w.deadline = zeroTime
}
return w
}
func releasePipelineWork(pool *sync.Pool, w *pipelineWork) {
if w.t != nil {
w.t.Stop()
}
w.reqCopy.Reset()
w.respCopy.Reset()
w.req = nil
w.resp = nil
w.err = nil
pool.Put(w)
}
fasthttp-1.31.0/client_example_test.go 0000664 0000000 0000000 00000002040 14130360711 0020000 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"log"
"github.com/valyala/fasthttp"
)
func ExampleHostClient() {
// Perpare a client, which fetches webpages via HTTP proxy listening
// on the localhost:8080.
c := &fasthttp.HostClient{
Addr: "localhost:8080",
}
// Fetch google page via local proxy.
statusCode, body, err := c.Get(nil, "http://google.com/foo/bar")
if err != nil {
log.Fatalf("Error when loading google page through local proxy: %s", err)
}
if statusCode != fasthttp.StatusOK {
log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK)
}
useResponseBody(body)
// Fetch foobar page via local proxy. Reuse body buffer.
statusCode, body, err = c.Get(body, "http://foobar.com/google/com")
if err != nil {
log.Fatalf("Error when loading foobar page through local proxy: %s", err)
}
if statusCode != fasthttp.StatusOK {
log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK)
}
useResponseBody(body)
}
func useResponseBody(body []byte) {
// Do something with body :)
}
fasthttp-1.31.0/client_test.go 0000664 0000000 0000000 00000170232 14130360711 0016276 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestCloseIdleConnections(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Error(err)
}
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
if _, _, err := c.Get(nil, "http://google.com"); err != nil {
t.Fatal(err)
}
connsLen := func() int {
c.mLock.Lock()
defer c.mLock.Unlock()
if _, ok := c.m["google.com"]; !ok {
return 0
}
c.m["google.com"].connsLock.Lock()
defer c.m["google.com"].connsLock.Unlock()
return len(c.m["google.com"].conns)
}
if conns := connsLen(); conns > 1 {
t.Errorf("expected 1 conns got %d", conns)
}
c.CloseIdleConnections()
if conns := connsLen(); conns > 0 {
t.Errorf("expected 0 conns got %d", conns)
}
}
func TestPipelineClientSetUserAgent(t *testing.T) {
t.Parallel()
testPipelineClientSetUserAgent(t, 0)
}
func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
t.Parallel()
testPipelineClientSetUserAgent(t, time.Second)
}
func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
userAgent := "I'm not fasthttp"
c := &HostClient{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
var err error
if timeout <= 0 {
err = c.Do(req, res)
} else {
err = c.DoTimeout(req, res, timeout)
}
if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}
func TestPipelineClientIssue832(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
req := AcquireRequest()
// Don't defer ReleaseRequest as we use it in a goroutine that might not be done at the end.
req.SetHost("example.com")
res := AcquireResponse()
// Don't defer ReleaseResponse as we use it in a goroutine that might not be done at the end.
client := PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ReadTimeout: time.Millisecond * 10,
Logger: &testLogger{}, // Ignore log output.
}
attempts := 10
go func() {
for i := 0; i < attempts; i++ {
c, err := ln.Accept()
if err != nil {
t.Error(err)
}
if c != nil {
go func() {
time.Sleep(time.Millisecond * 50)
c.Close()
}()
}
}
}()
done := make(chan int)
go func() {
defer close(done)
for i := 0; i < attempts; i++ {
if err := client.Do(req, res); err == nil {
t.Error("error expected")
}
}
}()
select {
case <-time.After(time.Second * 2):
t.Fatal("PipelineClient did not restart worker")
case <-done:
}
}
func TestClientInvalidURI(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n")
err := c.Do(req, res)
if err == nil {
t.Fatal("expected error (missing required Host header in request)")
}
if n := atomic.LoadInt64(&requests); n != 0 {
t.Fatalf("0 requests expected, got %d", n)
}
}
func TestClientGetWithBody(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
body := ctx.Request.Body()
ctx.Write(body) //nolint:errcheck
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
req.SetBodyString("test")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if len(res.Body()) == 0 {
t.Fatal("missing request body")
}
}
func TestClientURLAuth(t *testing.T) {
t.Parallel()
cases := map[string]string{
"user:pass@": "Basic dXNlcjpwYXNz",
"foo:@": "Basic Zm9vOg==",
":@": "",
"@": "",
"": "",
}
ch := make(chan string, 1)
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for up, expected := range cases {
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://" + up + "example.com/foo/bar")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
val := <-ch
if val != expected {
t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected)
}
}
}
func TestClientNilResp(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
if err := c.DoTimeout(req, nil, time.Second); err != nil {
t.Fatal(err)
}
ln.Close()
}
func TestPipelineClientNilResp(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
if err := c.DoTimeout(req, nil, time.Second); err != nil {
t.Fatal(err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
}
func TestClientParseConn(t *testing.T) {
t.Parallel()
network := "tcp"
ln, _ := net.Listen(network, "127.0.0.1:0")
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
host := ln.Addr().String()
c := &Client{}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.SetRequestURI("http://" + host + "")
if err := c.Do(req, res); err != nil {
t.Fatal(err)
}
if res.RemoteAddr().Network() != network {
t.Fatalf("req RemoteAddr parse network fail: %s, hope: %s", res.RemoteAddr().Network(), network)
}
if host != res.RemoteAddr().String() {
t.Fatalf("req RemoteAddr parse addr fail: %s, hope: %s", res.RemoteAddr().String(), host)
}
if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) {
t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
}
}
func TestClientPostArgs(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
body := ctx.Request.Body()
if len(body) == 0 {
return
}
ctx.Write(body) //nolint:errcheck
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
args := req.PostArgs()
args.Add("addhttp2", "support")
args.Add("fast", "http")
req.Header.SetMethod(MethodPost)
req.SetRequestURI("http://make.fasthttp.great?again")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if len(res.Body()) == 0 {
t.Fatal("cannot set args as body")
}
}
func TestClientRedirectSameSchema(t *testing.T) {
t.Parallel()
listenHTTPS1 := testClientRedirectListener(t, true)
defer listenHTTPS1.Close()
listenHTTPS2 := testClientRedirectListener(t, true)
defer listenHTTPS2.Close()
sHTTPS1 := testClientRedirectChangingSchemaServer(t, listenHTTPS1, listenHTTPS1, true)
defer sHTTPS1.Stop()
sHTTPS2 := testClientRedirectChangingSchemaServer(t, listenHTTPS2, listenHTTPS2, false)
defer sHTTPS2.Stop()
destURL := fmt.Sprintf("https://%s/baz", listenHTTPS1.Addr().String())
urlParsed, err := url.Parse(destURL)
if err != nil {
t.Fatal(err)
return
}
reqClient := &HostClient{
IsTLS: true,
Addr: urlParsed.Host,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %s", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
}
}
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()
listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
reqClient := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %s", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
}
}
func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()
listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
urlParsed, err := url.Parse(destURL)
if err != nil {
t.Fatal(err)
return
}
reqClient := &HostClient{
Addr: urlParsed.Host,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != ErrHostClientRedirectToDifferentScheme {
t.Fatal("expected HostClient error")
}
}
func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
var ln net.Listener
var err error
var tlsConfig *tls.Config
if isTLS {
certData, keyData, kerr := GenerateTestCertificate("localhost")
if kerr != nil {
t.Fatal(kerr)
}
cert, kerr := tls.X509KeyPair(certData, keyData)
if kerr != nil {
t.Fatal(kerr)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ln, err = tls.Listen("tcp", "localhost:0", tlsConfig)
} else {
ln, err = net.Listen("tcp", "localhost:0")
}
if err != nil {
t.Fatalf("cannot listen isTLS %v: %s", isTLS, err)
}
return ln
}
func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listener, isTLS bool) *testEchoServer {
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.IsTLS() {
ctx.SetStatusCode(200)
} else {
ctx.Redirect(fmt.Sprintf("https://%s/baz", https.Addr().String()), 301)
}
},
}
var ln net.Listener
if isTLS {
ln = https
} else {
ln = http
}
ch := make(chan struct{})
go func() {
err := s.Serve(ln)
if err != nil {
t.Errorf("unexpected error returned from Serve(): %s", err)
}
close(ch)
}()
return &testEchoServer{
s: s,
ln: ln,
ch: ch,
t: t,
}
}
func TestClientHeaderCase(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
t.Error(err)
}
c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck
"content-type: text/plain\r\n" +
"transfer-encoding: chunked\r\n\r\n" +
"24\r\nThis is the data in the first chunk \r\n" +
"1B\r\nand this is the second one \r\n" +
"0\r\n\r\n",
))
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ReadTimeout: time.Millisecond * 10,
// Even without name normalizing we should parse headers correctly.
DisableHeaderNamesNormalizing: true,
}
code, body, err := c.Get(nil, "http://example.com")
if err != nil {
t.Error(err)
} else if code != 200 {
t.Errorf("expected status code 200 got %d", code)
} else if string(body) != "This is the data in the first chunk and this is the second one " {
t.Errorf("wrong body: %q", body)
}
}
func TestClientReadTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
timeout := false
s := &Server{
Handler: func(ctx *RequestCtx) {
if timeout {
time.Sleep(time.Second)
} else {
timeout = true
}
},
Logger: &testLogger{}, // Don't print closed pipe errors.
}
go s.Serve(ln) //nolint:errcheck
c := &HostClient{
ReadTimeout: time.Millisecond * 400,
MaxIdemponentCallAttempts: 1,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://localhost")
// Setting Connection: Close will make the connection be
// returned to the pool.
req.SetConnectionClose()
if err := c.Do(req, res); err != nil {
t.Fatal(err)
}
ReleaseRequest(req)
ReleaseResponse(res)
done := make(chan struct{})
go func() {
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://localhost")
req.SetConnectionClose()
if err := c.Do(req, res); err != ErrTimeout {
t.Errorf("expected ErrTimeout got %#v", err)
}
ReleaseRequest(req)
ReleaseResponse(res)
close(done)
}()
select {
case <-done:
// This shouldn't take longer than the timeout times the number of requests it is going to try to do.
// Give it an extra second just to be sure.
case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdemponentCallAttempts) + time.Second):
t.Fatal("Client.ReadTimeout didn't work")
}
}
func TestClientDefaultUserAgent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != string(defaultUserAgent) {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent)
}
}
func TestClientSetUserAgent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
userAgent := "I'm not fasthttp"
c := &Client{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}
func TestClientNoUserAgent(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
NoDefaultUserAgentHeader: true,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != "" {
t.Fatalf("User-Agent wrong %q != %q", userAgentSeen, "")
}
}
func TestClientDoWithCustomHeaders(t *testing.T) {
t.Parallel()
// make sure that the client sends all the request headers and body.
ln := fasthttputil.NewInmemoryListener()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
uri := "/foo/bar/baz?a=b&cd=12"
headers := map[string]string{
"Foo": "bar",
"Host": "xxx.com",
"Content-Type": "asdfsdf",
"a-b-c-d-f": "",
}
body := "request body"
ch := make(chan error)
go func() {
conn, err := ln.Accept()
if err != nil {
ch <- fmt.Errorf("cannot accept client connection: %s", err)
return
}
br := bufio.NewReader(conn)
var req Request
if err = req.Read(br); err != nil {
ch <- fmt.Errorf("cannot read client request: %s", err)
return
}
if string(req.Header.Method()) != MethodPost {
ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", req.Header.Method(), MethodPost)
return
}
reqURI := req.RequestURI()
if string(reqURI) != uri {
ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri)
return
}
for k, v := range headers {
hv := req.Header.Peek(k)
if string(hv) != v {
ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v)
return
}
}
cl := req.Header.ContentLength()
if cl != len(body) {
ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body))
return
}
reqBody := req.Body()
if string(reqBody) != body {
ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body)
return
}
var resp Response
bw := bufio.NewWriter(conn)
if err = resp.Write(bw); err != nil {
ch <- fmt.Errorf("cannot send response: %s", err)
return
}
if err = bw.Flush(); err != nil {
ch <- fmt.Errorf("cannot flush response: %s", err)
return
}
ch <- nil
}()
var req Request
req.Header.SetMethod(MethodPost)
req.SetRequestURI(uri)
for k, v := range headers {
req.Header.Set(k, v)
}
req.SetBodyString(body)
var resp Response
err := c.DoTimeout(&req, &resp, time.Second)
if err != nil {
t.Fatalf("error when doing request: %s", err)
}
select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
}
func TestPipelineClientDoSerial(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 1, 0, 0)
}
func TestPipelineClientDoConcurrent(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 0, 1)
}
func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 1)
}
func TestPipelineClientDoBatchDelayConcurrentMultiConn(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 3)
}
func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration, maxConns int) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: maxConns,
MaxPendingRequests: concurrency,
MaxBatchDelay: maxBatchDelay,
Logger: &testLogger{},
}
clientStopCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
testPipelineClientDo(t, c)
clientStopCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientStopCh:
case <-time.After(3 * time.Second):
t.Fatalf("timeout")
}
}
if c.PendingRequests() != 0 {
t.Fatalf("unexpected number of pending requests: %d. Expecting zero", c.PendingRequests())
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func testPipelineClientDo(t *testing.T, c *PipelineClient) {
var err error
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()
for i := 0; i < 10; i++ {
if i&1 == 0 {
err = c.DoTimeout(req, resp, time.Second)
} else {
err = c.Do(req, resp)
}
if err != nil {
if err == ErrPipelineOverflow {
time.Sleep(10 * time.Millisecond)
continue
}
t.Fatalf("unexpected error on iteration %d: %s", i, err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := string(resp.Body())
if body != "OK" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "OK")
}
// sleep for a while, so the connection to the host may expire.
if i%5 == 0 {
time.Sleep(30 * time.Millisecond)
}
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
func TestPipelineClientDoDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
testPipelineClientDisableHeaderNamesNormalizing(t, 0)
}
func TestPipelineClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
testPipelineClientDisableHeaderNamesNormalizing(t, time.Second)
}
func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Response.Header.Set("foo-BAR", "baz")
},
DisableHeaderNamesNormalizing: true,
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisableHeaderNamesNormalizing: true,
}
var req Request
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
var resp Response
for i := 0; i < 5; i++ {
if timeout > 0 {
if err := c.DoTimeout(&req, &resp, timeout); err != nil {
t.Fatalf("unexpected error: %s", err)
}
} else {
if err := c.Do(&req, &resp); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
hv := resp.Header.Peek("foo-BAR")
if string(hv) != "baz" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
}
hv = resp.Header.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("unexpected non-empty header value %q", hv)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Response.Header.Set("foo-BAR", "baz")
},
DisableHeaderNamesNormalizing: true,
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisableHeaderNamesNormalizing: true,
}
var req Request
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Fatalf("unexpected error: %s", err)
}
hv := resp.Header.Peek("foo-BAR")
if string(hv) != "baz" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
}
hv = resp.Header.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("unexpected non-empty header value %q", hv)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
uri := ctx.URI()
uri.DisablePathNormalizing = true
ctx.Response.Header.Set("received-uri", string(uri.FullURI()))
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisablePathNormalizing: true,
}
urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff"
var req Request
req.SetRequestURI(urlWithEncodedPath)
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Fatalf("unexpected error: %s", err)
}
hv := resp.Header.Peek("received-uri")
if string(hv) != urlWithEncodedPath {
t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestHostClientPendingRequests(t *testing.T) {
t.Parallel()
const concurrency = 10
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
Handler: func(ctx *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
pendingRequests := c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}
resultCh := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()
if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
resultCh <- fmt.Errorf("unexpected error: %s", err)
return
}
if resp.StatusCode() != StatusOK {
resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
return
}
resultCh <- nil
}()
}
// wait while all the requests reach server
for i := 0; i < concurrency; i++ {
select {
case <-readyCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
pendingRequests = c.PendingRequests()
if pendingRequests != concurrency {
t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency)
}
// unblock request handlers on the server and wait until all the requests are finished.
close(doneCh)
for i := 0; i < concurrency; i++ {
select {
case err := <-resultCh:
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
pendingRequests = c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}
// stop the server
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestHostClientMaxConnsWithDeadline(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
timeout = 200 * time.Millisecond
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
for {
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err == ErrNoFreeConns {
time.Sleep(time.Millisecond)
continue
}
t.Errorf("unexpected error: %s", err)
}
break
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}()
}
wg.Wait()
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnDuration(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
connectionCloseCount := uint32(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("abcd") //nolint:errcheck
if ctx.Request.ConnectionClose() {
atomic.AddUint32(&connectionCloseCount, 1)
}
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConnDuration: 10 * time.Millisecond,
}
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "abcd" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
}
time.Sleep(c.MaxConnDuration)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if connectionCloseCount == 0 {
t.Fatalf("expecting at least one 'Connection: close' request header")
}
}
func TestHostClientMultipleAddrs(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Write(ctx.Host()) //nolint:errcheck
ctx.SetConnectionClose()
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
dialsCount := make(map[string]int)
c := &HostClient{
Addr: "foo,bar,baz",
Dial: func(addr string) (net.Conn, error) {
dialsCount[addr]++
return ln.Dial()
},
}
for i := 0; i < 9; i++ {
statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "foobar" {
t.Fatalf("unexpected body %q. Expecting %q", body, "foobar")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if len(dialsCount) != 3 {
t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount))
}
for _, k := range []string{"foo", "bar", "baz"} {
if dialsCount[k] != 3 {
t.Fatalf("unexpected dialsCount for %q. Expecting 3", k)
}
}
}
func TestClientFollowRedirects(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/foo":
u := ctx.URI()
u.Update("/xy?z=wer")
ctx.Redirect(u.String(), StatusFound)
case "/xy":
u := ctx.URI()
u.Update("/bar")
ctx.Redirect(u.String(), StatusFound)
default:
ctx.Success("text/plain", ctx.Path())
}
},
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "xxx",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for i := 0; i < 10; i++ {
statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
}
for i := 0; i < 10; i++ {
statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/aaab/sss" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
}
}
for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
err := c.DoRedirects(req, resp, 16)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode := resp.StatusCode(); statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if body := string(resp.Body()); body != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
err := c.DoRedirects(req, resp, 0)
if have, want := err, ErrTooManyRedirects; have != want {
t.Fatalf("want error: %v, have %v", want, have)
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
func TestClientGetTimeoutSuccess(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientGetTimeoutSuccessConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}()
}
wg.Wait()
}
func TestClientDoTimeoutSuccess(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}()
}
wg.Wait()
}
func TestClientGetTimeoutError(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{t: time.Second}, nil
},
}
testClientGetTimeoutError(t, c, 100)
}
func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGetTimeoutError(t, c, 100)
}()
}
wg.Wait()
}
func TestClientDoTimeoutError(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{t: time.Second}, nil
},
}
testClientDoTimeoutError(t, c, 100)
}
func TestClientDoTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientDoTimeoutError(t, c, 100)
}()
}
wg.Wait()
}
func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
var req Request
var resp Response
req.SetRequestURI("http://foobar.com/baz")
for i := 0; i < n; i++ {
err := c.DoTimeout(&req, &resp, time.Millisecond)
if err == nil {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
}
}
}
func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
buf := make([]byte, 10)
for i := 0; i < n; i++ {
statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond)
if err == nil {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
}
if statusCode != 0 {
t.Fatalf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
}
if body == nil {
t.Fatalf("body must be non-nil")
}
}
}
type readTimeoutConn struct {
net.Conn
t time.Duration
}
func (r *readTimeoutConn) Read(p []byte) (int, error) {
time.Sleep(r.t)
return 0, io.EOF
}
func (r *readTimeoutConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *readTimeoutConn) Close() error {
return nil
}
func (r *readTimeoutConn) LocalAddr() net.Addr {
return nil
}
func (r *readTimeoutConn) RemoteAddr() net.Addr {
return nil
}
func TestClientNonIdempotentRetry(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
return &readErrorConn{}, nil
case 3:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
// This POST must succeed, since the readErrorConn closes
// the connection before sending any response.
// So the client must retry non-idempotent request.
dialsCount = 0
statusCode, body, err := c.Post(nil, "http://foobar/a/b", nil)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
// Verify that idempotent GET succeeds.
dialsCount = 0
statusCode, body, err = c.Get(nil, "http://foobar/a/b")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
}
func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
return &readErrorConn{}, nil
case 3:
return &singleEchoConn{
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
dialsCount = 0
req := Request{}
res := Response{}
req.SetRequestURI("http://foobar/a/b")
req.Header.SetMethod("POST")
body := bytes.NewBufferString("test")
req.SetBodyStream(body, body.Len())
err := c.Do(&req, &res)
if err == nil {
t.Fatal("expected error from being unable to retry a bodyStream")
}
}
func TestClientIdempotentRequest(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
return &singleReadConn{
s: "invalid response",
}, nil
case 2:
return &writeErrorConn{}, nil
case 3:
return &readErrorConn{}, nil
case 4:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
// idempotent GET must succeed.
statusCode, body, err := c.Get(nil, "http://foobar/a/b")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
var args Args
// non-idempotent POST must fail on incorrect singleReadConn
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b", &args)
if err == nil {
t.Fatalf("expecting error")
}
// non-idempotent POST must fail on incorrect singleReadConn
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b", nil)
if err == nil {
t.Fatalf("expecting error")
}
}
func TestClientRetryRequestWithCustomDecider(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(addr string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
return &singleReadConn{
s: "invalid response",
}, nil
case 2:
return &writeErrorConn{}, nil
case 3:
return &readErrorConn{}, nil
case 4:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
RetryIf: func(req *Request) bool {
return req.URI().String() == "http://foobar/a/b"
},
}
var args Args
// Post must succeed for http://foobar/a/b uri.
statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
// POST must fail for http://foobar/a/b/c uri.
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b/c", &args)
if err == nil {
t.Fatalf("expecting error")
}
}
func TestHostClientTransport(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("abcd") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Transport: func() TransportFunc {
c, _ := ln.Dial()
br := bufio.NewReader(c)
bw := bufio.NewWriter(c)
return func(req *Request, res *Response) error {
if err := req.Write(bw); err != nil {
return err
}
if err := bw.Flush(); err != nil {
return err
}
return res.Read(br)
}
}(),
}
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "abcd" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
type writeErrorConn struct {
net.Conn
}
func (w *writeErrorConn) Write(p []byte) (int, error) {
return 1, fmt.Errorf("error")
}
func (w *writeErrorConn) Close() error {
return nil
}
func (w *writeErrorConn) LocalAddr() net.Addr {
return nil
}
func (w *writeErrorConn) RemoteAddr() net.Addr {
return nil
}
type readErrorConn struct {
net.Conn
}
func (r *readErrorConn) Read(p []byte) (int, error) {
return 0, fmt.Errorf("error")
}
func (r *readErrorConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *readErrorConn) Close() error {
return nil
}
func (r *readErrorConn) LocalAddr() net.Addr {
return nil
}
func (r *readErrorConn) RemoteAddr() net.Addr {
return nil
}
type singleReadConn struct {
net.Conn
s string
n int
}
func (r *singleReadConn) Read(p []byte) (int, error) {
if len(r.s) == r.n {
return 0, io.EOF
}
n := copy(p, []byte(r.s[r.n:]))
r.n += n
return n, nil
}
func (r *singleReadConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *singleReadConn) Close() error {
return nil
}
func (r *singleReadConn) LocalAddr() net.Addr {
return nil
}
func (r *singleReadConn) RemoteAddr() net.Addr {
return nil
}
type singleEchoConn struct {
net.Conn
b []byte
n int
}
func (r *singleEchoConn) Read(p []byte) (int, error) {
if len(r.b) == r.n {
return 0, io.EOF
}
n := copy(p, r.b[r.n:])
r.n += n
return n, nil
}
func (r *singleEchoConn) Write(p []byte) (int, error) {
r.b = append(r.b, p...)
return len(p), nil
}
func (r *singleEchoConn) Close() error {
return nil
}
func (r *singleEchoConn) LocalAddr() net.Addr {
return nil
}
func (r *singleEchoConn) RemoteAddr() net.Addr {
return nil
}
func TestSingleEchoConn(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &singleEchoConn{
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
}, nil
},
}
req := Request{}
res := Response{}
req.SetRequestURI("http://foobar/a/b")
req.Header.SetMethod("POST")
req.Header.Set("Content-Type", "text/plain")
body := bytes.NewBufferString("test")
req.SetBodyStream(body, body.Len())
err := c.Do(&req, &res)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if res.StatusCode() != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode())
}
expected := "POST /a/b HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar\r\nContent-Type: text/plain\r\nContent-Length: 4\r\n\r\ntest"
if string(res.Body()) != expected {
t.Fatalf("unexpected body: %q. Expecting %q", res.Body(), expected)
}
}
func TestClientHTTPSInvalidServerName(t *testing.T) {
t.Parallel()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
var c Client
for i := 0; i < 10; i++ {
_, _, err := c.GetTimeout(nil, "https://"+sHTTPS.Addr(), time.Second)
if err == nil {
t.Fatalf("expecting TLS error")
}
}
}
func TestClientHTTPSConcurrent(t *testing.T) {
t.Parallel()
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
defer sHTTP.Stop()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
c := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
addr := "http://" + sHTTP.Addr()
if i&1 != 0 {
addr = "https://" + sHTTPS.Addr()
}
go func() {
defer wg.Done()
testClientGet(t, c, addr, 20)
testClientPost(t, c, addr, 10)
}()
}
wg.Wait()
}
func TestClientManyServers(t *testing.T) {
t.Parallel()
var addrs []string
for i := 0; i < 10; i++ {
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
addrs = append(addrs, s.Addr())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
addr := "http://" + addrs[i]
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 20)
testClientPost(t, &defaultClient, addr, 10)
}()
}
wg.Wait()
}
func TestClientGet(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientGet(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientPost(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientPost(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
addr := "http://" + s.Addr()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 30)
testClientPost(t, &defaultClient, addr, 10)
}()
}
wg.Wait()
}
func skipIfNotUnix(tb testing.TB) {
switch runtime.GOOS {
case "android", "nacl", "plan9", "windows":
tb.Skipf("%s does not support unix sockets", runtime.GOOS)
}
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
tb.Skip("iOS does not support unix, unixgram")
}
}
func TestHostClientGet(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "TestHostClientGet.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
testHostClientGet(t, c, 100)
}
func TestHostClientPost(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "./TestHostClientPost.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
testHostClientPost(t, c, 100)
}
func TestHostClientConcurrent(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "./TestHostClientConcurrent.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testHostClientGet(t, c, 30)
testHostClientPost(t, c, 10)
}()
}
wg.Wait()
}
func testClientGet(t *testing.T, c clientGetter, addr string, n int) {
var buf []byte
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
statusCode, body, err := c.Get(buf, uri)
buf = body
if err != nil {
t.Fatalf("unexpected error when doing http request: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if resultURI != uri {
t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
var req Request
var resp Response
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
req.SetRequestURI(uri)
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
resultURI := string(resp.Body())
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
var buf []byte
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
statusCode, body, err := c.GetTimeout(buf, uri, time.Second)
buf = body
if err != nil {
t.Fatalf("unexpected error when doing http request: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientPost(t *testing.T, c clientPoster, addr string, n int) {
var buf []byte
var args Args
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
args.Set("xx", fmt.Sprintf("yy%d", i))
args.Set("zzz", fmt.Sprintf("qwe_%d", i))
argsS := args.String()
statusCode, body, err := c.Post(buf, uri, &args)
buf = body
if err != nil {
t.Fatalf("unexpected error when doing http request: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
s := string(body)
if s != argsS {
t.Fatalf("unexpected response %q. Expecting %q", s, argsS)
}
}
}
func testHostClientGet(t *testing.T, c *HostClient, n int) {
testClientGet(t, c, "http://google.com", n)
}
func testHostClientPost(t *testing.T, c *HostClient, n int) {
testClientPost(t, c, "http://post-host.com", n)
}
type clientPoster interface {
Post(dst []byte, uri string, postArgs *Args) (int, []byte, error)
}
type clientGetter interface {
Get(dst []byte, uri string) (int, []byte, error)
}
func createEchoClient(t *testing.T, network, addr string) *HostClient {
return &HostClient{
Addr: addr,
Dial: func(addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
}
}
type testEchoServer struct {
s *Server
ln net.Listener
ch chan struct{}
t *testing.T
}
func (s *testEchoServer) Stop() {
s.ln.Close()
select {
case <-s.ch:
case <-time.After(time.Second):
s.t.Fatalf("timeout when waiting for server close")
}
}
func (s *testEchoServer) Addr() string {
return s.ln.Addr().String()
}
func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer {
return startEchoServerExt(t, network, addr, true)
}
func startEchoServer(t *testing.T, network, addr string) *testEchoServer {
return startEchoServerExt(t, network, addr, false)
}
func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEchoServer {
if network == "unix" {
os.Remove(addr)
}
var ln net.Listener
var err error
if isTLS {
certData, keyData, kerr := GenerateTestCertificate("localhost")
if kerr != nil {
t.Fatal(kerr)
}
cert, kerr := tls.X509KeyPair(certData, keyData)
if kerr != nil {
t.Fatal(kerr)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
ln, err = tls.Listen(network, addr, tlsConfig)
} else {
ln, err = net.Listen(network, addr)
}
if err != nil {
t.Fatalf("cannot listen %q: %s", addr, err)
}
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.IsGet() {
ctx.Success("text/plain", ctx.URI().FullURI())
} else if ctx.IsPost() {
ctx.PostArgs().WriteTo(ctx) //nolint:errcheck
}
},
Logger: &testLogger{}, // Ignore log output.
}
ch := make(chan struct{})
go func() {
err := s.Serve(ln)
if err != nil {
t.Errorf("unexpected error returned from Serve(): %s", err)
}
close(ch)
}()
return &testEchoServer{
s: s,
ln: ln,
ch: ch,
t: t,
}
}
func TestClientTLSHandshakeTimeout(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := listener.Addr().String()
defer listener.Close()
complete := make(chan bool)
defer close(complete)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error(err)
return
}
<-complete
conn.Close()
}()
client := Client{
WriteTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
_, _, err = client.Get(nil, "https://"+addr)
if err == nil {
t.Fatal("tlsClientHandshake completed successfully")
}
if err != ErrTLSHandshakeTimeout {
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
}
func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(5 * time.Millisecond)
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: time.Second * 2,
}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.Do(req, resp); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}()
}
wg.Wait()
if c.connsWait.len() > 0 {
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second * 5):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(5 * time.Millisecond)
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: 10 * time.Millisecond,
}
var errNoFreeConnsCount uint32
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.Do(req, resp); err != nil {
if err != ErrNoFreeConns {
t.Errorf("unexpected error: %s. Expecting %s", err, ErrNoFreeConns)
}
atomic.AddUint32(&errNoFreeConnsCount, 1)
} else {
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}
}()
}
wg.Wait()
// Prevent a race condition with the conns cleaner that might still be running.
c.connsLock.Lock()
defer c.connsLock.Unlock()
if c.connsWait.len() > 0 {
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
}
if errNoFreeConnsCount == 0 {
t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
// make deadline reach earlier than conns wait timeout
sleep = 100 * time.Millisecond
timeout = 10 * time.Millisecond
maxConnWaitTimeout = 50 * time.Millisecond
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(sleep)
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: maxConnWaitTimeout,
}
var errTimeoutCount uint32
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err != ErrTimeout {
t.Errorf("unexpected error: %s. Expecting %s", err, ErrTimeout)
}
atomic.AddUint32(&errTimeoutCount, 1)
} else {
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}
}()
}
wg.Wait()
c.connsLock.Lock()
for {
w := c.connsWait.popFront()
if w == nil {
break
}
w.mu.Lock()
if w.err != nil && w.err != ErrTimeout {
t.Errorf("unexpected error: %s. Expecting %s", w.err, ErrTimeout)
}
w.mu.Unlock()
}
c.connsLock.Unlock()
if errTimeoutCount == 0 {
t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
fasthttp-1.31.0/client_timing_test.go 0000664 0000000 0000000 00000037465 14130360711 0017657 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
type fakeClientConn struct {
net.Conn
s []byte
n int
ch chan struct{}
}
func (c *fakeClientConn) Write(b []byte) (int, error) {
c.ch <- struct{}{}
return len(b), nil
}
func (c *fakeClientConn) Read(b []byte) (int, error) {
if c.n == 0 {
// wait for request :)
<-c.ch
}
n := 0
for len(b) > 0 {
if c.n == len(c.s) {
c.n = 0
return n, nil
}
n = copy(b, c.s[c.n:])
c.n += n
b = b[n:]
}
return n, nil
}
func (c *fakeClientConn) Close() error {
releaseFakeServerConn(c)
return nil
}
func (c *fakeClientConn) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
}
}
func (c *fakeClientConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
}
}
func releaseFakeServerConn(c *fakeClientConn) {
c.n = 0
fakeClientConnPool.Put(c)
}
func acquireFakeServerConn(s []byte) *fakeClientConn {
v := fakeClientConnPool.Get()
if v == nil {
c := &fakeClientConn{
s: s,
ch: make(chan struct{}, 1),
}
return c
}
return v.(*fakeClientConn)
}
var fakeClientConnPool sync.Pool
func BenchmarkClientGetTimeoutFastServer(b *testing.B) {
body := []byte("123456789099")
s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body))
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return acquireFakeServerConn(s), nil
},
}
nn := uint32(0)
b.RunParallel(func(pb *testing.PB) {
url := fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1))
var statusCode int
var bodyBuf []byte
var err error
for pb.Next() {
statusCode, bodyBuf, err = c.GetTimeout(bodyBuf[:0], url, time.Second)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d", statusCode)
}
if !bytes.Equal(bodyBuf, body) {
b.Fatalf("unexpected response body: %q. Expected %q", bodyBuf, body)
}
}
})
}
func BenchmarkClientDoFastServer(b *testing.B) {
body := []byte("012345678912")
s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body))
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return acquireFakeServerConn(s), nil
},
MaxConnsPerHost: runtime.GOMAXPROCS(-1),
}
nn := uint32(0)
b.RunParallel(func(pb *testing.PB) {
var req Request
var resp Response
req.Header.SetRequestURI(fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)))
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.Header.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d", resp.Header.StatusCode())
}
if !bytes.Equal(resp.Body(), body) {
b.Fatalf("unexpected response body: %q. Expected %q", resp.Body(), body)
}
}
})
}
func BenchmarkNetHTTPClientDoFastServer(b *testing.B) {
body := []byte("012345678912")
s := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(body), body))
c := &http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return acquireFakeServerConn(s), nil
},
MaxIdleConnsPerHost: runtime.GOMAXPROCS(-1),
},
}
nn := uint32(0)
b.RunParallel(func(pb *testing.PB) {
req, err := http.NewRequest(MethodGet, fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
for pb.Next() {
resp, err := c.Do(req)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d", resp.StatusCode)
}
respBody, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
b.Fatalf("unexpected error when reading response body: %s", err)
}
if !bytes.Equal(respBody, body) {
b.Fatalf("unexpected response body: %q. Expected %q", respBody, body)
}
}
})
}
func fasthttpEchoHandler(ctx *RequestCtx) {
ctx.Success("text/plain", ctx.RequestURI())
}
func nethttpEchoHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set(HeaderContentType, "text/plain")
w.Write([]byte(r.RequestURI)) //nolint:errcheck
}
func BenchmarkClientGetEndToEnd1TCP(b *testing.B) {
benchmarkClientGetEndToEndTCP(b, 1)
}
func BenchmarkClientGetEndToEnd10TCP(b *testing.B) {
benchmarkClientGetEndToEndTCP(b, 10)
}
func BenchmarkClientGetEndToEnd100TCP(b *testing.B) {
benchmarkClientGetEndToEndTCP(b, 100)
}
func benchmarkClientGetEndToEndTCP(b *testing.B, parallelism int) {
addr := "127.0.0.1:8543"
ln, err := net.Listen("tcp4", addr)
if err != nil {
b.Fatalf("cannot listen %q: %s", addr, err)
}
ch := make(chan struct{})
go func() {
if err := Serve(ln, fasthttpEchoHandler); err != nil {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &Client{
MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism,
}
requestURI := "/foo/bar?baz=123"
url := "http://" + addr + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
var buf []byte
for pb.Next() {
statusCode, body, err := c.Get(buf, url)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
buf = body
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkNetHTTPClientGetEndToEnd1TCP(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndTCP(b, 1)
}
func BenchmarkNetHTTPClientGetEndToEnd10TCP(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndTCP(b, 10)
}
func BenchmarkNetHTTPClientGetEndToEnd100TCP(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndTCP(b, 100)
}
func benchmarkNetHTTPClientGetEndToEndTCP(b *testing.B, parallelism int) {
addr := "127.0.0.1:8542"
ln, err := net.Listen("tcp4", addr)
if err != nil {
b.Fatalf("cannot listen %q: %s", addr, err)
}
ch := make(chan struct{})
go func() {
if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &http.Client{
Transport: &http.Transport{
MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1),
},
}
requestURI := "/foo/bar?baz=123"
url := "http://" + addr + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := c.Get(url)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
b.Fatalf("unexpected error when reading response body: %s", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkClientGetEndToEnd1Inmemory(b *testing.B) {
benchmarkClientGetEndToEndInmemory(b, 1)
}
func BenchmarkClientGetEndToEnd10Inmemory(b *testing.B) {
benchmarkClientGetEndToEndInmemory(b, 10)
}
func BenchmarkClientGetEndToEnd100Inmemory(b *testing.B) {
benchmarkClientGetEndToEndInmemory(b, 100)
}
func BenchmarkClientGetEndToEnd1000Inmemory(b *testing.B) {
benchmarkClientGetEndToEndInmemory(b, 1000)
}
func BenchmarkClientGetEndToEnd10KInmemory(b *testing.B) {
benchmarkClientGetEndToEndInmemory(b, 10000)
}
func benchmarkClientGetEndToEndInmemory(b *testing.B, parallelism int) {
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
if err := Serve(ln, fasthttpEchoHandler); err != nil {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &Client{
MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism,
Dial: func(addr string) (net.Conn, error) { return ln.Dial() },
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
var buf []byte
for pb.Next() {
statusCode, body, err := c.Get(buf, url)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
buf = body
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkNetHTTPClientGetEndToEnd1Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndInmemory(b, 1)
}
func BenchmarkNetHTTPClientGetEndToEnd10Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndInmemory(b, 10)
}
func BenchmarkNetHTTPClientGetEndToEnd100Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndInmemory(b, 100)
}
func BenchmarkNetHTTPClientGetEndToEnd1000Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndInmemory(b, 1000)
}
func benchmarkNetHTTPClientGetEndToEndInmemory(b *testing.B, parallelism int) {
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() },
MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1),
},
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := c.Get(url)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
b.Fatalf("unexpected error when reading response body: %s", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkClientEndToEndBigResponse1Inmemory(b *testing.B) {
benchmarkClientEndToEndBigResponseInmemory(b, 1)
}
func BenchmarkClientEndToEndBigResponse10Inmemory(b *testing.B) {
benchmarkClientEndToEndBigResponseInmemory(b, 10)
}
func benchmarkClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
bigResponse := createFixedBody(1024 * 1024)
h := func(ctx *RequestCtx) {
ctx.SetContentType("text/plain")
ctx.Write(bigResponse) //nolint:errcheck
}
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &Client{
MaxConnsPerHost: runtime.GOMAXPROCS(-1) * parallelism,
Dial: func(addr string) (net.Conn, error) { return ln.Dial() },
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
var req Request
req.SetRequestURI(url)
var resp Response
for pb.Next() {
if err := c.DoTimeout(&req, &resp, 5*time.Second); err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if !bytes.Equal(bigResponse, body) {
b.Fatalf("unexpected response %q. Expecting %q", body, bigResponse)
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkNetHTTPClientEndToEndBigResponse1Inmemory(b *testing.B) {
benchmarkNetHTTPClientEndToEndBigResponseInmemory(b, 1)
}
func BenchmarkNetHTTPClientEndToEndBigResponse10Inmemory(b *testing.B) {
benchmarkNetHTTPClientEndToEndBigResponseInmemory(b, 10)
}
func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
bigResponse := createFixedBody(1024 * 1024)
h := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(HeaderContentType, "text/plain")
w.Write(bigResponse) //nolint:errcheck
}
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
if err := http.Serve(ln, http.HandlerFunc(h)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() },
MaxIdleConnsPerHost: parallelism * runtime.GOMAXPROCS(-1),
},
Timeout: 5 * time.Second,
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
req, err := http.NewRequest(MethodGet, url, nil)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
for pb.Next() {
resp, err := c.Do(req)
if err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
b.Fatalf("unexpected error when reading response body: %s", err)
}
if !bytes.Equal(bigResponse, body) {
b.Fatalf("unexpected response %q. Expecting %q", body, bigResponse)
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func BenchmarkPipelineClient1(b *testing.B) {
benchmarkPipelineClient(b, 1)
}
func BenchmarkPipelineClient10(b *testing.B) {
benchmarkPipelineClient(b, 10)
}
func BenchmarkPipelineClient100(b *testing.B) {
benchmarkPipelineClient(b, 100)
}
func BenchmarkPipelineClient1000(b *testing.B) {
benchmarkPipelineClient(b, 1000)
}
func benchmarkPipelineClient(b *testing.B, parallelism int) {
h := func(ctx *RequestCtx) {
ctx.WriteString("foobar") //nolint:errcheck
}
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
maxConns := runtime.GOMAXPROCS(-1)
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) { return ln.Dial() },
ReadBufferSize: 1024 * 1024,
WriteBufferSize: 1024 * 1024,
MaxConns: maxConns,
MaxPendingRequests: parallelism * maxConns,
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
var req Request
req.SetRequestURI(url)
var resp Response
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foobar" {
b.Fatalf("unexpected response %q. Expecting %q", body, "foobar")
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
fasthttp-1.31.0/client_timing_wait_test.go 0000664 0000000 0000000 00000010373 14130360711 0020670 0 ustar 00root root 0000000 0000000 //go:build go1.11
// +build go1.11
package fasthttp
import (
"io/ioutil"
"net"
"net/http"
"strings"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func newFasthttpSleepEchoHandler(sleep time.Duration) RequestHandler {
return func(ctx *RequestCtx) {
time.Sleep(sleep)
ctx.Success("text/plain", ctx.RequestURI())
}
}
func BenchmarkClientGetEndToEndWaitConn1Inmemory(b *testing.B) {
benchmarkClientGetEndToEndWaitConnInmemory(b, 1)
}
func BenchmarkClientGetEndToEndWaitConn10Inmemory(b *testing.B) {
benchmarkClientGetEndToEndWaitConnInmemory(b, 10)
}
func BenchmarkClientGetEndToEndWaitConn100Inmemory(b *testing.B) {
benchmarkClientGetEndToEndWaitConnInmemory(b, 100)
}
func BenchmarkClientGetEndToEndWaitConn1000Inmemory(b *testing.B) {
benchmarkClientGetEndToEndWaitConnInmemory(b, 1000)
}
func benchmarkClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) {
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
sleepDuration := 50 * time.Millisecond
go func() {
if err := Serve(ln, newFasthttpSleepEchoHandler(sleepDuration)); err != nil {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &Client{
MaxConnsPerHost: 1,
Dial: func(addr string) (net.Conn, error) { return ln.Dial() },
MaxConnWaitTimeout: 5 * time.Second,
}
requestURI := "/foo/bar?baz=123&sleep=10ms"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
var buf []byte
for pb.Next() {
statusCode, body, err := c.Get(buf, url)
if err != nil {
if err != ErrNoFreeConns {
b.Fatalf("unexpected error: %s", err)
}
} else {
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
}
buf = body
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
func newNethttpSleepEchoHandler(sleep time.Duration) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
time.Sleep(sleep)
w.Header().Set(HeaderContentType, "text/plain")
w.Write([]byte(r.RequestURI)) //nolint:errcheck
}
}
func BenchmarkNetHTTPClientGetEndToEndWaitConn1Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 1)
}
func BenchmarkNetHTTPClientGetEndToEndWaitConn10Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 10)
}
func BenchmarkNetHTTPClientGetEndToEndWaitConn100Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 100)
}
func BenchmarkNetHTTPClientGetEndToEndWaitConn1000Inmemory(b *testing.B) {
benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 1000)
}
func benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) {
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
sleep := 50 * time.Millisecond
go func() {
if err := http.Serve(ln, newNethttpSleepEchoHandler(sleep)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
b.Errorf("error when serving requests: %s", err)
}
close(ch)
}()
c := &http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() },
MaxConnsPerHost: 1,
},
Timeout: 5 * time.Second,
}
requestURI := "/foo/bar?baz=123"
url := "http://unused.host" + requestURI
b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := c.Get(url)
if err != nil {
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
b.Fatalf("unexpected error: %s", err)
}
} else {
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
b.Fatalf("unexpected error when reading response body: %s", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
}
}
}
})
ln.Close()
select {
case <-ch:
case <-time.After(time.Second):
b.Fatalf("server wasn't stopped")
}
}
fasthttp-1.31.0/coarseTime.go 0000664 0000000 0000000 00000000472 14130360711 0016052 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"time"
)
// CoarseTimeNow returns the current time truncated to the nearest second.
//
// Deprecated: This is slower than calling time.Now() directly.
// This is now time.Now().Truncate(time.Second) shortcut.
func CoarseTimeNow() time.Time {
return time.Now().Truncate(time.Second)
}
fasthttp-1.31.0/coarseTime_test.go 0000664 0000000 0000000 00000001176 14130360711 0017113 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"sync/atomic"
"testing"
"time"
)
func BenchmarkCoarseTimeNow(b *testing.B) {
var zeroTimeCount uint64
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
t := CoarseTimeNow()
if t.IsZero() {
atomic.AddUint64(&zeroTimeCount, 1)
}
}
})
if zeroTimeCount > 0 {
b.Fatalf("zeroTimeCount must be zero")
}
}
func BenchmarkTimeNow(b *testing.B) {
var zeroTimeCount uint64
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
t := time.Now()
if t.IsZero() {
atomic.AddUint64(&zeroTimeCount, 1)
}
}
})
if zeroTimeCount > 0 {
b.Fatalf("zeroTimeCount must be zero")
}
}
fasthttp-1.31.0/compress.go 0000664 0000000 0000000 00000026122 14130360711 0015612 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"io"
"os"
"sync"
"github.com/klauspost/compress/flate"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zlib"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
// Supported compression levels.
const (
CompressNoCompression = flate.NoCompression
CompressBestSpeed = flate.BestSpeed
CompressBestCompression = flate.BestCompression
CompressDefaultCompression = 6 // flate.DefaultCompression
CompressHuffmanOnly = -2 // flate.HuffmanOnly
)
func acquireGzipReader(r io.Reader) (*gzip.Reader, error) {
v := gzipReaderPool.Get()
if v == nil {
return gzip.NewReader(r)
}
zr := v.(*gzip.Reader)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseGzipReader(zr *gzip.Reader) {
zr.Close()
gzipReaderPool.Put(zr)
}
var gzipReaderPool sync.Pool
func acquireFlateReader(r io.Reader) (io.ReadCloser, error) {
v := flateReaderPool.Get()
if v == nil {
zr, err := zlib.NewReader(r)
if err != nil {
return nil, err
}
return zr, nil
}
zr := v.(io.ReadCloser)
if err := resetFlateReader(zr, r); err != nil {
return nil, err
}
return zr, nil
}
func releaseFlateReader(zr io.ReadCloser) {
zr.Close()
flateReaderPool.Put(zr)
}
func resetFlateReader(zr io.ReadCloser, r io.Reader) error {
zrr, ok := zr.(zlib.Resetter)
if !ok {
panic("BUG: zlib.Reader doesn't implement zlib.Resetter???")
}
return zrr.Reset(r, nil)
}
var flateReaderPool sync.Pool
func acquireStacklessGzipWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeCompressLevel(level)
p := stacklessGzipWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealGzipWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessGzipWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeCompressLevel(level)
p := stacklessGzipWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealGzipWriter(w io.Writer, level int) *gzip.Writer {
nLevel := normalizeCompressLevel(level)
p := realGzipWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := gzip.NewWriterLevel(w, level)
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err))
}
return zw
}
zw := v.(*gzip.Writer)
zw.Reset(w)
return zw
}
func releaseRealGzipWriter(zw *gzip.Writer, level int) {
zw.Close()
nLevel := normalizeCompressLevel(level)
p := realGzipWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessGzipWriterPoolMap = newCompressWriterPoolMap()
realGzipWriterPoolMap = newCompressWriterPoolMap()
)
// AppendGzipBytesLevel appends gzipped src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func AppendGzipBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteGzipLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteGzipLevel writes gzipped p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteGzip
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteGzip(ctx)
return len(p), nil
default:
zw := acquireStacklessGzipWriter(w, level)
n, err := zw.Write(p)
releaseStacklessGzipWriter(zw, level)
return n, err
}
}
var stacklessWriteGzip = stackless.NewFunc(nonblockingWriteGzip)
func nonblockingWriteGzip(ctxv interface{}) {
ctx := ctxv.(*compressCtx)
zw := acquireRealGzipWriter(ctx.w, ctx.level)
_, err := zw.Write(ctx.p)
if err != nil {
panic(fmt.Sprintf("BUG: gzip.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
}
releaseRealGzipWriter(zw, ctx.level)
}
// WriteGzip writes gzipped p to w and returns the number of compressed
// bytes written to w.
func WriteGzip(w io.Writer, p []byte) (int, error) {
return WriteGzipLevel(w, p, CompressDefaultCompression)
}
// AppendGzipBytes appends gzipped src to dst and returns the resulting dst.
func AppendGzipBytes(dst, src []byte) []byte {
return AppendGzipBytesLevel(dst, src, CompressDefaultCompression)
}
// WriteGunzip writes ungzipped p to w and returns the number of uncompressed
// bytes written to w.
func WriteGunzip(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{p}
zr, err := acquireGzipReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseGzipReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data gunzipped: %d", n)
}
return nn, err
}
// AppendGunzipBytes appends gunzipped src to dst and returns the resulting dst.
func AppendGunzipBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{dst}
_, err := WriteGunzip(w, src)
return w.b, err
}
// AppendDeflateBytesLevel appends deflated src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func AppendDeflateBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteDeflateLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteDeflateLevel writes deflated p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func WriteDeflateLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteDeflate
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteDeflate(ctx)
return len(p), nil
default:
zw := acquireStacklessDeflateWriter(w, level)
n, err := zw.Write(p)
releaseStacklessDeflateWriter(zw, level)
return n, err
}
}
var stacklessWriteDeflate = stackless.NewFunc(nonblockingWriteDeflate)
func nonblockingWriteDeflate(ctxv interface{}) {
ctx := ctxv.(*compressCtx)
zw := acquireRealDeflateWriter(ctx.w, ctx.level)
_, err := zw.Write(ctx.p)
if err != nil {
panic(fmt.Sprintf("BUG: zlib.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
}
releaseRealDeflateWriter(zw, ctx.level)
}
type compressCtx struct {
w io.Writer
p []byte
level int
}
// WriteDeflate writes deflated p to w and returns the number of compressed
// bytes written to w.
func WriteDeflate(w io.Writer, p []byte) (int, error) {
return WriteDeflateLevel(w, p, CompressDefaultCompression)
}
// AppendDeflateBytes appends deflated src to dst and returns the resulting dst.
func AppendDeflateBytes(dst, src []byte) []byte {
return AppendDeflateBytesLevel(dst, src, CompressDefaultCompression)
}
// WriteInflate writes inflated p to w and returns the number of uncompressed
// bytes written to w.
func WriteInflate(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{p}
zr, err := acquireFlateReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseFlateReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data inflated: %d", n)
}
return nn, err
}
// AppendInflateBytes appends inflated src to dst and returns the resulting dst.
func AppendInflateBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{dst}
_, err := WriteInflate(w, src)
return w.b, err
}
type byteSliceWriter struct {
b []byte
}
func (w *byteSliceWriter) Write(p []byte) (int, error) {
w.b = append(w.b, p...)
return len(p), nil
}
type byteSliceReader struct {
b []byte
}
func (r *byteSliceReader) Read(p []byte) (int, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := copy(p, r.b)
r.b = r.b[n:]
return n, nil
}
func (r *byteSliceReader) ReadByte() (byte, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := r.b[0]
r.b = r.b[1:]
return n, nil
}
func acquireStacklessDeflateWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeCompressLevel(level)
p := stacklessDeflateWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealDeflateWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessDeflateWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeCompressLevel(level)
p := stacklessDeflateWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealDeflateWriter(w io.Writer, level int) *zlib.Writer {
nLevel := normalizeCompressLevel(level)
p := realDeflateWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := zlib.NewWriterLevel(w, level)
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error from zlib.NewWriterLevel(%d): %s", level, err))
}
return zw
}
zw := v.(*zlib.Writer)
zw.Reset(w)
return zw
}
func releaseRealDeflateWriter(zw *zlib.Writer, level int) {
zw.Close()
nLevel := normalizeCompressLevel(level)
p := realDeflateWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessDeflateWriterPoolMap = newCompressWriterPoolMap()
realDeflateWriterPoolMap = newCompressWriterPoolMap()
)
func newCompressWriterPoolMap() []*sync.Pool {
// Initialize pools for all the compression levels defined
// in https://golang.org/pkg/compress/flate/#pkg-constants .
// Compression levels are normalized with normalizeCompressLevel,
// so the fit [0..11].
var m []*sync.Pool
for i := 0; i < 12; i++ {
m = append(m, &sync.Pool{})
}
return m
}
func isFileCompressible(f *os.File, minCompressRatio float64) bool {
// Try compressing the first 4kb of of the file
// and see if it can be compressed by more than
// the given minCompressRatio.
b := bytebufferpool.Get()
zw := acquireStacklessGzipWriter(b, CompressDefaultCompression)
lr := &io.LimitedReader{
R: f,
N: 4096,
}
_, err := copyZeroAlloc(zw, lr)
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
f.Seek(0, 0) //nolint:errcheck
if err != nil {
return false
}
n := 4096 - lr.N
zn := len(b.B)
bytebufferpool.Put(b)
return float64(zn) < float64(n)*minCompressRatio
}
// normalizes compression level into [0..11], so it could be used as an index
// in *PoolMap.
func normalizeCompressLevel(level int) int {
// -2 is the lowest compression level - CompressHuffmanOnly
// 9 is the highest compression level - CompressBestCompression
if level < -2 || level > 9 {
level = CompressDefaultCompression
}
return level + 2
}
fasthttp-1.31.0/compress_test.go 0000664 0000000 0000000 00000012534 14130360711 0016653 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"io/ioutil"
"testing"
"time"
)
var compressTestcases = func() []string {
a := []string{
"",
"foobar",
"выфаодлодл одлфываыв sd2 k34",
}
bigS := createFixedBody(1e4)
a = append(a, string(bigS))
return a
}()
func TestGzipBytesSerial(t *testing.T) {
t.Parallel()
if err := testGzipBytes(); err != nil {
t.Fatal(err)
}
}
func TestGzipBytesConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testGzipBytes); err != nil {
t.Fatal(err)
}
}
func TestDeflateBytesSerial(t *testing.T) {
t.Parallel()
if err := testDeflateBytes(); err != nil {
t.Fatal(err)
}
}
func TestDeflateBytesConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testDeflateBytes); err != nil {
t.Fatal(err)
}
}
func testGzipBytes() error {
for _, s := range compressTestcases {
if err := testGzipBytesSingleCase(s); err != nil {
return err
}
}
return nil
}
func testDeflateBytes() error {
for _, s := range compressTestcases {
if err := testDeflateBytesSingleCase(s); err != nil {
return err
}
}
return nil
}
func testGzipBytesSingleCase(s string) error {
prefix := []byte("foobar")
gzippedS := AppendGzipBytes(prefix, []byte(s))
if !bytes.Equal(gzippedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, gzippedS[:len(prefix)], prefix)
}
gunzippedS, err := AppendGunzipBytes(prefix, gzippedS[len(prefix):])
if err != nil {
return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
}
if !bytes.Equal(gunzippedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, gunzippedS[:len(prefix)], prefix)
}
gunzippedS = gunzippedS[len(prefix):]
if string(gunzippedS) != s {
return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", gunzippedS, s)
}
return nil
}
func testDeflateBytesSingleCase(s string) error {
prefix := []byte("foobar")
deflatedS := AppendDeflateBytes(prefix, []byte(s))
if !bytes.Equal(deflatedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, deflatedS[:len(prefix)], prefix)
}
inflatedS, err := AppendInflateBytes(prefix, deflatedS[len(prefix):])
if err != nil {
return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
}
if !bytes.Equal(inflatedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, inflatedS[:len(prefix)], prefix)
}
inflatedS = inflatedS[len(prefix):]
if string(inflatedS) != s {
return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", inflatedS, s)
}
return nil
}
func TestGzipCompressSerial(t *testing.T) {
t.Parallel()
if err := testGzipCompress(); err != nil {
t.Fatal(err)
}
}
func TestGzipCompressConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testGzipCompress); err != nil {
t.Fatal(err)
}
}
func TestFlateCompressSerial(t *testing.T) {
t.Parallel()
if err := testFlateCompress(); err != nil {
t.Fatal(err)
}
}
func TestFlateCompressConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(10, testFlateCompress); err != nil {
t.Fatal(err)
}
}
func testGzipCompress() error {
for _, s := range compressTestcases {
if err := testGzipCompressSingleCase(s); err != nil {
return err
}
}
return nil
}
func testFlateCompress() error {
for _, s := range compressTestcases {
if err := testFlateCompressSingleCase(s); err != nil {
return err
}
}
return nil
}
func testGzipCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessGzipWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
zr, err := acquireGzipReader(&buf)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
body, err := ioutil.ReadAll(zr)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
}
releaseGzipReader(zr)
return nil
}
func testFlateCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessDeflateWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
releaseStacklessDeflateWriter(zw, CompressDefaultCompression)
zr, err := acquireFlateReader(&buf)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
body, err := ioutil.ReadAll(zr)
if err != nil {
return fmt.Errorf("unexpected error: %s. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
}
releaseFlateReader(zr)
return nil
}
func testConcurrent(concurrency int, f func() error) error {
ch := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func(idx int) {
err := f()
if err != nil {
ch <- fmt.Errorf("error in goroutine %d: %s", idx, err)
}
ch <- nil
}(i)
}
for i := 0; i < concurrency; i++ {
select {
case err := <-ch:
if err != nil {
return err
}
case <-time.After(time.Second):
return fmt.Errorf("timeout")
}
}
return nil
}
fasthttp-1.31.0/cookie.go 0000664 0000000 0000000 00000032745 14130360711 0015240 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"errors"
"io"
"sync"
"time"
)
var zeroTime time.Time
var (
// CookieExpireDelete may be set on Cookie.Expire for expiring the given cookie.
CookieExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
// CookieExpireUnlimited indicates that the cookie doesn't expire.
CookieExpireUnlimited = zeroTime
)
// CookieSameSite is an enum for the mode in which the SameSite flag should be set for the given cookie.
// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
type CookieSameSite int
const (
// CookieSameSiteDisabled removes the SameSite flag
CookieSameSiteDisabled CookieSameSite = iota
// CookieSameSiteDefaultMode sets the SameSite flag
CookieSameSiteDefaultMode
// CookieSameSiteLaxMode sets the SameSite flag with the "Lax" parameter
CookieSameSiteLaxMode
// CookieSameSiteStrictMode sets the SameSite flag with the "Strict" parameter
CookieSameSiteStrictMode
// CookieSameSiteNoneMode sets the SameSite flag with the "None" parameter
// see https://tools.ietf.org/html/draft-west-cookie-incrementalism-00
CookieSameSiteNoneMode
)
// AcquireCookie returns an empty Cookie object from the pool.
//
// The returned object may be returned back to the pool with ReleaseCookie.
// This allows reducing GC load.
func AcquireCookie() *Cookie {
return cookiePool.Get().(*Cookie)
}
// ReleaseCookie returns the Cookie object acquired with AcquireCookie back
// to the pool.
//
// Do not access released Cookie object, otherwise data races may occur.
func ReleaseCookie(c *Cookie) {
c.Reset()
cookiePool.Put(c)
}
var cookiePool = &sync.Pool{
New: func() interface{} {
return &Cookie{}
},
}
// Cookie represents HTTP response cookie.
//
// Do not copy Cookie objects. Create new object and use CopyTo instead.
//
// Cookie instance MUST NOT be used from concurrently running goroutines.
type Cookie struct {
noCopy noCopy //nolint:unused,structcheck
key []byte
value []byte
expire time.Time
maxAge int
domain []byte
path []byte
httpOnly bool
secure bool
sameSite CookieSameSite
bufKV argsKV
buf []byte
}
// CopyTo copies src cookie to c.
func (c *Cookie) CopyTo(src *Cookie) {
c.Reset()
c.key = append(c.key, src.key...)
c.value = append(c.value, src.value...)
c.expire = src.expire
c.maxAge = src.maxAge
c.domain = append(c.domain, src.domain...)
c.path = append(c.path, src.path...)
c.httpOnly = src.httpOnly
c.secure = src.secure
c.sameSite = src.sameSite
}
// HTTPOnly returns true if the cookie is http only.
func (c *Cookie) HTTPOnly() bool {
return c.httpOnly
}
// SetHTTPOnly sets cookie's httpOnly flag to the given value.
func (c *Cookie) SetHTTPOnly(httpOnly bool) {
c.httpOnly = httpOnly
}
// Secure returns true if the cookie is secure.
func (c *Cookie) Secure() bool {
return c.secure
}
// SetSecure sets cookie's secure flag to the given value.
func (c *Cookie) SetSecure(secure bool) {
c.secure = secure
}
// SameSite returns the SameSite mode.
func (c *Cookie) SameSite() CookieSameSite {
return c.sameSite
}
// SetSameSite sets the cookie's SameSite flag to the given value.
// set value CookieSameSiteNoneMode will set Secure to true also to avoid browser rejection
func (c *Cookie) SetSameSite(mode CookieSameSite) {
c.sameSite = mode
if mode == CookieSameSiteNoneMode {
c.SetSecure(true)
}
}
// Path returns cookie path.
func (c *Cookie) Path() []byte {
return c.path
}
// SetPath sets cookie path.
func (c *Cookie) SetPath(path string) {
c.buf = append(c.buf[:0], path...)
c.path = normalizePath(c.path, c.buf)
}
// SetPathBytes sets cookie path.
func (c *Cookie) SetPathBytes(path []byte) {
c.buf = append(c.buf[:0], path...)
c.path = normalizePath(c.path, c.buf)
}
// Domain returns cookie domain.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Domain() []byte {
return c.domain
}
// SetDomain sets cookie domain.
func (c *Cookie) SetDomain(domain string) {
c.domain = append(c.domain[:0], domain...)
}
// SetDomainBytes sets cookie domain.
func (c *Cookie) SetDomainBytes(domain []byte) {
c.domain = append(c.domain[:0], domain...)
}
// MaxAge returns the seconds until the cookie is meant to expire or 0
// if no max age.
func (c *Cookie) MaxAge() int {
return c.maxAge
}
// SetMaxAge sets cookie expiration time based on seconds. This takes precedence
// over any absolute expiry set on the cookie
//
// Set max age to 0 to unset
func (c *Cookie) SetMaxAge(seconds int) {
c.maxAge = seconds
}
// Expire returns cookie expiration time.
//
// CookieExpireUnlimited is returned if cookie doesn't expire
func (c *Cookie) Expire() time.Time {
expire := c.expire
if expire.IsZero() {
expire = CookieExpireUnlimited
}
return expire
}
// SetExpire sets cookie expiration time.
//
// Set expiration time to CookieExpireDelete for expiring (deleting)
// the cookie on the client.
//
// By default cookie lifetime is limited by browser session.
func (c *Cookie) SetExpire(expire time.Time) {
c.expire = expire
}
// Value returns cookie value.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Value() []byte {
return c.value
}
// SetValue sets cookie value.
func (c *Cookie) SetValue(value string) {
c.value = append(c.value[:0], value...)
}
// SetValueBytes sets cookie value.
func (c *Cookie) SetValueBytes(value []byte) {
c.value = append(c.value[:0], value...)
}
// Key returns cookie name.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Key() []byte {
return c.key
}
// SetKey sets cookie name.
func (c *Cookie) SetKey(key string) {
c.key = append(c.key[:0], key...)
}
// SetKeyBytes sets cookie name.
func (c *Cookie) SetKeyBytes(key []byte) {
c.key = append(c.key[:0], key...)
}
// Reset clears the cookie.
func (c *Cookie) Reset() {
c.key = c.key[:0]
c.value = c.value[:0]
c.expire = zeroTime
c.maxAge = 0
c.domain = c.domain[:0]
c.path = c.path[:0]
c.httpOnly = false
c.secure = false
c.sameSite = CookieSameSiteDisabled
}
// AppendBytes appends cookie representation to dst and returns
// the extended dst.
func (c *Cookie) AppendBytes(dst []byte) []byte {
if len(c.key) > 0 {
dst = append(dst, c.key...)
dst = append(dst, '=')
}
dst = append(dst, c.value...)
if c.maxAge > 0 {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieMaxAge...)
dst = append(dst, '=')
dst = AppendUint(dst, c.maxAge)
} else if !c.expire.IsZero() {
c.bufKV.value = AppendHTTPDate(c.bufKV.value[:0], c.expire)
dst = append(dst, ';', ' ')
dst = append(dst, strCookieExpires...)
dst = append(dst, '=')
dst = append(dst, c.bufKV.value...)
}
if len(c.domain) > 0 {
dst = appendCookiePart(dst, strCookieDomain, c.domain)
}
if len(c.path) > 0 {
dst = appendCookiePart(dst, strCookiePath, c.path)
}
if c.httpOnly {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieHTTPOnly...)
}
if c.secure {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSecure...)
}
switch c.sameSite {
case CookieSameSiteDefaultMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
case CookieSameSiteLaxMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteLax...)
case CookieSameSiteStrictMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteStrict...)
case CookieSameSiteNoneMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteNone...)
}
return dst
}
// Cookie returns cookie representation.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Cookie() []byte {
c.buf = c.AppendBytes(c.buf[:0])
return c.buf
}
// String returns cookie representation.
func (c *Cookie) String() string {
return string(c.Cookie())
}
// WriteTo writes cookie representation to w.
//
// WriteTo implements io.WriterTo interface.
func (c *Cookie) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(c.Cookie())
return int64(n), err
}
var errNoCookies = errors.New("no cookies found")
// Parse parses Set-Cookie header.
func (c *Cookie) Parse(src string) error {
c.buf = append(c.buf[:0], src...)
return c.ParseBytes(c.buf)
}
// ParseBytes parses Set-Cookie header.
func (c *Cookie) ParseBytes(src []byte) error {
c.Reset()
var s cookieScanner
s.b = src
kv := &c.bufKV
if !s.next(kv) {
return errNoCookies
}
c.key = append(c.key, kv.key...)
c.value = append(c.value, kv.value...)
for s.next(kv) {
if len(kv.key) != 0 {
// Case insensitive switch on first char
switch kv.key[0] | 0x20 {
case 'm':
if caseInsensitiveCompare(strCookieMaxAge, kv.key) {
maxAge, err := ParseUint(kv.value)
if err != nil {
return err
}
c.maxAge = maxAge
}
case 'e': // "expires"
if caseInsensitiveCompare(strCookieExpires, kv.key) {
v := b2s(kv.value)
// Try the same two formats as net/http
// See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135
exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC)
if err != nil {
exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v)
if err != nil {
return err
}
}
c.expire = exptime
}
case 'd': // "domain"
if caseInsensitiveCompare(strCookieDomain, kv.key) {
c.domain = append(c.domain, kv.value...)
}
case 'p': // "path"
if caseInsensitiveCompare(strCookiePath, kv.key) {
c.path = append(c.path, kv.value...)
}
case 's': // "samesite"
if caseInsensitiveCompare(strCookieSameSite, kv.key) {
if len(kv.value) > 0 {
// Case insensitive switch on first char
switch kv.value[0] | 0x20 {
case 'l': // "lax"
if caseInsensitiveCompare(strCookieSameSiteLax, kv.value) {
c.sameSite = CookieSameSiteLaxMode
}
case 's': // "strict"
if caseInsensitiveCompare(strCookieSameSiteStrict, kv.value) {
c.sameSite = CookieSameSiteStrictMode
}
case 'n': // "none"
if caseInsensitiveCompare(strCookieSameSiteNone, kv.value) {
c.sameSite = CookieSameSiteNoneMode
}
}
}
}
}
} else if len(kv.value) != 0 {
// Case insensitive switch on first char
switch kv.value[0] | 0x20 {
case 'h': // "httponly"
if caseInsensitiveCompare(strCookieHTTPOnly, kv.value) {
c.httpOnly = true
}
case 's': // "secure"
if caseInsensitiveCompare(strCookieSecure, kv.value) {
c.secure = true
} else if caseInsensitiveCompare(strCookieSameSite, kv.value) {
c.sameSite = CookieSameSiteDefaultMode
}
}
} // else empty or no match
}
return nil
}
func appendCookiePart(dst, key, value []byte) []byte {
dst = append(dst, ';', ' ')
dst = append(dst, key...)
dst = append(dst, '=')
return append(dst, value...)
}
func getCookieKey(dst, src []byte) []byte {
n := bytes.IndexByte(src, '=')
if n >= 0 {
src = src[:n]
}
return decodeCookieArg(dst, src, false)
}
func appendRequestCookieBytes(dst []byte, cookies []argsKV) []byte {
for i, n := 0, len(cookies); i < n; i++ {
kv := &cookies[i]
if len(kv.key) > 0 {
dst = append(dst, kv.key...)
dst = append(dst, '=')
}
dst = append(dst, kv.value...)
if i+1 < n {
dst = append(dst, ';', ' ')
}
}
return dst
}
// For Response we can not use the above function as response cookies
// already contain the key= in the value.
func appendResponseCookieBytes(dst []byte, cookies []argsKV) []byte {
for i, n := 0, len(cookies); i < n; i++ {
kv := &cookies[i]
dst = append(dst, kv.value...)
if i+1 < n {
dst = append(dst, ';', ' ')
}
}
return dst
}
func parseRequestCookies(cookies []argsKV, src []byte) []argsKV {
var s cookieScanner
s.b = src
var kv *argsKV
cookies, kv = allocArg(cookies)
for s.next(kv) {
if len(kv.key) > 0 || len(kv.value) > 0 {
cookies, kv = allocArg(cookies)
}
}
return releaseArg(cookies)
}
type cookieScanner struct {
b []byte
}
func (s *cookieScanner) next(kv *argsKV) bool {
b := s.b
if len(b) == 0 {
return false
}
isKey := true
k := 0
for i, c := range b {
switch c {
case '=':
if isKey {
isKey = false
kv.key = decodeCookieArg(kv.key, b[:i], false)
k = i + 1
}
case ';':
if isKey {
kv.key = kv.key[:0]
}
kv.value = decodeCookieArg(kv.value, b[k:i], true)
s.b = b[i+1:]
return true
}
}
if isKey {
kv.key = kv.key[:0]
}
kv.value = decodeCookieArg(kv.value, b[k:], true)
s.b = b[len(b):]
return true
}
func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte {
for len(src) > 0 && src[0] == ' ' {
src = src[1:]
}
for len(src) > 0 && src[len(src)-1] == ' ' {
src = src[:len(src)-1]
}
if skipQuotes {
if len(src) > 1 && src[0] == '"' && src[len(src)-1] == '"' {
src = src[1 : len(src)-1]
}
}
return append(dst[:0], src...)
}
// caseInsensitiveCompare does a case insensitive equality comparison of
// two []byte. Assumes only letters need to be matched.
func caseInsensitiveCompare(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i]|0x20 != b[i]|0x20 {
return false
}
}
return true
}
fasthttp-1.31.0/cookie_test.go 0000664 0000000 0000000 00000025176 14130360711 0016277 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"strings"
"testing"
"time"
)
func TestCookiePanic(t *testing.T) {
t.Parallel()
var c Cookie
if err := c.Parse(";SAMeSITe="); err != nil {
t.Error(err)
}
}
func TestCookieValueWithEqualAndSpaceChars(t *testing.T) {
t.Parallel()
testCookieValueWithEqualAndSpaceChars(t, "sth1", "/", "MTQ2NjU5NTcwN3xfUVduVXk4aG9jSmZaNzNEb1dGa1VjekY1bG9vMmxSWlJBZUN2Q1ZtZVFNMTk2YU9YaWtCVmY1eDRWZXd3M3Q5RTJRZnZMbk5mWklSSFZJcVlXTDhiSFFHWWdpdFVLd1hwbXR2UUN4QlJ1N3BITFpkS3Y4PXzDvPNn6JVDBFB2wYVYPHdkdlZBm6n1_0QB3_GWwE40Tg ==")
testCookieValueWithEqualAndSpaceChars(t, "sth2", "/", "123")
testCookieValueWithEqualAndSpaceChars(t, "sth3", "/", "123 == 1")
}
func testCookieValueWithEqualAndSpaceChars(t *testing.T, expectedName, expectedPath, expectedValue string) {
var c Cookie
c.SetKey(expectedName)
c.SetPath(expectedPath)
c.SetValue(expectedValue)
s := c.String()
var c1 Cookie
if err := c1.Parse(s); err != nil {
t.Fatalf("unexpected error: %s", err)
}
name := c1.Key()
if string(name) != expectedName {
t.Fatalf("unexpected name %q. Expecting %q", name, expectedName)
}
path := c1.Path()
if string(path) != expectedPath {
t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath)
}
value := c1.Value()
if string(value) != expectedValue {
t.Fatalf("unexpected value %q. Expecting %q", value, expectedValue)
}
}
func TestCookieSecureHttpOnly(t *testing.T) {
t.Parallel()
var c Cookie
if err := c.Parse("foo=bar; HttpOnly; secure"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !c.Secure() {
t.Fatalf("secure must be set")
}
if !c.HTTPOnly() {
t.Fatalf("HttpOnly must be set")
}
s := c.String()
if !strings.Contains(s, "; secure") {
t.Fatalf("missing secure flag in cookie %q", s)
}
if !strings.Contains(s, "; HttpOnly") {
t.Fatalf("missing HttpOnly flag in cookie %q", s)
}
}
func TestCookieSecure(t *testing.T) {
t.Parallel()
var c Cookie
if err := c.Parse("foo=bar; secure"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !c.Secure() {
t.Fatalf("secure must be set")
}
s := c.String()
if !strings.Contains(s, "; secure") {
t.Fatalf("missing secure flag in cookie %q", s)
}
if err := c.Parse("foo=bar"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.Secure() {
t.Fatalf("Unexpected secure flag set")
}
s = c.String()
if strings.Contains(s, "secure") {
t.Fatalf("unexpected secure flag in cookie %q", s)
}
}
func TestCookieSameSite(t *testing.T) {
t.Parallel()
var c Cookie
if err := c.Parse("foo=bar; samesite"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.SameSite() != CookieSameSiteDefaultMode {
t.Fatalf("SameSite must be set")
}
s := c.String()
if !strings.Contains(s, "; SameSite") {
t.Fatalf("missing SameSite flag in cookie %q", s)
}
if err := c.Parse("foo=bar; samesite=lax"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.SameSite() != CookieSameSiteLaxMode {
t.Fatalf("SameSite Lax Mode must be set")
}
s = c.String()
if !strings.Contains(s, "; SameSite=Lax") {
t.Fatalf("missing SameSite flag in cookie %q", s)
}
if err := c.Parse("foo=bar; samesite=strict"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.SameSite() != CookieSameSiteStrictMode {
t.Fatalf("SameSite Strict Mode must be set")
}
s = c.String()
if !strings.Contains(s, "; SameSite=Strict") {
t.Fatalf("missing SameSite flag in cookie %q", s)
}
if err := c.Parse("foo=bar; samesite=none"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.SameSite() != CookieSameSiteNoneMode {
t.Fatalf("SameSite None Mode must be set")
}
s = c.String()
if !strings.Contains(s, "; SameSite=None") {
t.Fatalf("missing SameSite flag in cookie %q", s)
}
if err := c.Parse("foo=bar"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
c.SetSameSite(CookieSameSiteNoneMode)
s = c.String()
if !strings.Contains(s, "; SameSite=None") {
t.Fatalf("missing SameSite flag in cookie %q", s)
}
if !strings.Contains(s, "; secure") {
t.Fatalf("missing Secure flag in cookie %q", s)
}
if err := c.Parse("foo=bar"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.SameSite() != CookieSameSiteDisabled {
t.Fatalf("Unexpected SameSite flag set")
}
s = c.String()
if strings.Contains(s, "SameSite") {
t.Fatalf("unexpected SameSite flag in cookie %q", s)
}
}
func TestCookieMaxAge(t *testing.T) {
t.Parallel()
var c Cookie
maxAge := 100
if err := c.Parse("foo=bar; max-age=100"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if maxAge != c.MaxAge() {
t.Fatalf("max-age must be set")
}
s := c.String()
if !strings.Contains(s, "; max-age=100") {
t.Fatalf("missing max-age flag in cookie %q", s)
}
if err := c.Parse("foo=bar; expires=Tue, 10 Nov 2009 23:00:00 GMT; max-age=100;"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if maxAge != c.MaxAge() {
t.Fatalf("max-age ignored")
}
s = c.String()
if s != "foo=bar; max-age=100" {
t.Fatalf("missing max-age in cookie %q", s)
}
expires := time.Unix(100, 0)
c.SetExpire(expires)
s = c.String()
if s != "foo=bar; max-age=100" {
t.Fatalf("expires should be ignored due to max-age: %q", s)
}
c.SetMaxAge(0)
s = c.String()
if s != "foo=bar; expires=Thu, 01 Jan 1970 00:01:40 GMT" {
t.Fatalf("missing expires %q", s)
}
}
func TestCookieHttpOnly(t *testing.T) {
t.Parallel()
var c Cookie
if err := c.Parse("foo=bar; HttpOnly"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !c.HTTPOnly() {
t.Fatalf("HTTPOnly must be set")
}
s := c.String()
if !strings.Contains(s, "; HttpOnly") {
t.Fatalf("missing HttpOnly flag in cookie %q", s)
}
if err := c.Parse("foo=bar"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if c.HTTPOnly() {
t.Fatalf("Unexpected HTTPOnly flag set")
}
s = c.String()
if strings.Contains(s, "HttpOnly") {
t.Fatalf("unexpected HttpOnly flag in cookie %q", s)
}
}
func TestCookieAcquireReleaseSequential(t *testing.T) {
t.Parallel()
testCookieAcquireRelease(t)
}
func TestCookieAcquireReleaseConcurrent(t *testing.T) {
t.Parallel()
ch := make(chan struct{}, 10)
for i := 0; i < 10; i++ {
go func() {
testCookieAcquireRelease(t)
ch <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testCookieAcquireRelease(t *testing.T) {
c := AcquireCookie()
key := "foo"
c.SetKey(key)
value := "bar"
c.SetValue(value)
domain := "foo.bar.com"
c.SetDomain(domain)
path := "/foi/bar/aaa"
c.SetPath(path)
s := c.String()
c.Reset()
if err := c.Parse(s); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(c.Key()) != key {
t.Fatalf("unexpected cookie name %q. Expecting %q", c.Key(), key)
}
if string(c.Value()) != value {
t.Fatalf("unexpected cookie value %q. Expecting %q", c.Value(), value)
}
if string(c.Domain()) != domain {
t.Fatalf("unexpected domain %q. Expecting %q", c.Domain(), domain)
}
if string(c.Path()) != path {
t.Fatalf("unexpected path %q. Expecting %q", c.Path(), path)
}
ReleaseCookie(c)
}
func TestCookieParse(t *testing.T) {
t.Parallel()
testCookieParse(t, "foo", "foo")
testCookieParse(t, "foo=bar", "foo=bar")
testCookieParse(t, "foo=", "foo=")
testCookieParse(t, `foo="bar"`, "foo=bar")
testCookieParse(t, `"foo"=bar`, `"foo"=bar`)
testCookieParse(t, "foo=bar; Domain=aaa.com; PATH=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar")
testCookieParse(t, "foo=bar; max-age= 101 ; expires= Tue, 10 Nov 2009 23:00:00 GMT", "foo=bar; max-age=101")
testCookieParse(t, " xxx = yyy ; path=/a/b;;;domain=foobar.com ; expires= Tue, 10 Nov 2009 23:00:00 GMT ; ;;",
"xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b")
}
func testCookieParse(t *testing.T, s, expectedS string) {
var c Cookie
if err := c.Parse(s); err != nil {
t.Fatalf("unexpected error: %s", err)
}
result := string(c.Cookie())
if result != expectedS {
t.Fatalf("unexpected cookies %q. Expecting %q. Original %q", result, expectedS, s)
}
}
func TestCookieAppendBytes(t *testing.T) {
t.Parallel()
c := &Cookie{}
testCookieAppendBytes(t, c, "", "bar", "bar")
testCookieAppendBytes(t, c, "foo", "", "foo=")
testCookieAppendBytes(t, c, "ффф", "12 лодлы", "ффф=12 лодлы")
c.SetDomain("foobar.com")
testCookieAppendBytes(t, c, "a", "b", "a=b; domain=foobar.com")
c.SetPath("/a/b")
testCookieAppendBytes(t, c, "aa", "bb", "aa=bb; domain=foobar.com; path=/a/b")
c.SetExpire(CookieExpireDelete)
testCookieAppendBytes(t, c, "xxx", "yyy", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b")
}
func testCookieAppendBytes(t *testing.T, c *Cookie, key, value, expectedS string) {
c.SetKey(key)
c.SetValue(value)
result := string(c.AppendBytes(nil))
if result != expectedS {
t.Fatalf("Unexpected cookie %q. Expecting %q", result, expectedS)
}
}
func TestParseRequestCookies(t *testing.T) {
t.Parallel()
testParseRequestCookies(t, "", "")
testParseRequestCookies(t, "=", "")
testParseRequestCookies(t, "foo", "foo")
testParseRequestCookies(t, "=foo", "foo")
testParseRequestCookies(t, "bar=", "bar=")
testParseRequestCookies(t, "xxx=aa;bb=c; =d; ;;e=g", "xxx=aa; bb=c; d; e=g")
testParseRequestCookies(t, "a;b;c; d=1;d=2", "a; b; c; d=1; d=2")
testParseRequestCookies(t, " %D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc ;s%20s=aaa ", "%D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc; s%20s=aaa")
}
func testParseRequestCookies(t *testing.T, s, expectedS string) {
cookies := parseRequestCookies(nil, []byte(s))
ss := string(appendRequestCookieBytes(nil, cookies))
if ss != expectedS {
t.Fatalf("Unexpected cookies after parsing: %q. Expecting %q. String to parse %q", ss, expectedS, s)
}
}
func TestAppendRequestCookieBytes(t *testing.T) {
t.Parallel()
testAppendRequestCookieBytes(t, "=", "")
testAppendRequestCookieBytes(t, "foo=", "foo=")
testAppendRequestCookieBytes(t, "=bar", "bar")
testAppendRequestCookieBytes(t, "привет=a bc&s s=aaa", "привет=a bc; s s=aaa")
}
func testAppendRequestCookieBytes(t *testing.T, s, expectedS string) {
kvs := strings.Split(s, "&")
cookies := make([]argsKV, 0, len(kvs))
for _, ss := range kvs {
tmp := strings.SplitN(ss, "=", 2)
if len(tmp) != 2 {
t.Fatalf("Cannot find '=' in %q, part of %q", ss, s)
}
cookies = append(cookies, argsKV{
key: []byte(tmp[0]),
value: []byte(tmp[1]),
})
}
prefix := "foobar"
result := string(appendRequestCookieBytes([]byte(prefix), cookies))
if result[:len(prefix)] != prefix {
t.Fatalf("unexpected prefix %q. Expecting %q for cookie %q", result[:len(prefix)], prefix, s)
}
result = result[len(prefix):]
if result != expectedS {
t.Fatalf("Unexpected result %q. Expecting %q for cookie %q", result, expectedS, s)
}
}
fasthttp-1.31.0/cookie_timing_test.go 0000664 0000000 0000000 00000001465 14130360711 0017641 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"testing"
)
func BenchmarkCookieParseMin(b *testing.B) {
var c Cookie
s := []byte("xxx=yyy")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
b.Fatalf("unexpected error when parsing cookies: %s", err)
}
}
}
func BenchmarkCookieParseNoExpires(b *testing.B) {
var c Cookie
s := []byte("xxx=yyy; domain=foobar.com; path=/a/b")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
b.Fatalf("unexpected error when parsing cookies: %s", err)
}
}
}
func BenchmarkCookieParseFull(b *testing.B) {
var c Cookie
s := []byte("xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
b.Fatalf("unexpected error when parsing cookies: %s", err)
}
}
}
fasthttp-1.31.0/doc.go 0000664 0000000 0000000 00000002575 14130360711 0014532 0 ustar 00root root 0000000 0000000 /*
Package fasthttp provides fast HTTP server and client API.
Fasthttp provides the following features:
* Optimized for speed. Easily handles more than 100K qps and more than 1M
concurrent keep-alive connections on modern hardware.
* Optimized for low memory usage.
* Easy 'Connection: Upgrade' support via RequestCtx.Hijack.
* Server provides the following anti-DoS limits:
* The number of concurrent connections.
* The number of concurrent connections per client IP.
* The number of requests per connection.
* Request read timeout.
* Response write timeout.
* Maximum request header size.
* Maximum request body size.
* Maximum request execution time.
* Maximum keep-alive connection lifetime.
* Early filtering out non-GET requests.
* A lot of additional useful info is exposed to request handler:
* Server and client address.
* Per-request logger.
* Unique request id.
* Request start time.
* Connection start time.
* Request sequence number for the current connection.
* Client supports automatic retry on idempotent requests' failure.
* Fasthttp API is designed with the ability to extend existing client
and server implementations or to write custom client and server
implementations from scratch.
*/
package fasthttp
fasthttp-1.31.0/examples/ 0000775 0000000 0000000 00000000000 14130360711 0015243 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/examples/README.md 0000664 0000000 0000000 00000000134 14130360711 0016520 0 ustar 00root root 0000000 0000000 # Code examples
* [HelloWorld server](helloworldserver)
* [Static file server](fileserver)
fasthttp-1.31.0/examples/fileserver/ 0000775 0000000 0000000 00000000000 14130360711 0017411 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/examples/fileserver/.gitignore 0000664 0000000 0000000 00000000013 14130360711 0021373 0 ustar 00root root 0000000 0000000 fileserver
fasthttp-1.31.0/examples/fileserver/Makefile 0000664 0000000 0000000 00000000222 14130360711 0021045 0 ustar 00root root 0000000 0000000 fileserver: clean
go get -u github.com/valyala/fasthttp
go get -u github.com/valyala/fasthttp/expvarhandler
go build
clean:
rm -f fileserver
fasthttp-1.31.0/examples/fileserver/README.md 0000664 0000000 0000000 00000004132 14130360711 0020670 0 ustar 00root root 0000000 0000000 # Static file server example
* Serves files from the given directory.
* Supports transparent response compression.
* Supports byte range responses.
* Generates directory index pages.
* Supports TLS (aka SSL or HTTPS).
* Supports virtual hosts.
* Exports various stats on /stats path.
# How to build
```
make
```
# How to run
```
./fileserver -h
./fileserver -addr=tcp.addr.to.listen:to -dir=/path/to/directory/to/serve
```
# fileserver vs nginx performance comparison
Serving default nginx path (`/usr/share/nginx/html` on ubuntu).
* nginx
```
$ ./wrk -t 4 -c 16 -d 10 http://localhost:80
Running 10s test @ http://localhost:80
4 threads and 16 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 397.76us 1.08ms 20.23ms 95.19%
Req/Sec 21.20k 2.49k 31.34k 79.65%
850220 requests in 10.10s, 695.65MB read
Requests/sec: 84182.71
Transfer/sec: 68.88MB
```
* fileserver
```
$ ./wrk -t 4 -c 16 -d 10 http://localhost:8080
Running 10s test @ http://localhost:8080
4 threads and 16 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 447.99us 1.59ms 27.20ms 94.79%
Req/Sec 37.13k 3.99k 47.86k 76.00%
1478457 requests in 10.02s, 1.03GB read
Requests/sec: 147597.06
Transfer/sec: 105.15MB
```
8 pipelined requests
* nginx
```
$ ./wrk -s pipeline.lua -t 4 -c 16 -d 10 http://localhost:80 -- 8
Running 10s test @ http://localhost:80
4 threads and 16 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 1.34ms 2.15ms 30.91ms 92.16%
Req/Sec 33.54k 7.36k 108.12k 76.81%
1339908 requests in 10.10s, 1.07GB read
Requests/sec: 132705.81
Transfer/sec: 108.58MB
```
* fileserver
```
$ ./wrk -s pipeline.lua -t 4 -c 16 -d 10 http://localhost:8080 -- 8
Running 10s test @ http://localhost:8080
4 threads and 16 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 2.08ms 6.33ms 88.26ms 92.83%
Req/Sec 116.54k 14.66k 167.98k 69.00%
4642226 requests in 10.03s, 3.23GB read
Requests/sec: 462769.41
Transfer/sec: 329.67MB
```
fasthttp-1.31.0/examples/fileserver/fileserver.go 0000664 0000000 0000000 00000007301 14130360711 0022107 0 ustar 00root root 0000000 0000000 // Example static file server.
//
// Serves static files from the given directory.
// Exports various stats at /stats .
package main
import (
"expvar"
"flag"
"log"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/expvarhandler"
)
var (
addr = flag.String("addr", "localhost:8080", "TCP address to listen to")
addrTLS = flag.String("addrTLS", "", "TCP address to listen to TLS (aka SSL or HTTPS) requests. Leave empty for disabling TLS")
byteRange = flag.Bool("byteRange", false, "Enables byte range requests if set to true")
certFile = flag.String("certFile", "./ssl-cert.pem", "Path to TLS certificate file")
compress = flag.Bool("compress", false, "Enables transparent response compression if set to true")
dir = flag.String("dir", "/usr/share/nginx/html", "Directory to serve static files from")
generateIndexPages = flag.Bool("generateIndexPages", true, "Whether to generate directory index pages")
keyFile = flag.String("keyFile", "./ssl-cert.key", "Path to TLS key file")
vhost = flag.Bool("vhost", false, "Enables virtual hosting by prepending the requested path with the requested hostname")
)
func main() {
// Parse command-line flags.
flag.Parse()
// Setup FS handler
fs := &fasthttp.FS{
Root: *dir,
IndexNames: []string{"index.html"},
GenerateIndexPages: *generateIndexPages,
Compress: *compress,
AcceptByteRange: *byteRange,
}
if *vhost {
fs.PathRewrite = fasthttp.NewVHostPathRewriter(0)
}
fsHandler := fs.NewRequestHandler()
// Create RequestHandler serving server stats on /stats and files
// on other requested paths.
// /stats output may be filtered using regexps. For example:
//
// * /stats?r=fs will show only stats (expvars) containing 'fs'
// in their names.
requestHandler := func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Path()) {
case "/stats":
expvarhandler.ExpvarHandler(ctx)
default:
fsHandler(ctx)
updateFSCounters(ctx)
}
}
// Start HTTP server.
if len(*addr) > 0 {
log.Printf("Starting HTTP server on %q", *addr)
go func() {
if err := fasthttp.ListenAndServe(*addr, requestHandler); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}()
}
// Start HTTPS server.
if len(*addrTLS) > 0 {
log.Printf("Starting HTTPS server on %q", *addrTLS)
go func() {
if err := fasthttp.ListenAndServeTLS(*addrTLS, *certFile, *keyFile, requestHandler); err != nil {
log.Fatalf("error in ListenAndServeTLS: %s", err)
}
}()
}
log.Printf("Serving files from directory %q", *dir)
log.Printf("See stats at http://%s/stats", *addr)
// Wait forever.
select {}
}
func updateFSCounters(ctx *fasthttp.RequestCtx) {
// Increment the number of fsHandler calls.
fsCalls.Add(1)
// Update other stats counters
resp := &ctx.Response
switch resp.StatusCode() {
case fasthttp.StatusOK:
fsOKResponses.Add(1)
fsResponseBodyBytes.Add(int64(resp.Header.ContentLength()))
case fasthttp.StatusNotModified:
fsNotModifiedResponses.Add(1)
case fasthttp.StatusNotFound:
fsNotFoundResponses.Add(1)
default:
fsOtherResponses.Add(1)
}
}
// Various counters - see https://golang.org/pkg/expvar/ for details.
var (
// Counter for total number of fs calls
fsCalls = expvar.NewInt("fsCalls")
// Counters for various response status codes
fsOKResponses = expvar.NewInt("fsOKResponses")
fsNotModifiedResponses = expvar.NewInt("fsNotModifiedResponses")
fsNotFoundResponses = expvar.NewInt("fsNotFoundResponses")
fsOtherResponses = expvar.NewInt("fsOtherResponses")
// Total size in bytes for OK response bodies served.
fsResponseBodyBytes = expvar.NewInt("fsResponseBodyBytes")
)
fasthttp-1.31.0/examples/fileserver/ssl-cert-snakeoil.key 0000664 0000000 0000000 00000003250 14130360711 0023462 0 ustar 00root root 0000000 0000000 -----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG
3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U
wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0
FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf
IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg
GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF
sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2
sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D
uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb
K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3
YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+
DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk
B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV
Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x
IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY
wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj
wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D
FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m
tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX
fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU
ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk
K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT
6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt
9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN
Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV
c257YgaWmjK9uB0Y2r2VxS0G
-----END PRIVATE KEY-----
fasthttp-1.31.0/examples/fileserver/ssl-cert-snakeoil.pem 0000664 0000000 0000000 00000001755 14130360711 0023463 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV
BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV
MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D
K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te
+z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij
L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1
xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY
6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG
SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98
L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2
45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li
K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6
X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI
whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd
-----END CERTIFICATE-----
fasthttp-1.31.0/examples/helloworldserver/ 0000775 0000000 0000000 00000000000 14130360711 0020645 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/examples/helloworldserver/.gitignore 0000664 0000000 0000000 00000000021 14130360711 0022626 0 ustar 00root root 0000000 0000000 helloworldserver
fasthttp-1.31.0/examples/helloworldserver/Makefile 0000664 0000000 0000000 00000000151 14130360711 0022302 0 ustar 00root root 0000000 0000000 helloworldserver: clean
go get -u github.com/valyala/fasthttp
go build
clean:
rm -f helloworldserver
fasthttp-1.31.0/examples/helloworldserver/README.md 0000664 0000000 0000000 00000000353 14130360711 0022125 0 ustar 00root root 0000000 0000000 # HelloWorld server example
* Displays various request info.
* Sets response headers and cookies.
* Supports transparent compression.
# How to build
```
make
```
# How to run
```
./helloworldserver -addr=tcp.addr.to.listen:to
```
fasthttp-1.31.0/examples/helloworldserver/helloworldserver.go 0000664 0000000 0000000 00000002776 14130360711 0024612 0 ustar 00root root 0000000 0000000 package main
import (
"flag"
"fmt"
"log"
"github.com/valyala/fasthttp"
)
var (
addr = flag.String("addr", ":8080", "TCP address to listen to")
compress = flag.Bool("compress", false, "Whether to enable transparent response compression")
)
func main() {
flag.Parse()
h := requestHandler
if *compress {
h = fasthttp.CompressHandler(h)
}
if err := fasthttp.ListenAndServe(*addr, h); err != nil {
log.Fatalf("Error in ListenAndServe: %s", err)
}
}
func requestHandler(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hello, world!\n\n")
fmt.Fprintf(ctx, "Request method is %q\n", ctx.Method())
fmt.Fprintf(ctx, "RequestURI is %q\n", ctx.RequestURI())
fmt.Fprintf(ctx, "Requested path is %q\n", ctx.Path())
fmt.Fprintf(ctx, "Host is %q\n", ctx.Host())
fmt.Fprintf(ctx, "Query string is %q\n", ctx.QueryArgs())
fmt.Fprintf(ctx, "User-Agent is %q\n", ctx.UserAgent())
fmt.Fprintf(ctx, "Connection has been established at %s\n", ctx.ConnTime())
fmt.Fprintf(ctx, "Request has been started at %s\n", ctx.Time())
fmt.Fprintf(ctx, "Serial request number for the current connection is %d\n", ctx.ConnRequestNum())
fmt.Fprintf(ctx, "Your ip is %q\n\n", ctx.RemoteIP())
fmt.Fprintf(ctx, "Raw request is:\n---CUT---\n%s\n---CUT---", &ctx.Request)
ctx.SetContentType("text/plain; charset=utf8")
// Set arbitrary headers
ctx.Response.Header.Set("X-My-Header", "my-header-value")
// Set cookies
var c fasthttp.Cookie
c.SetKey("cookie-name")
c.SetValue("cookie-value")
ctx.Response.Header.SetCookie(&c)
}
fasthttp-1.31.0/examples/letsencrypt/ 0000775 0000000 0000000 00000000000 14130360711 0017617 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/examples/letsencrypt/letsencryptserver.go 0000664 0000000 0000000 00000001474 14130360711 0023757 0 ustar 00root root 0000000 0000000 package main
import (
"crypto/tls"
"net"
"github.com/valyala/fasthttp"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
func requestHandler(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("hello from https!")
}
func main() {
m := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.com"), // Replace with your domain.
Cache: autocert.DirCache("./certs"),
}
cfg := &tls.Config{
GetCertificate: m.GetCertificate,
NextProtos: []string{
"http/1.1", acme.ALPNProto,
},
}
// Let's Encrypt tls-alpn-01 only works on port 443.
ln, err := net.Listen("tcp4", "0.0.0.0:443") /* #nosec G102 */
if err != nil {
panic(err)
}
lnTls := tls.NewListener(ln, cfg)
if err := fasthttp.Serve(lnTls, requestHandler); err != nil {
panic(err)
}
}
fasthttp-1.31.0/examples/multidomain/ 0000775 0000000 0000000 00000000000 14130360711 0017565 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/examples/multidomain/Makefile 0000664 0000000 0000000 00000000132 14130360711 0021221 0 ustar 00root root 0000000 0000000 writer: clean
go get -u github.com/valyala/fasthttp
go build
clean:
rm -f multidomain
fasthttp-1.31.0/examples/multidomain/README.md 0000664 0000000 0000000 00000000233 14130360711 0021042 0 ustar 00root root 0000000 0000000 # Multidomain using SSL certs example
* Prints two messages depending on visited host.
# How to build
```
make
```
# How to run
```
./multidomain
```
fasthttp-1.31.0/examples/multidomain/multidomain.go 0000664 0000000 0000000 00000002427 14130360711 0022443 0 ustar 00root root 0000000 0000000 package main
import (
"fmt"
"github.com/valyala/fasthttp"
)
var domains = make(map[string]fasthttp.RequestHandler)
func main() {
server := &fasthttp.Server{
// You can check the access using openssl command:
// $ openssl s_client -connect localhost:8080 << EOF
// > GET /
// > Host: localhost
// > EOF
//
// $ openssl s_client -connect localhost:8080 << EOF
// > GET /
// > Host: 127.0.0.1:8080
// > EOF
//
Handler: func(ctx *fasthttp.RequestCtx) {
h, ok := domains[string(ctx.Host())]
if !ok {
ctx.NotFound()
return
}
h(ctx)
},
}
// preparing first host
cert, priv, err := fasthttp.GenerateTestCertificate("localhost:8080")
if err != nil {
panic(err)
}
domains["localhost:8080"] = func(ctx *fasthttp.RequestCtx) {
ctx.Write([]byte("You are accessing to localhost:8080\n"))
}
err = server.AppendCertEmbed(cert, priv)
if err != nil {
panic(err)
}
// preparing second host
cert, priv, err = fasthttp.GenerateTestCertificate("127.0.0.1")
if err != nil {
panic(err)
}
domains["127.0.0.1:8080"] = func(ctx *fasthttp.RequestCtx) {
ctx.Write([]byte("You are accessing to 127.0.0.1:8080\n"))
}
err = server.AppendCertEmbed(cert, priv)
if err != nil {
panic(err)
}
fmt.Println(server.ListenAndServeTLS(":8080", "", ""))
}
fasthttp-1.31.0/expvarhandler/ 0000775 0000000 0000000 00000000000 14130360711 0016270 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/expvarhandler/expvar.go 0000664 0000000 0000000 00000002647 14130360711 0020135 0 ustar 00root root 0000000 0000000 // Package expvarhandler provides fasthttp-compatible request handler
// serving expvars.
package expvarhandler
import (
"expvar"
"fmt"
"regexp"
"github.com/valyala/fasthttp"
)
var (
expvarHandlerCalls = expvar.NewInt("expvarHandlerCalls")
expvarRegexpErrors = expvar.NewInt("expvarRegexpErrors")
defaultRE = regexp.MustCompile(".")
)
// ExpvarHandler dumps json representation of expvars to http response.
//
// Expvars may be filtered by regexp provided via 'r' query argument.
//
// See https://golang.org/pkg/expvar/ for details.
func ExpvarHandler(ctx *fasthttp.RequestCtx) {
expvarHandlerCalls.Add(1)
ctx.Response.Reset()
r, err := getExpvarRegexp(ctx)
if err != nil {
expvarRegexpErrors.Add(1)
fmt.Fprintf(ctx, "Error when obtaining expvar regexp: %s", err)
ctx.SetStatusCode(fasthttp.StatusBadRequest)
return
}
fmt.Fprintf(ctx, "{\n")
first := true
expvar.Do(func(kv expvar.KeyValue) {
if r.MatchString(kv.Key) {
if !first {
fmt.Fprintf(ctx, ",\n")
}
first = false
fmt.Fprintf(ctx, "\t%q: %s", kv.Key, kv.Value)
}
})
fmt.Fprintf(ctx, "\n}\n")
ctx.SetContentType("application/json; charset=utf-8")
}
func getExpvarRegexp(ctx *fasthttp.RequestCtx) (*regexp.Regexp, error) {
r := string(ctx.QueryArgs().Peek("r"))
if len(r) == 0 {
return defaultRE, nil
}
rr, err := regexp.Compile(r)
if err != nil {
return nil, fmt.Errorf("cannot parse r=%q: %s", r, err)
}
return rr, nil
}
fasthttp-1.31.0/expvarhandler/expvar_test.go 0000664 0000000 0000000 00000002667 14130360711 0021176 0 ustar 00root root 0000000 0000000 package expvarhandler
import (
"encoding/json"
"expvar"
"strings"
"testing"
"github.com/valyala/fasthttp"
)
func TestExpvarHandlerBasic(t *testing.T) {
t.Parallel()
expvar.Publish("customVar", expvar.Func(func() interface{} {
return "foobar"
}))
var ctx fasthttp.RequestCtx
expvarHandlerCalls.Set(0)
ExpvarHandler(&ctx)
body := ctx.Response.Body()
var m map[string]interface{}
if err := json.Unmarshal(body, &m); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, ok := m["cmdline"]; !ok {
t.Fatalf("cannot locate cmdline expvar")
}
if _, ok := m["memstats"]; !ok {
t.Fatalf("cannot locate memstats expvar")
}
v := m["customVar"]
sv, ok := v.(string)
if !ok {
t.Fatalf("unexpected custom var type %T. Expecting string", v)
}
if sv != "foobar" {
t.Fatalf("unexpected custom var value: %q. Expecting %q", v, "foobar")
}
v = m["expvarHandlerCalls"]
fv, ok := v.(float64)
if !ok {
t.Fatalf("unexpected expvarHandlerCalls type %T. Expecting float64", v)
}
if int(fv) != 1 {
t.Fatalf("unexpected value for expvarHandlerCalls: %v. Expecting %v", fv, 1)
}
}
func TestExpvarHandlerRegexp(t *testing.T) {
var ctx fasthttp.RequestCtx
ctx.QueryArgs().Set("r", "cmd")
ExpvarHandler(&ctx)
body := string(ctx.Response.Body())
if !strings.Contains(body, `"cmdline"`) {
t.Fatalf("missing 'cmdline' expvar")
}
if strings.Contains(body, `"memstats"`) {
t.Fatalf("unexpected memstats expvar found")
}
}
fasthttp-1.31.0/fasthttpadaptor/ 0000775 0000000 0000000 00000000000 14130360711 0016635 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fasthttpadaptor/adaptor.go 0000664 0000000 0000000 00000007045 14130360711 0020624 0 ustar 00root root 0000000 0000000 // Package fasthttpadaptor provides helper functions for converting net/http
// request handlers to fasthttp request handlers.
package fasthttpadaptor
import (
"net/http"
"github.com/valyala/fasthttp"
)
// NewFastHTTPHandlerFunc wraps net/http handler func to fasthttp
// request handler, so it can be passed to fasthttp server.
//
// While this function may be used for easy switching from net/http to fasthttp,
// it has the following drawbacks comparing to using manually written fasthttp
// request handler:
//
// * A lot of useful functionality provided by fasthttp is missing
// from net/http handler.
// * net/http -> fasthttp handler conversion has some overhead,
// so the returned handler will be always slower than manually written
// fasthttp handler.
//
// So it is advisable using this function only for quick net/http -> fasthttp
// switching. Then manually convert net/http handlers to fasthttp handlers
// according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp .
func NewFastHTTPHandlerFunc(h http.HandlerFunc) fasthttp.RequestHandler {
return NewFastHTTPHandler(h)
}
// NewFastHTTPHandler wraps net/http handler to fasthttp request handler,
// so it can be passed to fasthttp server.
//
// While this function may be used for easy switching from net/http to fasthttp,
// it has the following drawbacks comparing to using manually written fasthttp
// request handler:
//
// * A lot of useful functionality provided by fasthttp is missing
// from net/http handler.
// * net/http -> fasthttp handler conversion has some overhead,
// so the returned handler will be always slower than manually written
// fasthttp handler.
//
// So it is advisable using this function only for quick net/http -> fasthttp
// switching. Then manually convert net/http handlers to fasthttp handlers
// according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp .
func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
var r http.Request
if err := ConvertRequest(ctx, &r, true); err != nil {
ctx.Logger().Printf("cannot parse requestURI %q: %s", r.RequestURI, err)
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
var w netHTTPResponseWriter
h.ServeHTTP(&w, r.WithContext(ctx))
ctx.SetStatusCode(w.StatusCode())
haveContentType := false
for k, vv := range w.Header() {
if k == fasthttp.HeaderContentType {
haveContentType = true
}
for _, v := range vv {
ctx.Response.Header.Add(k, v)
}
}
if !haveContentType {
// From net/http.ResponseWriter.Write:
// If the Header does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to DetectContentType.
l := 512
if len(w.body) < 512 {
l = len(w.body)
}
ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(w.body[:l]))
}
ctx.Write(w.body) //nolint:errcheck
}
}
type netHTTPResponseWriter struct {
statusCode int
h http.Header
body []byte
}
func (w *netHTTPResponseWriter) StatusCode() int {
if w.statusCode == 0 {
return http.StatusOK
}
return w.statusCode
}
func (w *netHTTPResponseWriter) Header() http.Header {
if w.h == nil {
w.h = make(http.Header)
}
return w.h
}
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
w.body = append(w.body, p...)
return len(p), nil
}
fasthttp-1.31.0/fasthttpadaptor/adaptor_test.go 0000664 0000000 0000000 00000012272 14130360711 0021661 0 ustar 00root root 0000000 0000000 package fasthttpadaptor
import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"reflect"
"testing"
"github.com/valyala/fasthttp"
)
func TestNewFastHTTPHandler(t *testing.T) {
t.Parallel()
expectedMethod := fasthttp.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedContextKey := "contextKey"
expectedContextValue := "contextValue"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
if r.Method != expectedMethod {
t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod)
}
if r.Proto != expectedProto {
t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto)
}
if r.ProtoMajor != expectedProtoMajor {
t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor)
}
if r.ProtoMinor != expectedProtoMinor {
t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor)
}
if r.RequestURI != expectedRequestURI {
t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI)
}
if r.ContentLength != int64(expectedContentLength) {
t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength)
}
if len(r.TransferEncoding) != 0 {
t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding)
}
if r.Host != expectedHost {
t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost)
}
if r.RemoteAddr != expectedRemoteAddr {
t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr)
}
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Fatalf("unexpected error when reading request body: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
if !reflect.DeepEqual(r.URL, expectedURL) {
t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL)
}
if r.Context().Value(expectedContextKey) != expectedContextValue {
t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue)
}
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
if v != expectedV {
t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV)
}
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is %q", body)
}
fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH))
fasthttpH = setContextValueMiddleware(fasthttpH, expectedContextKey, expectedContextValue)
var ctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) // nolint:errcheck
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx.Init(&req, remoteAddr, nil)
fasthttpH(&ctx)
if callsCount != 1 {
t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount)
}
resp := &ctx.Response
if resp.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusBadRequest)
}
if string(resp.Header.Peek("Header1")) != "value1" {
t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1")
}
if string(resp.Header.Peek("Header2")) != "value2" {
t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2")
}
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
if string(resp.Body()) != expectedResponseBody {
t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody)
}
}
func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value interface{}) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
ctx.SetUserValue(key, value)
next(ctx)
}
}
func TestContentType(t *testing.T) {
t.Parallel()
nethttpH := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("")) //nolint:errcheck
}
fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH))
var ctx fasthttp.RequestCtx
var req fasthttp.Request
req.SetRequestURI("http://example.com")
remoteAddr, err := net.ResolveTCPAddr("tcp", "1.2.3.4:80")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx.Init(&req, remoteAddr, nil)
fasthttpH(&ctx)
resp := &ctx.Response
got := string(resp.Header.Peek("Content-Type"))
expected := "text/html; charset=utf-8"
if got != expected {
t.Errorf("expected %q got %q", expected, got)
}
}
fasthttp-1.31.0/fasthttpadaptor/request.go 0000664 0000000 0000000 00000002333 14130360711 0020655 0 ustar 00root root 0000000 0000000 package fasthttpadaptor
import (
"bytes"
"io/ioutil"
"net/http"
"net/url"
"github.com/valyala/fasthttp"
)
// ConvertRequest convert a fasthttp.Request to an http.Request
// forServer should be set to true when the http.Request is going to passed to a http.Handler.
func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) error {
body := ctx.PostBody()
strRequestURI := string(ctx.RequestURI())
rURL, err := url.ParseRequestURI(strRequestURI)
if err != nil {
return err
}
r.Method = string(ctx.Method())
r.Proto = "HTTP/1.1"
r.ProtoMajor = 1
r.ProtoMinor = 1
r.ContentLength = int64(len(body))
r.RemoteAddr = ctx.RemoteAddr().String()
r.Host = string(ctx.Host())
r.TLS = ctx.TLSConnectionState()
r.Body = ioutil.NopCloser(bytes.NewReader(body))
r.URL = rURL
if forServer {
r.RequestURI = strRequestURI
}
if r.Header == nil {
r.Header = make(http.Header)
} else if len(r.Header) > 0 {
for k := range r.Header {
delete(r.Header, k)
}
}
ctx.Request.Header.VisitAll(func(k, v []byte) {
sk := string(k)
sv := string(v)
switch sk {
case "Transfer-Encoding":
r.TransferEncoding = append(r.TransferEncoding, sv)
default:
r.Header.Set(sk, sv)
}
})
return nil
}
fasthttp-1.31.0/fasthttpproxy/ 0000775 0000000 0000000 00000000000 14130360711 0016364 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fasthttpproxy/http.go 0000664 0000000 0000000 00000003474 14130360711 0017702 0 ustar 00root root 0000000 0000000 package fasthttpproxy
import (
"bufio"
"encoding/base64"
"fmt"
"net"
"strings"
"time"
"github.com/valyala/fasthttp"
)
// FasthttpHTTPDialer returns a fasthttp.DialFunc that dials using
// the provided HTTP proxy.
//
// Example usage:
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpHTTPDialer("username:password@localhost:9050"),
// }
func FasthttpHTTPDialer(proxy string) fasthttp.DialFunc {
return FasthttpHTTPDialerTimeout(proxy, 0)
}
// FasthttpHTTPDialerTimeout returns a fasthttp.DialFunc that dials using
// the provided HTTP proxy using the given timeout.
//
// Example usage:
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpHTTPDialerTimeout("username:password@localhost:9050", time.Second * 2),
// }
func FasthttpHTTPDialerTimeout(proxy string, timeout time.Duration) fasthttp.DialFunc {
var auth string
if strings.Contains(proxy, "@") {
split := strings.Split(proxy, "@")
auth = base64.StdEncoding.EncodeToString([]byte(split[0]))
proxy = split[1]
}
return func(addr string) (net.Conn, error) {
var conn net.Conn
var err error
if timeout == 0 {
conn, err = fasthttp.Dial(proxy)
} else {
conn, err = fasthttp.DialTimeout(proxy, timeout)
}
if err != nil {
return nil, err
}
req := "CONNECT " + addr + " HTTP/1.1\r\n"
if auth != "" {
req += "Proxy-Authorization: Basic " + auth + "\r\n"
}
req += "\r\n"
if _, err := conn.Write([]byte(req)); err != nil {
return nil, err
}
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.SkipBody = true
if err := res.Read(bufio.NewReader(conn)); err != nil {
conn.Close()
return nil, err
}
if res.Header.StatusCode() != 200 {
conn.Close()
return nil, fmt.Errorf("could not connect to proxy: %s status code: %d", proxy, res.Header.StatusCode())
}
return conn, nil
}
}
fasthttp-1.31.0/fasthttpproxy/proxy_env.go 0000664 0000000 0000000 00000006015 14130360711 0020746 0 ustar 00root root 0000000 0000000 package fasthttpproxy
import (
"bufio"
"encoding/base64"
"fmt"
"net"
"net/url"
"sync/atomic"
"time"
"golang.org/x/net/http/httpproxy"
"github.com/valyala/fasthttp"
)
const (
httpsScheme = "https"
httpScheme = "http"
tlsPort = "443"
)
// FasthttpProxyHTTPDialer returns a fasthttp.DialFunc that dials using
// the the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy.
//
// Example usage:
// c := &fasthttp.Client{
// Dial: FasthttpProxyHTTPDialer(),
// }
func FasthttpProxyHTTPDialer() fasthttp.DialFunc {
return FasthttpProxyHTTPDialerTimeout(0)
}
// FasthttpProxyHTTPDialer returns a fasthttp.DialFunc that dials using
// the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy using the given timeout.
//
// Example usage:
// c := &fasthttp.Client{
// Dial: FasthttpProxyHTTPDialerTimeout(time.Second * 2),
// }
func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc {
proxier := httpproxy.FromEnvironment().ProxyFunc()
// encoded auth barrier for http and https proxy.
authHTTPStorage := &atomic.Value{}
authHTTPSStorage := &atomic.Value{}
return func(addr string) (net.Conn, error) {
port, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("unexpected addr format: %v", err)
}
reqURL := &url.URL{Host: addr, Scheme: httpScheme}
if port == tlsPort {
reqURL.Scheme = httpsScheme
}
proxyURL, err := proxier(reqURL)
if err != nil {
return nil, err
}
if proxyURL == nil {
if timeout == 0 {
return fasthttp.Dial(addr)
}
return fasthttp.DialTimeout(addr, timeout)
}
var conn net.Conn
if timeout == 0 {
conn, err = fasthttp.Dial(proxyURL.Host)
} else {
conn, err = fasthttp.DialTimeout(proxyURL.Host, timeout)
}
if err != nil {
return nil, err
}
req := "CONNECT " + addr + " HTTP/1.1\r\n"
if proxyURL.User != nil {
authBarrierStorage := authHTTPStorage
if port == tlsPort {
authBarrierStorage = authHTTPSStorage
}
auth := authBarrierStorage.Load()
if auth == nil {
authBarrier := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.String()))
auth := &authBarrier
authBarrierStorage.Store(auth)
}
req += "Proxy-Authorization: Basic " + *auth.(*string) + "\r\n"
}
req += "\r\n"
if _, err := conn.Write([]byte(req)); err != nil {
return nil, err
}
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.SkipBody = true
if err := res.Read(bufio.NewReader(conn)); err != nil {
if connErr := conn.Close(); connErr != nil {
return nil, fmt.Errorf("conn close err %v followed by read conn err %v", connErr, err)
}
return nil, err
}
if res.Header.StatusCode() != 200 {
if connErr := conn.Close(); connErr != nil {
return nil, fmt.Errorf(
"conn close err %v followed by connect to proxy: code: %d body %s",
connErr, res.StatusCode(), string(res.Body()))
}
return nil, fmt.Errorf("could not connect to proxy: code: %d body %s", res.StatusCode(), string(res.Body()))
}
return conn, nil
}
}
fasthttp-1.31.0/fasthttpproxy/socks5.go 0000664 0000000 0000000 00000001661 14130360711 0020126 0 ustar 00root root 0000000 0000000 package fasthttpproxy
import (
"net"
"net/url"
"github.com/valyala/fasthttp"
"golang.org/x/net/proxy"
)
// FasthttpSocksDialer returns a fasthttp.DialFunc that dials using
// the provided SOCKS5 proxy.
//
// Example usage:
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpSocksDialer("socks5://localhost:9050"),
// }
func FasthttpSocksDialer(proxyAddr string) fasthttp.DialFunc {
var (
u *url.URL
err error
dialer proxy.Dialer
)
if u, err = url.Parse(proxyAddr); err == nil {
dialer, err = proxy.FromURL(u, proxy.Direct)
}
// It would be nice if we could return the error here. But we can't
// change our API so just keep returning it in the returned Dial function.
// Besides the implementation of proxy.SOCKS5() at the time of writing this
// will always return nil as error.
return func(addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
return dialer.Dial("tcp", addr)
}
}
fasthttp-1.31.0/fasthttputil/ 0000775 0000000 0000000 00000000000 14130360711 0016160 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fasthttputil/doc.go 0000664 0000000 0000000 00000000126 14130360711 0017253 0 ustar 00root root 0000000 0000000 // Package fasthttputil provides utility functions for fasthttp.
package fasthttputil
fasthttp-1.31.0/fasthttputil/ecdsa.key 0000664 0000000 0000000 00000000343 14130360711 0017751 0 ustar 00root root 0000000 0000000 -----BEGIN EC PRIVATE KEY-----
MHcCAQEEIBpQbZ6a5jL1Yh4wdP6yZk4MKjYWArD/QOLENFw8vbELoAoGCCqGSM49
AwEHoUQDQgAEKQCZWgE2IBhb47ot8MIs1D4KSisHYlZ41IWyeutpjb0fjwwIhimh
pl1Qld1/d2j3Z3vVyfa5yD+ncV7qCFZuSg==
-----END EC PRIVATE KEY-----
fasthttp-1.31.0/fasthttputil/ecdsa.pem 0000664 0000000 0000000 00000001052 14130360711 0017740 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIIBbTCCAROgAwIBAgIQPo718S+K+G7hc1SgTEU4QDAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MDQyMDIxMDExNFoXDTE4MDQyMDIxMDExNFow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABCkA
mVoBNiAYW+O6LfDCLNQ+CkorB2JWeNSFsnrraY29H48MCIYpoaZdUJXdf3do92d7
1cn2ucg/p3Fe6ghWbkqjSzBJMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr
BgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxvY2FsaG9zdDAKBggq
hkjOPQQDAgNIADBFAiEAoLAIQkvSuIcHUqyWroA6yWYw2fznlRH/uO9/hMCxUCEC
IClRYb/5O9eD/Eq/ozPnwNpsQHOeYefEhadJ/P82y0lG
-----END CERTIFICATE-----
fasthttp-1.31.0/fasthttputil/inmemory_listener.go 0000664 0000000 0000000 00000004342 14130360711 0022256 0 ustar 00root root 0000000 0000000 package fasthttputil
import (
"errors"
"net"
"sync"
)
// ErrInmemoryListenerClosed indicates that the InmemoryListener is already closed.
var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed: use of closed network connection")
// InmemoryListener provides in-memory dialer<->net.Listener implementation.
//
// It may be used either for fast in-process client<->server communications
// without network stack overhead or for client<->server tests.
type InmemoryListener struct {
lock sync.Mutex
closed bool
conns chan acceptConn
}
type acceptConn struct {
conn net.Conn
accepted chan struct{}
}
// NewInmemoryListener returns new in-memory dialer<->net.Listener.
func NewInmemoryListener() *InmemoryListener {
return &InmemoryListener{
conns: make(chan acceptConn, 1024),
}
}
// Accept implements net.Listener's Accept.
//
// It is safe calling Accept from concurrently running goroutines.
//
// Accept returns new connection per each Dial call.
func (ln *InmemoryListener) Accept() (net.Conn, error) {
c, ok := <-ln.conns
if !ok {
return nil, ErrInmemoryListenerClosed
}
close(c.accepted)
return c.conn, nil
}
// Close implements net.Listener's Close.
func (ln *InmemoryListener) Close() error {
var err error
ln.lock.Lock()
if !ln.closed {
close(ln.conns)
ln.closed = true
} else {
err = ErrInmemoryListenerClosed
}
ln.lock.Unlock()
return err
}
// Addr implements net.Listener's Addr.
func (ln *InmemoryListener) Addr() net.Addr {
return &net.UnixAddr{
Name: "InmemoryListener",
Net: "memory",
}
}
// Dial creates new client<->server connection.
// Just like a real Dial it only returns once the server
// has accepted the connection.
//
// It is safe calling Dial from concurrently running goroutines.
func (ln *InmemoryListener) Dial() (net.Conn, error) {
pc := NewPipeConns()
cConn := pc.Conn1()
sConn := pc.Conn2()
ln.lock.Lock()
accepted := make(chan struct{})
if !ln.closed {
ln.conns <- acceptConn{sConn, accepted}
// Wait until the connection has been accepted.
<-accepted
} else {
sConn.Close() //nolint:errcheck
cConn.Close() //nolint:errcheck
cConn = nil
}
ln.lock.Unlock()
if cConn == nil {
return nil, ErrInmemoryListenerClosed
}
return cConn, nil
}
fasthttp-1.31.0/fasthttputil/inmemory_listener_test.go 0000664 0000000 0000000 00000010356 14130360711 0023317 0 ustar 00root root 0000000 0000000 package fasthttputil
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"testing"
"time"
)
func TestInmemoryListener(t *testing.T) {
t.Parallel()
ln := NewInmemoryListener()
ch := make(chan struct{})
for i := 0; i < 10; i++ {
go func(n int) {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
defer conn.Close()
req := fmt.Sprintf("request_%d", n)
nn, err := conn.Write([]byte(req))
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if nn != len(req) {
t.Errorf("unexpected number of bytes written: %d. Expecting %d", nn, len(req))
}
buf := make([]byte, 30)
nn, err = conn.Read(buf)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
buf = buf[:nn]
resp := fmt.Sprintf("response_%d", n)
if nn != len(resp) {
t.Errorf("unexpected number of bytes read: %d. Expecting %d", nn, len(resp))
}
if string(buf) != resp {
t.Errorf("unexpected response %q. Expecting %q", buf, resp)
}
ch <- struct{}{}
}(i)
}
serverCh := make(chan struct{})
go func() {
for {
conn, err := ln.Accept()
if err != nil {
close(serverCh)
return
}
defer conn.Close()
buf := make([]byte, 30)
n, err := conn.Read(buf)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
buf = buf[:n]
if !bytes.HasPrefix(buf, []byte("request_")) {
t.Errorf("unexpected request prefix %q. Expecting %q", buf, "request_")
}
resp := fmt.Sprintf("response_%s", buf[len("request_"):])
n, err = conn.Write([]byte(resp))
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if n != len(resp) {
t.Errorf("unexpected number of bytes written: %d. Expecting %d", n, len(resp))
}
}
}()
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
// echoServerHandler implements http.Handler.
type echoServerHandler struct {
t *testing.T
}
func (s *echoServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
time.Sleep(time.Millisecond * 100)
if _, err := io.Copy(w, r.Body); err != nil {
s.t.Fatalf("unexpected error: %s", err)
}
}
func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Client)) {
ln := NewInmemoryListener()
defer ln.Close()
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return ln.Dial()
},
},
Timeout: time.Second,
}
server := &http.Server{
Handler: &echoServerHandler{t},
}
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Errorf("unexpected error: %s", err)
}
}()
f(t, client)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
server.Shutdown(ctx) //nolint:errcheck
}
func testInmemoryListenerHTTPSingle(t *testing.T, client *http.Client, content string) {
res, err := client.Post("http://...", "text/plain", bytes.NewBufferString(content))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
s := string(b)
if string(b) != content {
t.Fatalf("unexpected response %s, expecting %s", s, content)
}
}
func TestInmemoryListenerHTTPSingle(t *testing.T) {
t.Parallel()
testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
testInmemoryListenerHTTPSingle(t, client, "request")
})
}
func TestInmemoryListenerHTTPSerial(t *testing.T) {
t.Parallel()
testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
for i := 0; i < 10; i++ {
testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i))
}
})
}
func TestInmemoryListenerHTTPConcurrent(t *testing.T) {
t.Parallel()
testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i))
}(i)
}
wg.Wait()
})
}
fasthttp-1.31.0/fasthttputil/inmemory_listener_timing_test.go 0000664 0000000 0000000 00000012272 14130360711 0024665 0 ustar 00root root 0000000 0000000 package fasthttputil_test
import (
"crypto/tls"
"net"
"testing"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
// BenchmarkPlainStreaming measures end-to-end plaintext streaming performance
// for fasthttp client and server.
//
// It issues http requests over a small number of keep-alive connections.
func BenchmarkPlainStreaming(b *testing.B) {
benchmark(b, streamingHandler, false)
}
// BenchmarkPlainHandshake measures end-to-end plaintext handshake performance
// for fasthttp client and server.
//
// It re-establishes new connection per each http request.
func BenchmarkPlainHandshake(b *testing.B) {
benchmark(b, handshakeHandler, false)
}
// BenchmarkTLSStreaming measures end-to-end TLS streaming performance
// for fasthttp client and server.
//
// It issues http requests over a small number of TLS keep-alive connections.
func BenchmarkTLSStreaming(b *testing.B) {
benchmark(b, streamingHandler, true)
}
// BenchmarkTLSHandshake measures end-to-end TLS handshake performance
// for fasthttp client and server.
//
// It re-establishes new TLS connection per each http request.
func BenchmarkTLSHandshakeRSAWithClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: false,
}
benchmarkExt(b, handshakeHandler, bc)
}
func BenchmarkTLSHandshakeRSAWithoutClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: true,
}
benchmarkExt(b, handshakeHandler, bc)
}
func BenchmarkTLSHandshakeECDSAWithClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: false,
UseECDSA: true,
}
benchmarkExt(b, handshakeHandler, bc)
}
func BenchmarkTLSHandshakeECDSAWithoutClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: true,
UseECDSA: true,
}
benchmarkExt(b, handshakeHandler, bc)
}
func BenchmarkTLSHandshakeECDSAWithCurvesWithClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: false,
UseCurves: true,
UseECDSA: true,
}
benchmarkExt(b, handshakeHandler, bc)
}
func BenchmarkTLSHandshakeECDSAWithCurvesWithoutClientSessionCache(b *testing.B) {
bc := &benchConfig{
IsTLS: true,
DisableClientSessionCache: true,
UseCurves: true,
UseECDSA: true,
}
benchmarkExt(b, handshakeHandler, bc)
}
func benchmark(b *testing.B, h fasthttp.RequestHandler, isTLS bool) {
bc := &benchConfig{
IsTLS: isTLS,
}
benchmarkExt(b, h, bc)
}
type benchConfig struct {
IsTLS bool
DisableClientSessionCache bool
UseCurves bool
UseECDSA bool
}
func benchmarkExt(b *testing.B, h fasthttp.RequestHandler, bc *benchConfig) {
var serverTLSConfig, clientTLSConfig *tls.Config
if bc.IsTLS {
certFile := "rsa.pem"
keyFile := "rsa.key"
if bc.UseECDSA {
certFile = "ecdsa.pem"
keyFile = "ecdsa.key"
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
b.Fatalf("cannot load TLS certificate from certFile=%q, keyFile=%q: %s", certFile, keyFile, err)
}
serverTLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
}
serverTLSConfig.CurvePreferences = []tls.CurveID{}
if bc.UseCurves {
serverTLSConfig.CurvePreferences = []tls.CurveID{
tls.CurveP256,
}
}
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
}
if bc.DisableClientSessionCache {
clientTLSConfig.ClientSessionCache = fakeSessionCache{}
}
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
serverLn := net.Listener(ln)
if serverTLSConfig != nil {
serverLn = tls.NewListener(serverLn, serverTLSConfig)
}
if err := fasthttp.Serve(serverLn, h); err != nil {
b.Errorf("unexpected error in server: %s", err)
}
close(serverStopCh)
}()
c := &fasthttp.HostClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
IsTLS: clientTLSConfig != nil,
TLSConfig: clientTLSConfig,
}
b.RunParallel(func(pb *testing.PB) {
runRequests(b, pb, c)
})
ln.Close()
<-serverStopCh
}
func streamingHandler(ctx *fasthttp.RequestCtx) {
ctx.WriteString("foobar") //nolint:errcheck
}
func handshakeHandler(ctx *fasthttp.RequestCtx) {
streamingHandler(ctx)
// Explicitly close connection after each response.
ctx.SetConnectionClose()
}
func runRequests(b *testing.B, pb *testing.PB, c *fasthttp.HostClient) {
var req fasthttp.Request
req.SetRequestURI("http://foo.bar/baz")
var resp fasthttp.Response
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
b.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != fasthttp.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK)
}
}
}
type fakeSessionCache struct{}
func (fakeSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
return nil, false
}
func (fakeSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
// no-op
}
fasthttp-1.31.0/fasthttputil/pipeconns.go 0000664 0000000 0000000 00000013765 14130360711 0020521 0 ustar 00root root 0000000 0000000 package fasthttputil
import (
"errors"
"io"
"net"
"sync"
"time"
)
// NewPipeConns returns new bi-directional connection pipe.
//
// PipeConns is NOT safe for concurrent use by multiple goroutines!
func NewPipeConns() *PipeConns {
ch1 := make(chan *byteBuffer, 4)
ch2 := make(chan *byteBuffer, 4)
pc := &PipeConns{
stopCh: make(chan struct{}),
}
pc.c1.rCh = ch1
pc.c1.wCh = ch2
pc.c2.rCh = ch2
pc.c2.wCh = ch1
pc.c1.pc = pc
pc.c2.pc = pc
return pc
}
// PipeConns provides bi-directional connection pipe,
// which use in-process memory as a transport.
//
// PipeConns must be created by calling NewPipeConns.
//
// PipeConns has the following additional features comparing to connections
// returned from net.Pipe():
//
// * It is faster.
// * It buffers Write calls, so there is no need to have concurrent goroutine
// calling Read in order to unblock each Write call.
// * It supports read and write deadlines.
//
// PipeConns is NOT safe for concurrent use by multiple goroutines!
type PipeConns struct {
c1 pipeConn
c2 pipeConn
stopCh chan struct{}
stopChLock sync.Mutex
}
// Conn1 returns the first end of bi-directional pipe.
//
// Data written to Conn1 may be read from Conn2.
// Data written to Conn2 may be read from Conn1.
func (pc *PipeConns) Conn1() net.Conn {
return &pc.c1
}
// Conn2 returns the second end of bi-directional pipe.
//
// Data written to Conn2 may be read from Conn1.
// Data written to Conn1 may be read from Conn2.
func (pc *PipeConns) Conn2() net.Conn {
return &pc.c2
}
// Close closes pipe connections.
func (pc *PipeConns) Close() error {
pc.stopChLock.Lock()
select {
case <-pc.stopCh:
default:
close(pc.stopCh)
}
pc.stopChLock.Unlock()
return nil
}
type pipeConn struct {
b *byteBuffer
bb []byte
rCh chan *byteBuffer
wCh chan *byteBuffer
pc *PipeConns
readDeadlineTimer *time.Timer
writeDeadlineTimer *time.Timer
readDeadlineCh <-chan time.Time
writeDeadlineCh <-chan time.Time
readDeadlineChLock sync.Mutex
}
func (c *pipeConn) Write(p []byte) (int, error) {
b := acquireByteBuffer()
b.b = append(b.b[:0], p...)
select {
case <-c.pc.stopCh:
releaseByteBuffer(b)
return 0, errConnectionClosed
default:
}
select {
case c.wCh <- b:
default:
select {
case c.wCh <- b:
case <-c.writeDeadlineCh:
c.writeDeadlineCh = closedDeadlineCh
return 0, ErrTimeout
case <-c.pc.stopCh:
releaseByteBuffer(b)
return 0, errConnectionClosed
}
}
return len(p), nil
}
func (c *pipeConn) Read(p []byte) (int, error) {
mayBlock := true
nn := 0
for len(p) > 0 {
n, err := c.read(p, mayBlock)
nn += n
if err != nil {
if !mayBlock && err == errWouldBlock {
err = nil
}
return nn, err
}
p = p[n:]
mayBlock = false
}
return nn, nil
}
func (c *pipeConn) read(p []byte, mayBlock bool) (int, error) {
if len(c.bb) == 0 {
if err := c.readNextByteBuffer(mayBlock); err != nil {
return 0, err
}
}
n := copy(p, c.bb)
c.bb = c.bb[n:]
return n, nil
}
func (c *pipeConn) readNextByteBuffer(mayBlock bool) error {
releaseByteBuffer(c.b)
c.b = nil
select {
case c.b = <-c.rCh:
default:
if !mayBlock {
return errWouldBlock
}
c.readDeadlineChLock.Lock()
readDeadlineCh := c.readDeadlineCh
c.readDeadlineChLock.Unlock()
select {
case c.b = <-c.rCh:
case <-readDeadlineCh:
c.readDeadlineChLock.Lock()
c.readDeadlineCh = closedDeadlineCh
c.readDeadlineChLock.Unlock()
// rCh may contain data when deadline is reached.
// Read the data before returning ErrTimeout.
select {
case c.b = <-c.rCh:
default:
return ErrTimeout
}
case <-c.pc.stopCh:
// rCh may contain data when stopCh is closed.
// Read the data before returning EOF.
select {
case c.b = <-c.rCh:
default:
return io.EOF
}
}
}
c.bb = c.b.b
return nil
}
var (
errWouldBlock = errors.New("would block")
errConnectionClosed = errors.New("connection closed")
)
type timeoutError struct {
}
func (e *timeoutError) Error() string {
return "timeout"
}
// Only implement the Timeout() function of the net.Error interface.
// This allows for checks like:
//
// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
func (e *timeoutError) Timeout() bool {
return true
}
var (
// ErrTimeout is returned from Read() or Write() on timeout.
ErrTimeout = &timeoutError{}
)
func (c *pipeConn) Close() error {
return c.pc.Close()
}
func (c *pipeConn) LocalAddr() net.Addr {
return pipeAddr(0)
}
func (c *pipeConn) RemoteAddr() net.Addr {
return pipeAddr(0)
}
func (c *pipeConn) SetDeadline(deadline time.Time) error {
c.SetReadDeadline(deadline) //nolint:errcheck
c.SetWriteDeadline(deadline) //nolint:errcheck
return nil
}
func (c *pipeConn) SetReadDeadline(deadline time.Time) error {
if c.readDeadlineTimer == nil {
c.readDeadlineTimer = time.NewTimer(time.Hour)
}
readDeadlineCh := updateTimer(c.readDeadlineTimer, deadline)
c.readDeadlineChLock.Lock()
c.readDeadlineCh = readDeadlineCh
c.readDeadlineChLock.Unlock()
return nil
}
func (c *pipeConn) SetWriteDeadline(deadline time.Time) error {
if c.writeDeadlineTimer == nil {
c.writeDeadlineTimer = time.NewTimer(time.Hour)
}
c.writeDeadlineCh = updateTimer(c.writeDeadlineTimer, deadline)
return nil
}
func updateTimer(t *time.Timer, deadline time.Time) <-chan time.Time {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
if deadline.IsZero() {
return nil
}
d := -time.Since(deadline)
if d <= 0 {
return closedDeadlineCh
}
t.Reset(d)
return t.C
}
var closedDeadlineCh = func() <-chan time.Time {
ch := make(chan time.Time)
close(ch)
return ch
}()
type pipeAddr int
func (pipeAddr) Network() string {
return "pipe"
}
func (pipeAddr) String() string {
return "pipe"
}
type byteBuffer struct {
b []byte
}
func acquireByteBuffer() *byteBuffer {
return byteBufferPool.Get().(*byteBuffer)
}
func releaseByteBuffer(b *byteBuffer) {
if b != nil {
byteBufferPool.Put(b)
}
}
var byteBufferPool = &sync.Pool{
New: func() interface{} {
return &byteBuffer{
b: make([]byte, 1024),
}
},
}
fasthttp-1.31.0/fasthttputil/pipeconns_test.go 0000664 0000000 0000000 00000016561 14130360711 0021555 0 ustar 00root root 0000000 0000000 package fasthttputil
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"testing"
"time"
)
func TestPipeConnsWriteTimeout(t *testing.T) {
t.Parallel()
pc := NewPipeConns()
c1 := pc.Conn1()
deadline := time.Now().Add(time.Millisecond)
if err := c1.SetWriteDeadline(deadline); err != nil {
t.Fatalf("unexpected error: %s", err)
}
data := []byte("foobar")
for {
_, err := c1.Write(data)
if err != nil {
if err == ErrTimeout {
break
}
t.Fatalf("unexpected error: %s", err)
}
}
for i := 0; i < 10; i++ {
_, err := c1.Write(data)
if err == nil {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
}
}
// read the written data
c2 := pc.Conn2()
if err := c2.SetReadDeadline(time.Now().Add(10 * time.Millisecond)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
for {
_, err := c2.Read(data)
if err != nil {
if err == ErrTimeout {
break
}
t.Fatalf("unexpected error: %s", err)
}
}
for i := 0; i < 10; i++ {
_, err := c2.Read(data)
if err == nil {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
}
}
}
func TestPipeConnsPositiveReadTimeout(t *testing.T) {
t.Parallel()
testPipeConnsReadTimeout(t, time.Millisecond)
}
func TestPipeConnsNegativeReadTimeout(t *testing.T) {
t.Parallel()
testPipeConnsReadTimeout(t, -time.Second)
}
var zeroTime time.Time
func testPipeConnsReadTimeout(t *testing.T, timeout time.Duration) {
pc := NewPipeConns()
c1 := pc.Conn1()
deadline := time.Now().Add(timeout)
if err := c1.SetReadDeadline(deadline); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var buf [1]byte
for i := 0; i < 10; i++ {
_, err := c1.Read(buf[:])
if err == nil {
t.Fatalf("expecting error on iteration %d", i)
}
if err != ErrTimeout {
t.Fatalf("unexpected error on iteration %d: %s. Expecting %s", i, err, ErrTimeout)
}
}
// disable deadline and send data from c2 to c1
if err := c1.SetReadDeadline(zeroTime); err != nil {
t.Fatalf("unexpected error: %s", err)
}
data := []byte("foobar")
c2 := pc.Conn2()
if _, err := c2.Write(data); err != nil {
t.Fatalf("unexpected error: %s", err)
}
dataBuf := make([]byte, len(data))
if _, err := io.ReadFull(c1, dataBuf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(data, dataBuf) {
t.Fatalf("unexpected data received: %q. Expecting %q", dataBuf, data)
}
}
func TestPipeConnsCloseWhileReadWriteConcurrent(t *testing.T) {
t.Parallel()
concurrency := 4
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
testPipeConnsCloseWhileReadWriteSerial(t)
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
}
}
func TestPipeConnsCloseWhileReadWriteSerial(t *testing.T) {
t.Parallel()
testPipeConnsCloseWhileReadWriteSerial(t)
}
func testPipeConnsCloseWhileReadWriteSerial(t *testing.T) {
for i := 0; i < 10; i++ {
testPipeConnsCloseWhileReadWrite(t)
}
}
func testPipeConnsCloseWhileReadWrite(t *testing.T) {
pc := NewPipeConns()
c1 := pc.Conn1()
c2 := pc.Conn2()
readCh := make(chan error)
go func() {
var err error
if _, err = io.Copy(ioutil.Discard, c1); err != nil {
if err != errConnectionClosed {
err = fmt.Errorf("unexpected error: %s", err)
} else {
err = nil
}
}
readCh <- err
}()
writeCh := make(chan error)
go func() {
var err error
for {
if _, err = c2.Write([]byte("foobar")); err != nil {
if err != errConnectionClosed {
err = fmt.Errorf("unexpected error: %s", err)
} else {
err = nil
}
break
}
}
writeCh <- err
}()
time.Sleep(10 * time.Millisecond)
if err := c1.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := c2.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case err := <-readCh:
if err != nil {
t.Fatalf("unexpected error in reader: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
select {
case err := <-writeCh:
if err != nil {
t.Fatalf("unexpected error in writer: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestPipeConnsReadWriteSerial(t *testing.T) {
t.Parallel()
testPipeConnsReadWriteSerial(t)
}
func TestPipeConnsReadWriteConcurrent(t *testing.T) {
t.Parallel()
testConcurrency(t, 10, testPipeConnsReadWriteSerial)
}
func testPipeConnsReadWriteSerial(t *testing.T) {
pc := NewPipeConns()
testPipeConnsReadWrite(t, pc.Conn1(), pc.Conn2())
pc = NewPipeConns()
testPipeConnsReadWrite(t, pc.Conn2(), pc.Conn1())
}
func testPipeConnsReadWrite(t *testing.T, c1, c2 net.Conn) {
defer c1.Close()
defer c2.Close()
var buf [32]byte
for i := 0; i < 10; i++ {
// The first write
s1 := fmt.Sprintf("foo_%d", i)
n, err := c1.Write([]byte(s1))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != len(s1) {
t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s1))
}
// The second write
s2 := fmt.Sprintf("bar_%d", i)
n, err = c1.Write([]byte(s2))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != len(s2) {
t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s2))
}
// Read data written above in two writes
s := s1 + s2
n, err = c2.Read(buf[:])
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != len(s) {
t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(s))
}
if string(buf[:n]) != s {
t.Fatalf("unexpected string read: %q. Expecting %q", buf[:n], s)
}
}
}
func TestPipeConnsCloseSerial(t *testing.T) {
t.Parallel()
testPipeConnsCloseSerial(t)
}
func TestPipeConnsCloseConcurrent(t *testing.T) {
t.Parallel()
testConcurrency(t, 10, testPipeConnsCloseSerial)
}
func testPipeConnsCloseSerial(t *testing.T) {
pc := NewPipeConns()
testPipeConnsClose(t, pc.Conn1(), pc.Conn2())
pc = NewPipeConns()
testPipeConnsClose(t, pc.Conn2(), pc.Conn1())
}
func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) {
if err := c1.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var buf [10]byte
// attempt writing to closed conn
for i := 0; i < 10; i++ {
n, err := c1.Write(buf[:])
if err == nil {
t.Fatalf("expecting error")
}
if n != 0 {
t.Fatalf("unexpected number of bytes written: %d. Expecting 0", n)
}
}
// attempt reading from closed conn
for i := 0; i < 10; i++ {
n, err := c2.Read(buf[:])
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
if n != 0 {
t.Fatalf("unexpected number of bytes read: %d. Expecting 0", n)
}
}
if err := c2.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// attempt closing already closed conns
for i := 0; i < 10; i++ {
if err := c1.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := c2.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
}
func testConcurrency(t *testing.T, concurrency int, f func(*testing.T)) {
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
f(t)
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
fasthttp-1.31.0/fasthttputil/rsa.key 0000664 0000000 0000000 00000003250 14130360711 0017457 0 ustar 00root root 0000000 0000000 -----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG
3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U
wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0
FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf
IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg
GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF
sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2
sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D
uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb
K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3
YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+
DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk
B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV
Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x
IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY
wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj
wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D
FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m
tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX
fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU
ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk
K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT
6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt
9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN
Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV
c257YgaWmjK9uB0Y2r2VxS0G
-----END PRIVATE KEY-----
fasthttp-1.31.0/fasthttputil/rsa.pem 0000664 0000000 0000000 00000001755 14130360711 0017460 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV
BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV
MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D
K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te
+z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij
L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1
xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY
6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG
SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98
L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2
45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li
K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6
X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI
whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd
-----END CERTIFICATE-----
fasthttp-1.31.0/fs.go 0000664 0000000 0000000 00000110630 14130360711 0014365 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"errors"
"fmt"
"html"
"io"
"io/ioutil"
"mime"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/gzip"
"github.com/valyala/bytebufferpool"
)
// ServeFileBytesUncompressed returns HTTP response containing file contents
// from the given path.
//
// Directory contents is returned if path points to directory.
//
// ServeFileBytes may be used for saving network traffic when serving files
// with good compression ratio.
//
// See also RequestCtx.SendFileBytes.
func ServeFileBytesUncompressed(ctx *RequestCtx, path []byte) {
ServeFileUncompressed(ctx, b2s(path))
}
// ServeFileUncompressed returns HTTP response containing file contents
// from the given path.
//
// Directory contents is returned if path points to directory.
//
// ServeFile may be used for saving network traffic when serving files
// with good compression ratio.
//
// See also RequestCtx.SendFile.
func ServeFileUncompressed(ctx *RequestCtx, path string) {
ctx.Request.Header.DelBytes(strAcceptEncoding)
ServeFile(ctx, path)
}
// ServeFileBytes returns HTTP response containing compressed file contents
// from the given path.
//
// HTTP response may contain uncompressed file contents in the following cases:
//
// * Missing 'Accept-Encoding: gzip' request header.
// * No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// Use ServeFileBytesUncompressed is you don't need serving compressed
// file contents.
//
// See also RequestCtx.SendFileBytes.
func ServeFileBytes(ctx *RequestCtx, path []byte) {
ServeFile(ctx, b2s(path))
}
// ServeFile returns HTTP response containing compressed file contents
// from the given path.
//
// HTTP response may contain uncompressed file contents in the following cases:
//
// * Missing 'Accept-Encoding: gzip' request header.
// * No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// Use ServeFileUncompressed is you don't need serving compressed file contents.
//
// See also RequestCtx.SendFile.
func ServeFile(ctx *RequestCtx, path string) {
rootFSOnce.Do(func() {
rootFSHandler = rootFS.NewRequestHandler()
})
if len(path) == 0 || path[0] != '/' {
// extend relative path to absolute path
hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/'
var err error
if path, err = filepath.Abs(path); err != nil {
ctx.Logger().Printf("cannot resolve path %q to absolute file path: %s", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
if hasTrailingSlash {
path += "/"
}
}
ctx.Request.SetRequestURI(path)
rootFSHandler(ctx)
}
var (
rootFSOnce sync.Once
rootFS = &FS{
Root: "/",
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
AcceptByteRange: true,
}
rootFSHandler RequestHandler
)
// PathRewriteFunc must return new request path based on arbitrary ctx
// info such as ctx.Path().
//
// Path rewriter is used in FS for translating the current request
// to the local filesystem path relative to FS.Root.
//
// The returned path must not contain '/../' substrings due to security reasons,
// since such paths may refer files outside FS.Root.
//
// The returned path may refer to ctx members. For example, ctx.Path().
type PathRewriteFunc func(ctx *RequestCtx) []byte
// NewVHostPathRewriter returns path rewriter, which strips slashesCount
// leading slashes from the path and prepends the path with request's host,
// thus simplifying virtual hosting for static files.
//
// Examples:
//
// * host=foobar.com, slashesCount=0, original path="/foo/bar".
// Resulting path: "/foobar.com/foo/bar"
//
// * host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg"
// Resulting path: "/img.aaa.com/123/456.jpg"
//
func NewVHostPathRewriter(slashesCount int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
path := stripLeadingSlashes(ctx.Path(), slashesCount)
host := ctx.Host()
if n := bytes.IndexByte(host, '/'); n >= 0 {
host = nil
}
if len(host) == 0 {
host = strInvalidHost
}
b := bytebufferpool.Get()
b.B = append(b.B, '/')
b.B = append(b.B, host...)
b.B = append(b.B, path...)
ctx.URI().SetPathBytes(b.B)
bytebufferpool.Put(b)
return ctx.Path()
}
}
var strInvalidHost = []byte("invalid-host")
// NewPathSlashesStripper returns path rewriter, which strips slashesCount
// leading slashes from the path.
//
// Examples:
//
// * slashesCount = 0, original path: "/foo/bar", result: "/foo/bar"
// * slashesCount = 1, original path: "/foo/bar", result: "/bar"
// * slashesCount = 2, original path: "/foo/bar", result: ""
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathSlashesStripper(slashesCount int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
return stripLeadingSlashes(ctx.Path(), slashesCount)
}
}
// NewPathPrefixStripper returns path rewriter, which removes prefixSize bytes
// from the path prefix.
//
// Examples:
//
// * prefixSize = 0, original path: "/foo/bar", result: "/foo/bar"
// * prefixSize = 3, original path: "/foo/bar", result: "o/bar"
// * prefixSize = 7, original path: "/foo/bar", result: "r"
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathPrefixStripper(prefixSize int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
path := ctx.Path()
if len(path) >= prefixSize {
path = path[prefixSize:]
}
return path
}
}
// FS represents settings for request handler serving static files
// from the local filesystem.
//
// It is prohibited copying FS values. Create new values instead.
type FS struct {
noCopy noCopy //nolint:unused,structcheck
// Path to the root directory to serve files from.
Root string
// List of index file names to try opening during directory access.
//
// For example:
//
// * index.html
// * index.htm
// * my-super-index.xml
//
// By default the list is empty.
IndexNames []string
// Index pages for directories without files matching IndexNames
// are automatically generated if set.
//
// Directory index generation may be quite slow for directories
// with many files (more than 1K), so it is discouraged enabling
// index pages' generation for such directories.
//
// By default index pages aren't generated.
GenerateIndexPages bool
// Transparently compresses responses if set to true.
//
// The server tries minimizing CPU usage by caching compressed files.
// It adds CompressedFileSuffix suffix to the original file name and
// tries saving the resulting compressed file under the new file name.
// So it is advisable to give the server write access to Root
// and to all inner folders in order to minimize CPU usage when serving
// compressed responses.
//
// Transparent compression is disabled by default.
Compress bool
// Uses brotli encoding and fallbacks to gzip in responses if set to true, uses gzip if set to false.
//
// This value has sense only if Compress is set.
//
// Brotli encoding is disabled by default.
CompressBrotli bool
// Enables byte range requests if set to true.
//
// Byte range requests are disabled by default.
AcceptByteRange bool
// Path rewriting function.
//
// By default request path is not modified.
PathRewrite PathRewriteFunc
// PathNotFound fires when file is not found in filesystem
// this functions tries to replace "Cannot open requested path"
// server response giving to the programmer the control of server flow.
//
// By default PathNotFound returns
// "Cannot open requested path"
PathNotFound RequestHandler
// Expiration duration for inactive file handlers.
//
// FSHandlerCacheDuration is used by default.
CacheDuration time.Duration
// Suffix to add to the name of cached compressed file.
//
// This value has sense only if Compress is set.
//
// FSCompressedFileSuffix is used by default.
CompressedFileSuffix string
// Suffixes list to add to compressedFileSuffix depending on encoding
//
// This value has sense only if Compress is set.
//
// FSCompressedFileSuffixes is used by default.
CompressedFileSuffixes map[string]string
// If CleanStop is set, the channel can be closed to stop the cleanup handlers
// for the FS RequestHandlers created with NewRequestHandler.
// NEVER close this channel while the handler is still being used!
CleanStop chan struct{}
once sync.Once
h RequestHandler
}
// FSCompressedFileSuffix is the suffix FS adds to the original file names
// when trying to store compressed file under the new file name.
// See FS.Compress for details.
const FSCompressedFileSuffix = ".fasthttp.gz"
// FSCompressedFileSuffixes is the suffixes FS adds to the original file names depending on encoding
// when trying to store compressed file under the new file name.
// See FS.Compress for details.
var FSCompressedFileSuffixes = map[string]string{
"gzip": ".fasthttp.gz",
"br": ".fasthttp.br",
}
// FSHandlerCacheDuration is the default expiration duration for inactive
// file handlers opened by FS.
const FSHandlerCacheDuration = 10 * time.Second
// FSHandler returns request handler serving static files from
// the given root folder.
//
// stripSlashes indicates how many leading slashes must be stripped
// from requested path before searching requested file in the root folder.
// Examples:
//
// * stripSlashes = 0, original path: "/foo/bar", result: "/foo/bar"
// * stripSlashes = 1, original path: "/foo/bar", result: "/bar"
// * stripSlashes = 2, original path: "/foo/bar", result: ""
//
// The returned request handler automatically generates index pages
// for directories without index.html.
//
// The returned handler caches requested file handles
// for FSHandlerCacheDuration.
// Make sure your program has enough 'max open files' limit aka
// 'ulimit -n' if root folder contains many files.
//
// Do not create multiple request handler instances for the same
// (root, stripSlashes) arguments - just reuse a single instance.
// Otherwise goroutine leak will occur.
func FSHandler(root string, stripSlashes int) RequestHandler {
fs := &FS{
Root: root,
IndexNames: []string{"index.html"},
GenerateIndexPages: true,
AcceptByteRange: true,
}
if stripSlashes > 0 {
fs.PathRewrite = NewPathSlashesStripper(stripSlashes)
}
return fs.NewRequestHandler()
}
// NewRequestHandler returns new request handler with the given FS settings.
//
// The returned handler caches requested file handles
// for FS.CacheDuration.
// Make sure your program has enough 'max open files' limit aka
// 'ulimit -n' if FS.Root folder contains many files.
//
// Do not create multiple request handlers from a single FS instance -
// just reuse a single request handler.
func (fs *FS) NewRequestHandler() RequestHandler {
fs.once.Do(fs.initRequestHandler)
return fs.h
}
func (fs *FS) initRequestHandler() {
root := fs.Root
// serve files from the current working directory if root is empty
if len(root) == 0 {
root = "."
}
// strip trailing slashes from the root path
for len(root) > 0 && root[len(root)-1] == '/' {
root = root[:len(root)-1]
}
cacheDuration := fs.CacheDuration
if cacheDuration <= 0 {
cacheDuration = FSHandlerCacheDuration
}
compressedFileSuffixes := fs.CompressedFileSuffixes
if len(compressedFileSuffixes["br"]) == 0 || len(compressedFileSuffixes["gzip"]) == 0 ||
compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] {
compressedFileSuffixes = FSCompressedFileSuffixes
}
if len(fs.CompressedFileSuffix) > 0 {
compressedFileSuffixes["gzip"] = fs.CompressedFileSuffix
compressedFileSuffixes["br"] = FSCompressedFileSuffixes["br"]
}
h := &fsHandler{
root: root,
indexNames: fs.IndexNames,
pathRewrite: fs.PathRewrite,
generateIndexPages: fs.GenerateIndexPages,
compress: fs.Compress,
compressBrotli: fs.CompressBrotli,
pathNotFound: fs.PathNotFound,
acceptByteRange: fs.AcceptByteRange,
cacheDuration: cacheDuration,
compressedFileSuffixes: compressedFileSuffixes,
cache: make(map[string]*fsFile),
cacheBrotli: make(map[string]*fsFile),
cacheGzip: make(map[string]*fsFile),
}
go func() {
var pendingFiles []*fsFile
clean := func() {
pendingFiles = h.cleanCache(pendingFiles)
}
if fs.CleanStop != nil {
t := time.NewTicker(cacheDuration / 2)
for {
select {
case <-t.C:
clean()
case _, stillOpen := <-fs.CleanStop:
// Ignore values send on the channel, only stop when it is closed.
if !stillOpen {
t.Stop()
return
}
}
}
}
for {
time.Sleep(cacheDuration / 2)
clean()
}
}()
fs.h = h.handleRequest
}
type fsHandler struct {
root string
indexNames []string
pathRewrite PathRewriteFunc
pathNotFound RequestHandler
generateIndexPages bool
compress bool
compressBrotli bool
acceptByteRange bool
cacheDuration time.Duration
compressedFileSuffixes map[string]string
cache map[string]*fsFile
cacheBrotli map[string]*fsFile
cacheGzip map[string]*fsFile
cacheLock sync.Mutex
smallFileReaderPool sync.Pool
}
type fsFile struct {
h *fsHandler
f *os.File
dirIndex []byte
contentType string
contentLength int
compressed bool
lastModified time.Time
lastModifiedStr []byte
t time.Time
readersCount int
bigFiles []*bigFileReader
bigFilesLock sync.Mutex
}
func (ff *fsFile) NewReader() (io.Reader, error) {
if ff.isBig() {
r, err := ff.bigFileReader()
if err != nil {
ff.decReadersCount()
}
return r, err
}
return ff.smallFileReader(), nil
}
func (ff *fsFile) smallFileReader() io.Reader {
v := ff.h.smallFileReaderPool.Get()
if v == nil {
v = &fsSmallFileReader{}
}
r := v.(*fsSmallFileReader)
r.ff = ff
r.endPos = ff.contentLength
if r.startPos > 0 {
panic("BUG: fsSmallFileReader with non-nil startPos found in the pool")
}
return r
}
// files bigger than this size are sent with sendfile
const maxSmallFileSize = 2 * 4096
func (ff *fsFile) isBig() bool {
return ff.contentLength > maxSmallFileSize && len(ff.dirIndex) == 0
}
func (ff *fsFile) bigFileReader() (io.Reader, error) {
if ff.f == nil {
panic("BUG: ff.f must be non-nil in bigFileReader")
}
var r io.Reader
ff.bigFilesLock.Lock()
n := len(ff.bigFiles)
if n > 0 {
r = ff.bigFiles[n-1]
ff.bigFiles = ff.bigFiles[:n-1]
}
ff.bigFilesLock.Unlock()
if r != nil {
return r, nil
}
f, err := os.Open(ff.f.Name())
if err != nil {
return nil, fmt.Errorf("cannot open already opened file: %s", err)
}
return &bigFileReader{
f: f,
ff: ff,
r: f,
}, nil
}
func (ff *fsFile) Release() {
if ff.f != nil {
ff.f.Close()
if ff.isBig() {
ff.bigFilesLock.Lock()
for _, r := range ff.bigFiles {
r.f.Close()
}
ff.bigFilesLock.Unlock()
}
}
}
func (ff *fsFile) decReadersCount() {
ff.h.cacheLock.Lock()
ff.readersCount--
if ff.readersCount < 0 {
panic("BUG: negative fsFile.readersCount!")
}
ff.h.cacheLock.Unlock()
}
// bigFileReader attempts to trigger sendfile
// for sending big files over the wire.
type bigFileReader struct {
f *os.File
ff *fsFile
r io.Reader
lr io.LimitedReader
}
func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error {
if _, err := r.f.Seek(int64(startPos), 0); err != nil {
return err
}
r.r = &r.lr
r.lr.R = r.f
r.lr.N = int64(endPos - startPos + 1)
return nil
}
func (r *bigFileReader) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) {
if rf, ok := w.(io.ReaderFrom); ok {
// fast path. Senfile must be triggered
return rf.ReadFrom(r.r)
}
// slow path
return copyZeroAlloc(w, r.r)
}
func (r *bigFileReader) Close() error {
r.r = r.f
n, err := r.f.Seek(0, 0)
if err == nil {
if n != 0 {
panic("BUG: File.Seek(0,0) returned (non-zero, nil)")
}
ff := r.ff
ff.bigFilesLock.Lock()
ff.bigFiles = append(ff.bigFiles, r)
ff.bigFilesLock.Unlock()
} else {
r.f.Close()
}
r.ff.decReadersCount()
return err
}
type fsSmallFileReader struct {
ff *fsFile
startPos int
endPos int
}
func (r *fsSmallFileReader) Close() error {
ff := r.ff
ff.decReadersCount()
r.ff = nil
r.startPos = 0
r.endPos = 0
ff.h.smallFileReaderPool.Put(r)
return nil
}
func (r *fsSmallFileReader) UpdateByteRange(startPos, endPos int) error {
r.startPos = startPos
r.endPos = endPos + 1
return nil
}
func (r *fsSmallFileReader) Read(p []byte) (int, error) {
tailLen := r.endPos - r.startPos
if tailLen <= 0 {
return 0, io.EOF
}
if len(p) > tailLen {
p = p[:tailLen]
}
ff := r.ff
if ff.f != nil {
n, err := ff.f.ReadAt(p, int64(r.startPos))
r.startPos += n
return n, err
}
n := copy(p, ff.dirIndex[r.startPos:])
r.startPos += n
return n, nil
}
func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) {
ff := r.ff
var n int
var err error
if ff.f == nil {
n, err = w.Write(ff.dirIndex[r.startPos:r.endPos])
return int64(n), err
}
if rf, ok := w.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
curPos := r.startPos
bufv := copyBufPool.Get()
buf := bufv.([]byte)
for err == nil {
tailLen := r.endPos - curPos
if tailLen <= 0 {
break
}
if len(buf) > tailLen {
buf = buf[:tailLen]
}
n, err = ff.f.ReadAt(buf, int64(curPos))
nw, errw := w.Write(buf[:n])
curPos += nw
if errw == nil && nw != n {
panic("BUG: Write(p) returned (n, nil), where n != len(p)")
}
if err == nil {
err = errw
}
}
copyBufPool.Put(bufv)
if err == io.EOF {
err = nil
}
return int64(curPos - r.startPos), err
}
func (h *fsHandler) cleanCache(pendingFiles []*fsFile) []*fsFile {
var filesToRelease []*fsFile
h.cacheLock.Lock()
// Close files which couldn't be closed before due to non-zero
// readers count on the previous run.
var remainingFiles []*fsFile
for _, ff := range pendingFiles {
if ff.readersCount > 0 {
remainingFiles = append(remainingFiles, ff)
} else {
filesToRelease = append(filesToRelease, ff)
}
}
pendingFiles = remainingFiles
pendingFiles, filesToRelease = cleanCacheNolock(h.cache, pendingFiles, filesToRelease, h.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(h.cacheBrotli, pendingFiles, filesToRelease, h.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(h.cacheGzip, pendingFiles, filesToRelease, h.cacheDuration)
h.cacheLock.Unlock()
for _, ff := range filesToRelease {
ff.Release()
}
return pendingFiles
}
func cleanCacheNolock(cache map[string]*fsFile, pendingFiles, filesToRelease []*fsFile, cacheDuration time.Duration) ([]*fsFile, []*fsFile) {
t := time.Now()
for k, ff := range cache {
if t.Sub(ff.t) > cacheDuration {
if ff.readersCount > 0 {
// There are pending readers on stale file handle,
// so we cannot close it. Put it into pendingFiles
// so it will be closed later.
pendingFiles = append(pendingFiles, ff)
} else {
filesToRelease = append(filesToRelease, ff)
}
delete(cache, k)
}
}
return pendingFiles, filesToRelease
}
func (h *fsHandler) handleRequest(ctx *RequestCtx) {
var path []byte
if h.pathRewrite != nil {
path = h.pathRewrite(ctx)
} else {
path = ctx.Path()
}
hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/'
path = stripTrailingSlashes(path)
if n := bytes.IndexByte(path, 0); n >= 0 {
ctx.Logger().Printf("cannot serve path with nil byte at position %d: %q", n, path)
ctx.Error("Are you a hacker?", StatusBadRequest)
return
}
if h.pathRewrite != nil {
// There is no need to check for '/../' if path = ctx.Path(),
// since ctx.Path must normalize and sanitize the path.
if n := bytes.Index(path, strSlashDotDotSlash); n >= 0 {
ctx.Logger().Printf("cannot serve path with '/../' at position %d due to security reasons: %q", n, path)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
}
mustCompress := false
fileCache := h.cache
fileEncoding := ""
byteRange := ctx.Request.Header.peek(strRange)
if len(byteRange) == 0 && h.compress {
if h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr) {
mustCompress = true
fileCache = h.cacheBrotli
fileEncoding = "br"
} else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) {
mustCompress = true
fileCache = h.cacheGzip
fileEncoding = "gzip"
}
}
h.cacheLock.Lock()
ff, ok := fileCache[string(path)]
if ok {
ff.readersCount++
}
h.cacheLock.Unlock()
if !ok {
pathStr := string(path)
filePath := h.root + pathStr
var err error
ff, err = h.openFSFile(filePath, mustCompress, fileEncoding)
if mustCompress && err == errNoCreatePermission {
ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+
"Allow write access to the directory with this file in order to improve fasthttp performance", filePath)
mustCompress = false
ff, err = h.openFSFile(filePath, mustCompress, fileEncoding)
}
if err == errDirIndexRequired {
if !hasTrailingSlash {
ctx.RedirectBytes(append(path, '/'), StatusFound)
return
}
ff, err = h.openIndexFile(ctx, filePath, mustCompress, fileEncoding)
if err != nil {
ctx.Logger().Printf("cannot open dir index %q: %s", filePath, err)
ctx.Error("Directory index is forbidden", StatusForbidden)
return
}
} else if err != nil {
ctx.Logger().Printf("cannot open file %q: %s", filePath, err)
if h.pathNotFound == nil {
ctx.Error("Cannot open requested path", StatusNotFound)
} else {
ctx.SetStatusCode(StatusNotFound)
h.pathNotFound(ctx)
}
return
}
h.cacheLock.Lock()
ff1, ok := fileCache[pathStr]
if !ok {
fileCache[pathStr] = ff
ff.readersCount++
} else {
ff1.readersCount++
}
h.cacheLock.Unlock()
if ok {
// The file has been already opened by another
// goroutine, so close the current file and use
// the file opened by another goroutine instead.
ff.Release()
ff = ff1
}
}
if !ctx.IfModifiedSince(ff.lastModified) {
ff.decReadersCount()
ctx.NotModified()
return
}
r, err := ff.NewReader()
if err != nil {
ctx.Logger().Printf("cannot obtain file reader for path=%q: %s", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
hdr := &ctx.Response.Header
if ff.compressed {
if fileEncoding == "br" {
hdr.SetCanonical(strContentEncoding, strBr)
} else if fileEncoding == "gzip" {
hdr.SetCanonical(strContentEncoding, strGzip)
}
}
statusCode := StatusOK
contentLength := ff.contentLength
if h.acceptByteRange {
hdr.SetCanonical(strAcceptRanges, strBytes)
if len(byteRange) > 0 {
startPos, endPos, err := ParseByteRange(byteRange, contentLength)
if err != nil {
r.(io.Closer).Close()
ctx.Logger().Printf("cannot parse byte range %q for path=%q: %s", byteRange, path, err)
ctx.Error("Range Not Satisfiable", StatusRequestedRangeNotSatisfiable)
return
}
if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil {
r.(io.Closer).Close()
ctx.Logger().Printf("cannot seek byte range %q for path=%q: %s", byteRange, path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
hdr.SetContentRange(startPos, endPos, contentLength)
contentLength = endPos - startPos + 1
statusCode = StatusPartialContent
}
}
hdr.SetCanonical(strLastModified, ff.lastModifiedStr)
if !ctx.IsHead() {
ctx.SetBodyStream(r, contentLength)
} else {
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.SetContentLength(contentLength)
if rc, ok := r.(io.Closer); ok {
if err := rc.Close(); err != nil {
ctx.Logger().Printf("cannot close file reader: %s", err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
}
}
hdr.noDefaultContentType = true
if len(hdr.ContentType()) == 0 {
ctx.SetContentType(ff.contentType)
}
ctx.SetStatusCode(statusCode)
}
type byteRangeUpdater interface {
UpdateByteRange(startPos, endPos int) error
}
// ParseByteRange parses 'Range: bytes=...' header value.
//
// It follows https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 .
func ParseByteRange(byteRange []byte, contentLength int) (startPos, endPos int, err error) {
b := byteRange
if !bytes.HasPrefix(b, strBytes) {
return 0, 0, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, strBytes)
}
b = b[len(strBytes):]
if len(b) == 0 || b[0] != '=' {
return 0, 0, fmt.Errorf("missing byte range in %q", byteRange)
}
b = b[1:]
n := bytes.IndexByte(b, '-')
if n < 0 {
return 0, 0, fmt.Errorf("missing the end position of byte range in %q", byteRange)
}
if n == 0 {
v, err := ParseUint(b[n+1:])
if err != nil {
return 0, 0, err
}
startPos := contentLength - v
if startPos < 0 {
startPos = 0
}
return startPos, contentLength - 1, nil
}
if startPos, err = ParseUint(b[:n]); err != nil {
return 0, 0, err
}
if startPos >= contentLength {
return 0, 0, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange)
}
b = b[n+1:]
if len(b) == 0 {
return startPos, contentLength - 1, nil
}
if endPos, err = ParseUint(b); err != nil {
return 0, 0, err
}
if endPos >= contentLength {
endPos = contentLength - 1
}
if endPos < startPos {
return 0, 0, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange)
}
return startPos, endPos, nil
}
func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
for _, indexName := range h.indexNames {
indexFilePath := dirPath + "/" + indexName
ff, err := h.openFSFile(indexFilePath, mustCompress, fileEncoding)
if err == nil {
return ff, nil
}
if !os.IsNotExist(err) {
return nil, fmt.Errorf("cannot open file %q: %s", indexFilePath, err)
}
}
if !h.generateIndexPages {
return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath)
}
return h.createDirIndex(ctx.URI(), dirPath, mustCompress, fileEncoding)
}
var (
errDirIndexRequired = errors.New("directory index required")
errNoCreatePermission = errors.New("no 'create file' permissions")
)
func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
w := &bytebufferpool.ByteBuffer{}
basePathEscaped := html.EscapeString(string(base.Path()))
fmt.Fprintf(w, "
%s", basePathEscaped)
fmt.Fprintf(w, "%s
", basePathEscaped)
fmt.Fprintf(w, "")
if len(basePathEscaped) > 1 {
var parentURI URI
base.CopyTo(&parentURI)
parentURI.Update(string(base.Path()) + "/..")
parentPathEscaped := html.EscapeString(string(parentURI.Path()))
fmt.Fprintf(w, `- ..
`, parentPathEscaped)
}
f, err := os.Open(dirPath)
if err != nil {
return nil, err
}
fileinfos, err := f.Readdir(0)
f.Close()
if err != nil {
return nil, err
}
fm := make(map[string]os.FileInfo, len(fileinfos))
filenames := make([]string, 0, len(fileinfos))
nestedContinue:
for _, fi := range fileinfos {
name := fi.Name()
for _, cfs := range h.compressedFileSuffixes {
if strings.HasSuffix(name, cfs) {
// Do not show compressed files on index page.
continue nestedContinue
}
}
fm[name] = fi
filenames = append(filenames, name)
}
var u URI
base.CopyTo(&u)
u.Update(string(u.Path()) + "/")
sort.Strings(filenames)
for _, name := range filenames {
u.Update(name)
pathEscaped := html.EscapeString(string(u.Path()))
fi := fm[name]
auxStr := "dir"
className := "dir"
if !fi.IsDir() {
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file"
}
fmt.Fprintf(w, `- %s, %s, last modified %s
`,
pathEscaped, className, html.EscapeString(name), auxStr, fsModTime(fi.ModTime()))
}
fmt.Fprintf(w, "
")
if mustCompress {
var zbuf bytebufferpool.ByteBuffer
if fileEncoding == "br" {
zbuf.B = AppendBrotliBytesLevel(zbuf.B, w.B, CompressDefaultCompression)
} else if fileEncoding == "gzip" {
zbuf.B = AppendGzipBytesLevel(zbuf.B, w.B, CompressDefaultCompression)
}
w = &zbuf
}
dirIndex := w.B
lastModified := time.Now()
ff := &fsFile{
h: h,
dirIndex: dirIndex,
contentType: "text/html; charset=utf-8",
contentLength: len(dirIndex),
compressed: mustCompress,
lastModified: lastModified,
lastModifiedStr: AppendHTTPDate(nil, lastModified),
t: lastModified,
}
return ff, nil
}
const (
fsMinCompressRatio = 0.8
fsMaxCompressibleFileSize = 8 * 1024 * 1024
)
func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) (*fsFile, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, err
}
fileInfo, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err)
}
if fileInfo.IsDir() {
f.Close()
return nil, errDirIndexRequired
}
if strings.HasSuffix(filePath, h.compressedFileSuffixes[fileEncoding]) ||
fileInfo.Size() > fsMaxCompressibleFileSize ||
!isFileCompressible(f, fsMinCompressRatio) {
return h.newFSFile(f, fileInfo, false, "")
}
compressedFilePath := filePath + h.compressedFileSuffixes[fileEncoding]
absPath, err := filepath.Abs(compressedFilePath)
if err != nil {
f.Close()
return nil, fmt.Errorf("cannot determine absolute path for %q: %s", compressedFilePath, err)
}
flock := getFileLock(absPath)
flock.Lock()
ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath, fileEncoding)
flock.Unlock()
return ff, err
}
func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string, fileEncoding string) (*fsFile, error) {
// Attempt to open compressed file created by another concurrent
// goroutine.
// It is safe opening such a file, since the file creation
// is guarded by file mutex - see getFileLock call.
if _, err := os.Stat(compressedFilePath); err == nil {
f.Close()
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
// Create temporary file, so concurrent goroutines don't use
// it until it is created.
tmpFilePath := compressedFilePath + ".tmp"
zf, err := os.Create(tmpFilePath)
if err != nil {
f.Close()
if !os.IsPermission(err) {
return nil, fmt.Errorf("cannot create temporary file %q: %s", tmpFilePath, err)
}
return nil, errNoCreatePermission
}
if fileEncoding == "br" {
zw := acquireStacklessBrotliWriter(zf, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if err1 := zw.Flush(); err == nil {
err = err1
}
releaseStacklessBrotliWriter(zw, CompressDefaultCompression)
} else if fileEncoding == "gzip" {
zw := acquireStacklessGzipWriter(zf, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if err1 := zw.Flush(); err == nil {
err = err1
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
}
zf.Close()
f.Close()
if err != nil {
return nil, fmt.Errorf("error when compressing file %q to %q: %s", filePath, tmpFilePath, err)
}
if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil {
return nil, fmt.Errorf("cannot change modification time to %s for tmp file %q: %s",
fileInfo.ModTime(), tmpFilePath, err)
}
if err = os.Rename(tmpFilePath, compressedFilePath); err != nil {
return nil, fmt.Errorf("cannot move compressed file from %q to %q: %s", tmpFilePath, compressedFilePath, err)
}
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (*fsFile, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("cannot open compressed file %q: %s", filePath, err)
}
fileInfo, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("cannot obtain info for compressed file %q: %s", filePath, err)
}
return h.newFSFile(f, fileInfo, true, fileEncoding)
}
func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
filePathOriginal := filePath
if mustCompress {
filePath += h.compressedFileSuffixes[fileEncoding]
}
f, err := os.Open(filePath)
if err != nil {
if mustCompress && os.IsNotExist(err) {
return h.compressAndOpenFSFile(filePathOriginal, fileEncoding)
}
return nil, err
}
fileInfo, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err)
}
if fileInfo.IsDir() {
f.Close()
if mustCompress {
return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q",
filePath, h.compressedFileSuffixes[fileEncoding])
}
return nil, errDirIndexRequired
}
if mustCompress {
fileInfoOriginal, err := os.Stat(filePathOriginal)
if err != nil {
f.Close()
return nil, fmt.Errorf("cannot obtain info for original file %q: %s", filePathOriginal, err)
}
// Only re-create the compressed file if there was more than a second between the mod times.
// On MacOS the gzip seems to truncate the nanoseconds in the mod time causing the original file
// to look newer than the gzipped file.
if fileInfoOriginal.ModTime().Sub(fileInfo.ModTime()) >= time.Second {
// The compressed file became stale. Re-create it.
f.Close()
os.Remove(filePath)
return h.compressAndOpenFSFile(filePathOriginal, fileEncoding)
}
}
return h.newFSFile(f, fileInfo, mustCompress, fileEncoding)
}
func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, fileEncoding string) (*fsFile, error) {
n := fileInfo.Size()
contentLength := int(n)
if n != int64(contentLength) {
f.Close()
return nil, fmt.Errorf("too big file: %d bytes", n)
}
// detect content-type
ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffixes[fileEncoding])
contentType := mime.TypeByExtension(ext)
if len(contentType) == 0 {
data, err := readFileHeader(f, compressed, fileEncoding)
if err != nil {
return nil, fmt.Errorf("cannot read header of the file %q: %s", f.Name(), err)
}
contentType = http.DetectContentType(data)
}
lastModified := fileInfo.ModTime()
ff := &fsFile{
h: h,
f: f,
contentType: contentType,
contentLength: contentLength,
compressed: compressed,
lastModified: lastModified,
lastModifiedStr: AppendHTTPDate(nil, lastModified),
t: time.Now(),
}
return ff, nil
}
func readFileHeader(f *os.File, compressed bool, fileEncoding string) ([]byte, error) {
r := io.Reader(f)
var (
br *brotli.Reader
zr *gzip.Reader
)
if compressed {
var err error
if fileEncoding == "br" {
if br, err = acquireBrotliReader(f); err != nil {
return nil, err
}
r = br
} else if fileEncoding == "gzip" {
if zr, err = acquireGzipReader(f); err != nil {
return nil, err
}
r = zr
}
}
lr := &io.LimitedReader{
R: r,
N: 512,
}
data, err := ioutil.ReadAll(lr)
if _, err := f.Seek(0, 0); err != nil {
return nil, err
}
if br != nil {
releaseBrotliReader(br)
}
if zr != nil {
releaseGzipReader(zr)
}
return data, err
}
func stripLeadingSlashes(path []byte, stripSlashes int) []byte {
for stripSlashes > 0 && len(path) > 0 {
if path[0] != '/' {
panic("BUG: path must start with slash")
}
n := bytes.IndexByte(path[1:], '/')
if n < 0 {
path = path[:0]
break
}
path = path[n+1:]
stripSlashes--
}
return path
}
func stripTrailingSlashes(path []byte) []byte {
for len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
return path
}
func fileExtension(path string, compressed bool, compressedFileSuffix string) string {
if compressed && strings.HasSuffix(path, compressedFileSuffix) {
path = path[:len(path)-len(compressedFileSuffix)]
}
n := strings.LastIndexByte(path, '.')
if n < 0 {
return ""
}
return path[n:]
}
// FileLastModified returns last modified time for the file.
func FileLastModified(path string) (time.Time, error) {
f, err := os.Open(path)
if err != nil {
return zeroTime, err
}
fileInfo, err := f.Stat()
f.Close()
if err != nil {
return zeroTime, err
}
return fsModTime(fileInfo.ModTime()), nil
}
func fsModTime(t time.Time) time.Time {
return t.In(time.UTC).Truncate(time.Second)
}
var (
filesLockMap = make(map[string]*sync.Mutex)
filesLockMapLock sync.Mutex
)
func getFileLock(absPath string) *sync.Mutex {
filesLockMapLock.Lock()
flock := filesLockMap[absPath]
if flock == nil {
flock = &sync.Mutex{}
filesLockMap[absPath] = flock
}
filesLockMapLock.Unlock()
return flock
}
fasthttp-1.31.0/fs_example_test.go 0000664 0000000 0000000 00000001101 14130360711 0017127 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"log"
"github.com/valyala/fasthttp"
)
func ExampleFS() {
fs := &fasthttp.FS{
// Path to directory to serve.
Root: "/var/www/static-site",
// Generate index pages if client requests directory contents.
GenerateIndexPages: true,
// Enable transparent compression to save network traffic.
Compress: true,
}
// Create request handler for serving static files.
h := fs.NewRequestHandler()
// Start the server.
if err := fasthttp.ListenAndServe(":8080", h); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
fasthttp-1.31.0/fs_handler_example_test.go 0000664 0000000 0000000 00000002221 14130360711 0020630 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"bytes"
"log"
"github.com/valyala/fasthttp"
)
// Setup file handlers (aka 'file server config')
var (
// Handler for serving images from /img/ path,
// i.e. /img/foo/bar.jpg will be served from
// /var/www/images/foo/bar.jpb .
imgPrefix = []byte("/img/")
imgHandler = fasthttp.FSHandler("/var/www/images", 1)
// Handler for serving css from /static/css/ path,
// i.e. /static/css/foo/bar.css will be served from
// /home/dev/css/foo/bar.css .
cssPrefix = []byte("/static/css/")
cssHandler = fasthttp.FSHandler("/home/dev/css", 2)
// Handler for serving the rest of requests,
// i.e. /foo/bar/baz.html will be served from
// /var/www/files/foo/bar/baz.html .
filesHandler = fasthttp.FSHandler("/var/www/files", 0)
)
// Main request handler
func requestHandler(ctx *fasthttp.RequestCtx) {
path := ctx.Path()
switch {
case bytes.HasPrefix(path, imgPrefix):
imgHandler(ctx)
case bytes.HasPrefix(path, cssPrefix):
cssHandler(ctx)
default:
filesHandler(ctx)
}
}
func ExampleFSHandler() {
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
log.Fatalf("Error in server: %s", err)
}
}
fasthttp-1.31.0/fs_test.go 0000664 0000000 0000000 00000053405 14130360711 0015432 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"path"
"runtime"
"sort"
"testing"
"time"
)
type TestLogger struct {
t *testing.T
}
func (t TestLogger) Printf(format string, args ...interface{}) {
t.t.Logf(format, args...)
}
func TestNewVHostPathRewriter(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.Header.SetHost("foobar.com")
req.SetRequestURI("/foo/bar/baz")
ctx.Init(&req, nil, nil)
f := NewVHostPathRewriter(0)
path := f(&ctx)
expectedPath := "/foobar.com/foo/bar/baz"
if string(path) != expectedPath {
t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath)
}
ctx.Request.Reset()
ctx.Request.SetRequestURI("https://aaa.bbb.cc/one/two/three/four?asdf=dsf")
f = NewVHostPathRewriter(2)
path = f(&ctx)
expectedPath = "/aaa.bbb.cc/three/four"
if string(path) != expectedPath {
t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath)
}
}
func TestNewVHostPathRewriterMaliciousHost(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.Header.SetHost("/../../../etc/passwd")
req.SetRequestURI("/foo/bar/baz")
ctx.Init(&req, nil, nil)
f := NewVHostPathRewriter(0)
path := f(&ctx)
expectedPath := "/invalid-host/"
if string(path) != expectedPath {
t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath)
}
}
func testPathNotFound(t *testing.T, pathNotFoundFunc RequestHandler) {
var ctx RequestCtx
var req Request
req.SetRequestURI("http//some.url/file")
ctx.Init(&req, nil, TestLogger{t})
stop := make(chan struct{})
defer close(stop)
fs := &FS{
Root: "./",
PathNotFound: pathNotFoundFunc,
CleanStop: stop,
}
fs.NewRequestHandler()(&ctx)
if pathNotFoundFunc == nil {
// different to ...
if !bytes.Equal(ctx.Response.Body(),
[]byte("Cannot open requested path")) {
t.Fatalf("response defers. Response: %q", ctx.Response.Body())
}
} else {
// Equals to ...
if bytes.Equal(ctx.Response.Body(),
[]byte("Cannot open requested path")) {
t.Fatalf("response defers. Response: %q", ctx.Response.Body())
}
}
}
func TestPathNotFound(t *testing.T) {
t.Parallel()
testPathNotFound(t, nil)
}
func TestPathNotFoundFunc(t *testing.T) {
t.Parallel()
testPathNotFound(t, func(ctx *RequestCtx) {
ctx.WriteString("Not found hehe") //nolint:errcheck
})
}
func TestServeFileHead(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
var ctx RequestCtx
var req Request
req.Header.SetMethod(MethodHead)
req.SetRequestURI("http://foobar.com/baz")
ctx.Init(&req, nil, nil)
ServeFile(&ctx, "fs.go")
var resp Response
resp.SkipBody = true
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if len(ce) > 0 {
t.Fatalf("Unexpected 'Content-Encoding' %q", ce)
}
body := resp.Body()
if len(body) > 0 {
t.Fatalf("unexpected response body %q. Expecting empty body", body)
}
expectedBody, err := getFileContents("/fs.go")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
contentLength := resp.Header.ContentLength()
if contentLength != len(expectedBody) {
t.Fatalf("unexpected Content-Length: %d. expecting %d", contentLength, len(expectedBody))
}
}
func TestServeFileSmallNoReadFrom(t *testing.T) {
t.Parallel()
teststr := "hello, world!"
tempdir, err := ioutil.TempDir("", "httpexpect")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempdir)
if err := ioutil.WriteFile(
path.Join(tempdir, "hello"), []byte(teststr), 0666); err != nil {
t.Fatal(err)
}
var ctx RequestCtx
var req Request
req.SetRequestURI("http://foobar.com/baz")
ctx.Init(&req, nil, nil)
ServeFile(&ctx, path.Join(tempdir, "hello"))
reader, ok := ctx.Response.bodyStream.(*fsSmallFileReader)
if !ok {
t.Fatal("expected fsSmallFileReader")
}
buf := bytes.NewBuffer(nil)
n, err := reader.WriteTo(pureWriter{buf})
if err != nil {
t.Fatal(err)
}
if n != int64(len(teststr)) {
t.Fatalf("expected %d bytes, got %d bytes", len(teststr), n)
}
body := buf.String()
if body != teststr {
t.Fatalf("expected '%s'", teststr)
}
}
type pureWriter struct {
w io.Writer
}
func (pw pureWriter) Write(p []byte) (nn int, err error) {
return pw.w.Write(p)
}
func TestServeFileCompressed(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
var ctx RequestCtx
ctx.Init(&Request{}, nil, nil)
var resp Response
// request compressed gzip file
ctx.Request.SetRequestURI("http://foobar.com/baz")
ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip")
ServeFile(&ctx, "fs.go")
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "gzip" {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedBody, err := getFileContents("/fs.go")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
}
// request compressed brotli file
ctx.Request.Reset()
ctx.Request.SetRequestURI("http://foobar.com/baz")
ctx.Request.Header.Set(HeaderAcceptEncoding, "br")
ServeFile(&ctx, "fs.go")
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err = resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "br" {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "br")
}
body, err = resp.BodyUnbrotli()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedBody, err = getFileContents("/fs.go")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
}
}
func TestServeFileUncompressed(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
var ctx RequestCtx
var req Request
req.SetRequestURI("http://foobar.com/baz")
req.Header.Set(HeaderAcceptEncoding, "gzip")
ctx.Init(&req, nil, nil)
ServeFileUncompressed(&ctx, "fs.go")
var resp Response
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if len(ce) > 0 {
t.Fatalf("Unexpected 'Content-Encoding' %q", ce)
}
body := resp.Body()
expectedBody, err := getFileContents("/fs.go")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
}
}
func TestFSByteRangeConcurrent(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
stop := make(chan struct{})
defer close(stop)
fs := &FS{
Root: ".",
AcceptByteRange: true,
CleanStop: stop,
}
h := fs.NewRequestHandler()
concurrency := 10
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
for j := 0; j < 5; j++ {
testFSByteRange(t, h, "/fs.go")
testFSByteRange(t, h, "/README.md")
}
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-time.After(time.Second):
t.Fatalf("timeout")
case <-ch:
}
}
}
func TestFSByteRangeSingleThread(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
stop := make(chan struct{})
defer close(stop)
fs := &FS{
Root: ".",
AcceptByteRange: true,
CleanStop: stop,
}
h := fs.NewRequestHandler()
testFSByteRange(t, h, "/fs.go")
testFSByteRange(t, h, "/README.md")
}
func testFSByteRange(t *testing.T, h RequestHandler, filePath string) {
var ctx RequestCtx
ctx.Init(&Request{}, nil, nil)
expectedBody, err := getFileContents(filePath)
if err != nil {
t.Fatalf("cannot read file %q: %s", filePath, err)
}
fileSize := len(expectedBody)
startPos := rand.Intn(fileSize)
endPos := rand.Intn(fileSize)
if endPos < startPos {
startPos, endPos = endPos, startPos
}
ctx.Request.SetRequestURI(filePath)
ctx.Request.Header.SetByteRange(startPos, endPos)
h(&ctx)
var resp Response
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusPartialContent {
t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath)
}
cr := resp.Header.Peek(HeaderContentRange)
expectedCR := fmt.Sprintf("bytes %d-%d/%d", startPos, endPos, fileSize)
if string(cr) != expectedCR {
t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr, expectedCR, filePath)
}
body := resp.Body()
bodySize := endPos - startPos + 1
if len(body) != bodySize {
t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d",
len(body), bodySize, filePath, startPos, endPos)
}
expectedBody = expectedBody[startPos : endPos+1]
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d",
body, expectedBody, filePath, startPos, endPos)
}
}
func getFileContents(path string) ([]byte, error) {
path = "." + path
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return ioutil.ReadAll(f)
}
func TestParseByteRangeSuccess(t *testing.T) {
t.Parallel()
testParseByteRangeSuccess(t, "bytes=0-0", 1, 0, 0)
testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, 1234, 6789)
testParseByteRangeSuccess(t, "bytes=123-", 456, 123, 455)
testParseByteRangeSuccess(t, "bytes=-1", 1, 0, 0)
testParseByteRangeSuccess(t, "bytes=-123", 456, 333, 455)
// End position exceeding content-length. It should be updated to content-length-1.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
testParseByteRangeSuccess(t, "bytes=1-2345", 234, 1, 233)
testParseByteRangeSuccess(t, "bytes=0-2345", 2345, 0, 2344)
// Start position overflow. Whole range must be returned.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
testParseByteRangeSuccess(t, "bytes=-567", 56, 0, 55)
}
func testParseByteRangeSuccess(t *testing.T, v string, contentLength, startPos, endPos int) {
startPos1, endPos1, err := ParseByteRange([]byte(v), contentLength)
if err != nil {
t.Fatalf("unexpected error: %s. v=%q, contentLength=%d", err, v, contentLength)
}
if startPos1 != startPos {
t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1, startPos, v, contentLength)
}
if endPos1 != endPos {
t.Fatalf("unexpected endPos=%d. Expectind %d. v=%q, contentLenght=%d", endPos1, endPos, v, contentLength)
}
}
func TestParseByteRangeError(t *testing.T) {
t.Parallel()
// invalid value
testParseByteRangeError(t, "asdfasdfas", 1234)
// invalid units
testParseByteRangeError(t, "foobar=1-34", 600)
// missing '-'
testParseByteRangeError(t, "bytes=1234", 1235)
// non-numeric range
testParseByteRangeError(t, "bytes=foobar", 123)
testParseByteRangeError(t, "bytes=1-foobar", 123)
testParseByteRangeError(t, "bytes=df-344", 545)
// multiple byte ranges
testParseByteRangeError(t, "bytes=1-2,4-6", 123)
// byte range exceeding contentLength
testParseByteRangeError(t, "bytes=123-", 12)
// startPos exceeding endPos
testParseByteRangeError(t, "bytes=123-34", 1234)
}
func testParseByteRangeError(t *testing.T, v string, contentLength int) {
_, _, err := ParseByteRange([]byte(v), contentLength)
if err == nil {
t.Fatalf("expecting error when parsing byte range %q", v)
}
}
func TestFSCompressConcurrent(t *testing.T) {
// This test can't run parallel as files in / might be changed by other tests.
stop := make(chan struct{})
defer close(stop)
fs := &FS{
Root: ".",
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
CleanStop: stop,
}
h := fs.NewRequestHandler()
concurrency := 4
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
for j := 0; j < 5; j++ {
testFSCompress(t, h, "/fs.go")
testFSCompress(t, h, "/")
testFSCompress(t, h, "/README.md")
}
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-ch:
case <-time.After(time.Second * 3):
t.Fatalf("timeout")
}
}
}
func TestFSCompressSingleThread(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
stop := make(chan struct{})
defer close(stop)
fs := &FS{
Root: ".",
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
CleanStop: stop,
}
h := fs.NewRequestHandler()
testFSCompress(t, h, "/fs.go")
testFSCompress(t, h, "/")
testFSCompress(t, h, "/README.md")
}
func testFSCompress(t *testing.T, h RequestHandler, filePath string) {
var ctx RequestCtx
ctx.Init(&Request{}, nil, nil)
var resp Response
// request uncompressed file
ctx.Request.Reset()
ctx.Request.SetRequestURI(filePath)
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "" {
t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath)
}
body := string(resp.Body())
// request compressed gzip file
ctx.Request.Reset()
ctx.Request.SetRequestURI(filePath)
ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "gzip" {
t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath)
}
zbody, err := resp.BodyGunzip()
if err != nil {
t.Errorf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath)
}
if string(zbody) != body {
t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath)
}
// request compressed brotli file
ctx.Request.Reset()
ctx.Request.SetRequestURI(filePath)
ctx.Request.Header.Set(HeaderAcceptEncoding, "br")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "br" {
t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "br", filePath)
}
zbody, err = resp.BodyUnbrotli()
if err != nil {
t.Errorf("unexpected error when unbrotling response body: %s. filePath=%q", err, filePath)
}
if string(zbody) != body {
t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath)
}
}
func TestFSHandlerSingleThread(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
requestHandler := FSHandler(".", 0)
f, err := os.Open(".")
if err != nil {
t.Fatalf("cannot open cwd: %s", err)
}
filenames, err := f.Readdirnames(0)
f.Close()
if err != nil {
t.Fatalf("cannot read dirnames in cwd: %s", err)
}
sort.Strings(filenames)
for i := 0; i < 3; i++ {
fsHandlerTest(t, requestHandler, filenames)
}
}
func TestFSHandlerConcurrent(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
requestHandler := FSHandler(".", 0)
f, err := os.Open(".")
if err != nil {
t.Fatalf("cannot open cwd: %s", err)
}
filenames, err := f.Readdirnames(0)
f.Close()
if err != nil {
t.Fatalf("cannot read dirnames in cwd: %s", err)
}
sort.Strings(filenames)
concurrency := 10
ch := make(chan struct{}, concurrency)
for j := 0; j < concurrency; j++ {
go func() {
for i := 0; i < 3; i++ {
fsHandlerTest(t, requestHandler, filenames)
}
ch <- struct{}{}
}()
}
for j := 0; j < concurrency; j++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func fsHandlerTest(t *testing.T, requestHandler RequestHandler, filenames []string) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
ctx.Request.Header.SetHost("foobar.com")
filesTested := 0
for _, name := range filenames {
f, err := os.Open(name)
if err != nil {
t.Fatalf("cannot open file %q: %s", name, err)
}
stat, err := f.Stat()
if err != nil {
t.Fatalf("cannot get file stat %q: %s", name, err)
}
if stat.IsDir() {
f.Close()
continue
}
data, err := ioutil.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("cannot read file contents %q: %s", name, err)
}
ctx.URI().Update(name)
requestHandler(&ctx)
if ctx.Response.bodyStream == nil {
t.Fatalf("response body stream must be non-empty")
}
body, err := ioutil.ReadAll(ctx.Response.bodyStream)
if err != nil {
t.Fatalf("error when reading response body stream: %s", err)
}
if !bytes.Equal(body, data) {
t.Fatalf("unexpected body returned: %q. Expecting %q", body, data)
}
filesTested++
if filesTested >= 10 {
break
}
}
// verify index page generation
ctx.URI().Update("/")
requestHandler(&ctx)
if ctx.Response.bodyStream == nil {
t.Fatalf("response body stream must be non-empty")
}
body, err := ioutil.ReadAll(ctx.Response.bodyStream)
if err != nil {
t.Fatalf("error when reading response body stream: %s", err)
}
if len(body) == 0 {
t.Fatalf("index page must be non-empty")
}
}
func TestStripPathSlashes(t *testing.T) {
t.Parallel()
testStripPathSlashes(t, "", 0, "")
testStripPathSlashes(t, "", 10, "")
testStripPathSlashes(t, "/", 0, "")
testStripPathSlashes(t, "/", 1, "")
testStripPathSlashes(t, "/", 10, "")
testStripPathSlashes(t, "/foo/bar/baz", 0, "/foo/bar/baz")
testStripPathSlashes(t, "/foo/bar/baz", 1, "/bar/baz")
testStripPathSlashes(t, "/foo/bar/baz", 2, "/baz")
testStripPathSlashes(t, "/foo/bar/baz", 3, "")
testStripPathSlashes(t, "/foo/bar/baz", 10, "")
// trailing slash
testStripPathSlashes(t, "/foo/bar/", 0, "/foo/bar")
testStripPathSlashes(t, "/foo/bar/", 1, "/bar")
testStripPathSlashes(t, "/foo/bar/", 2, "")
testStripPathSlashes(t, "/foo/bar/", 3, "")
}
func testStripPathSlashes(t *testing.T, path string, stripSlashes int, expectedPath string) {
s := stripLeadingSlashes([]byte(path), stripSlashes)
s = stripTrailingSlashes(s)
if string(s) != expectedPath {
t.Fatalf("unexpected path after stripping %q with stripSlashes=%d: %q. Expecting %q", path, stripSlashes, s, expectedPath)
}
}
func TestFileExtension(t *testing.T) {
t.Parallel()
testFileExtension(t, "foo.bar", false, "zzz", ".bar")
testFileExtension(t, "foobar", false, "zzz", "")
testFileExtension(t, "foo.bar.baz", false, "zzz", ".baz")
testFileExtension(t, "", false, "zzz", "")
testFileExtension(t, "/a/b/c.d/efg.jpg", false, ".zzz", ".jpg")
testFileExtension(t, "foo.bar", true, ".zzz", ".bar")
testFileExtension(t, "foobar.zzz", true, ".zzz", "")
testFileExtension(t, "foo.bar.baz.fasthttp.gz", true, ".fasthttp.gz", ".baz")
testFileExtension(t, "", true, ".zzz", "")
testFileExtension(t, "/a/b/c.d/efg.jpg.xxx", true, ".xxx", ".jpg")
}
func testFileExtension(t *testing.T, path string, compressed bool, compressedFileSuffix, expectedExt string) {
ext := fileExtension(path, compressed, compressedFileSuffix)
if ext != expectedExt {
t.Fatalf("unexpected file extension for file %q: %q. Expecting %q", path, ext, expectedExt)
}
}
func TestServeFileContentType(t *testing.T) {
// This test can't run parallel as files in / might by changed by other tests.
var ctx RequestCtx
var req Request
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://foobar.com/baz")
ctx.Init(&req, nil, nil)
ServeFile(&ctx, "testdata/test.png")
var resp Response
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
expected := []byte("image/png")
if !bytes.Equal(resp.Header.ContentType(), expected) {
t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, resp.Header.ContentType())
}
}
func TestServeFileDirectoryRedirect(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.SkipNow()
}
var ctx RequestCtx
var req Request
req.SetRequestURI("http://foobar.com")
ctx.Init(&req, nil, nil)
ctx.Request.Reset()
ctx.Response.Reset()
ServeFile(&ctx, "fasthttputil")
if ctx.Response.StatusCode() != StatusFound {
t.Fatalf("Unexpected status code %d for directory '/fasthttputil' without trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusFound)
}
ctx.Request.Reset()
ctx.Response.Reset()
ServeFile(&ctx, "fasthttputil/")
if ctx.Response.StatusCode() != StatusOK {
t.Fatalf("Unexpected status code %d for directory '/fasthttputil/' with trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusOK)
}
ctx.Request.Reset()
ctx.Response.Reset()
ServeFile(&ctx, "fs.go")
if ctx.Response.StatusCode() != StatusOK {
t.Fatalf("Unexpected status code %d for file '/fs.go'. Expecting %d.", ctx.Response.StatusCode(), StatusOK)
}
}
fasthttp-1.31.0/fuzzit/ 0000775 0000000 0000000 00000000000 14130360711 0014760 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fuzzit/cookie/ 0000775 0000000 0000000 00000000000 14130360711 0016231 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fuzzit/cookie/cookie_fuzz.go 0000664 0000000 0000000 00000000532 14130360711 0021107 0 ustar 00root root 0000000 0000000 //go:build gofuzz
// +build gofuzz
package fuzz
import (
"bytes"
"github.com/valyala/fasthttp"
)
func Fuzz(data []byte) int {
c := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(c)
if err := c.ParseBytes(data); err != nil {
return 0
}
w := bytes.Buffer{}
if _, err := c.WriteTo(&w); err != nil {
return 0
}
return 1
}
fasthttp-1.31.0/fuzzit/request/ 0000775 0000000 0000000 00000000000 14130360711 0016450 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fuzzit/request/request_fuzz.go 0000664 0000000 0000000 00000000635 14130360711 0021551 0 ustar 00root root 0000000 0000000 //go:build gofuzz
// +build gofuzz
package fuzz
import (
"bufio"
"bytes"
"github.com/valyala/fasthttp"
)
func Fuzz(data []byte) int {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
if err := req.ReadLimitBody(bufio.NewReader(bytes.NewReader(data)), 1024*1024); err != nil {
return 0
}
w := bytes.Buffer{}
if _, err := req.WriteTo(&w); err != nil {
return 0
}
return 1
}
fasthttp-1.31.0/fuzzit/response/ 0000775 0000000 0000000 00000000000 14130360711 0016616 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fuzzit/response/response_fuzz.go 0000664 0000000 0000000 00000000637 14130360711 0022067 0 ustar 00root root 0000000 0000000 //go:build gofuzz
// +build gofuzz
package fuzz
import (
"bufio"
"bytes"
"github.com/valyala/fasthttp"
)
func Fuzz(data []byte) int {
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
if err := res.ReadLimitBody(bufio.NewReader(bytes.NewReader(data)), 1024*1024); err != nil {
return 0
}
w := bytes.Buffer{}
if _, err := res.WriteTo(&w); err != nil {
return 0
}
return 1
}
fasthttp-1.31.0/fuzzit/url/ 0000775 0000000 0000000 00000000000 14130360711 0015562 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/fuzzit/url/url_fuzz.go 0000664 0000000 0000000 00000000457 14130360711 0017777 0 ustar 00root root 0000000 0000000 //go:build gofuzz
// +build gofuzz
package fuzz
import (
"bytes"
"github.com/valyala/fasthttp"
)
func Fuzz(data []byte) int {
u := fasthttp.AcquireURI()
defer fasthttp.ReleaseURI(u)
u.UpdateBytes(data)
w := bytes.Buffer{}
if _, err := u.WriteTo(&w); err != nil {
return 0
}
return 1
}
fasthttp-1.31.0/go.mod 0000664 0000000 0000000 00000000567 14130360711 0014543 0 ustar 00root root 0000000 0000000 module github.com/valyala/fasthttp
go 1.12
require (
github.com/andybalholm/brotli v1.0.2
github.com/klauspost/compress v1.13.4
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/tcplisten v1.0.0
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
golang.org/x/net v0.0.0-20210510120150-4163338589ed
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015
)
fasthttp-1.31.0/go.sum 0000664 0000000 0000000 00000004231 14130360711 0014560 0 ustar 00root root 0000000 0000000 github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E=
github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s=
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I=
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
fasthttp-1.31.0/header.go 0000664 0000000 0000000 00000201415 14130360711 0015207 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
)
const (
rChar = byte('\r')
nChar = byte('\n')
)
// ResponseHeader represents HTTP response header.
//
// It is forbidden copying ResponseHeader instances.
// Create new instances instead and use CopyTo.
//
// ResponseHeader instance MUST NOT be used from concurrently running
// goroutines.
type ResponseHeader struct {
noCopy noCopy //nolint:unused,structcheck
disableNormalizing bool
noHTTP11 bool
connectionClose bool
noDefaultContentType bool
noDefaultDate bool
statusCode int
contentLength int
contentLengthBytes []byte
secureErrorLogMessage bool
contentType []byte
server []byte
h []argsKV
bufKV argsKV
cookies []argsKV
}
// RequestHeader represents HTTP request header.
//
// It is forbidden copying RequestHeader instances.
// Create new instances instead and use CopyTo.
//
// RequestHeader instance MUST NOT be used from concurrently running
// goroutines.
type RequestHeader struct {
noCopy noCopy //nolint:unused,structcheck
disableNormalizing bool
noHTTP11 bool
connectionClose bool
// These two fields have been moved close to other bool fields
// for reducing RequestHeader object size.
cookiesCollected bool
contentLength int
contentLengthBytes []byte
secureErrorLogMessage bool
method []byte
requestURI []byte
proto []byte
host []byte
contentType []byte
userAgent []byte
h []argsKV
bufKV argsKV
cookies []argsKV
// stores an immutable copy of headers as they were received from the
// wire.
rawHeaders []byte
}
// SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength'
// header.
func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) {
b := h.bufKV.value[:0]
b = append(b, strBytes...)
b = append(b, ' ')
b = AppendUint(b, startPos)
b = append(b, '-')
b = AppendUint(b, endPos)
b = append(b, '/')
b = AppendUint(b, contentLength)
h.bufKV.value = b
h.SetCanonical(strContentRange, h.bufKV.value)
}
// SetByteRange sets 'Range: bytes=startPos-endPos' header.
//
// * If startPos is negative, then 'bytes=-startPos' value is set.
// * If endPos is negative, then 'bytes=startPos-' value is set.
func (h *RequestHeader) SetByteRange(startPos, endPos int) {
b := h.bufKV.value[:0]
b = append(b, strBytes...)
b = append(b, '=')
if startPos >= 0 {
b = AppendUint(b, startPos)
} else {
endPos = -startPos
}
b = append(b, '-')
if endPos >= 0 {
b = AppendUint(b, endPos)
}
h.bufKV.value = b
h.SetCanonical(strRange, h.bufKV.value)
}
// StatusCode returns response status code.
func (h *ResponseHeader) StatusCode() int {
if h.statusCode == 0 {
return StatusOK
}
return h.statusCode
}
// SetStatusCode sets response status code.
func (h *ResponseHeader) SetStatusCode(statusCode int) {
h.statusCode = statusCode
}
// SetLastModified sets 'Last-Modified' header to the given value.
func (h *ResponseHeader) SetLastModified(t time.Time) {
h.bufKV.value = AppendHTTPDate(h.bufKV.value[:0], t)
h.SetCanonical(strLastModified, h.bufKV.value)
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (h *ResponseHeader) ConnectionClose() bool {
return h.connectionClose
}
// SetConnectionClose sets 'Connection: close' header.
func (h *ResponseHeader) SetConnectionClose() {
h.connectionClose = true
}
// ResetConnectionClose clears 'Connection: close' header if it exists.
func (h *ResponseHeader) ResetConnectionClose() {
if h.connectionClose {
h.connectionClose = false
h.h = delAllArgsBytes(h.h, strConnection)
}
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (h *RequestHeader) ConnectionClose() bool {
return h.connectionClose
}
// SetConnectionClose sets 'Connection: close' header.
func (h *RequestHeader) SetConnectionClose() {
h.connectionClose = true
}
// ResetConnectionClose clears 'Connection: close' header if it exists.
func (h *RequestHeader) ResetConnectionClose() {
if h.connectionClose {
h.connectionClose = false
h.h = delAllArgsBytes(h.h, strConnection)
}
}
// ConnectionUpgrade returns true if 'Connection: Upgrade' header is set.
func (h *ResponseHeader) ConnectionUpgrade() bool {
return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade)
}
// ConnectionUpgrade returns true if 'Connection: Upgrade' header is set.
func (h *RequestHeader) ConnectionUpgrade() bool {
return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade)
}
// PeekCookie is able to returns cookie by a given key from response.
func (h *ResponseHeader) PeekCookie(key string) []byte {
return peekArgStr(h.cookies, key)
}
// ContentLength returns Content-Length header value.
//
// It may be negative:
// -1 means Transfer-Encoding: chunked.
// -2 means Transfer-Encoding: identity.
func (h *ResponseHeader) ContentLength() int {
return h.contentLength
}
// SetContentLength sets Content-Length header value.
//
// Content-Length may be negative:
// -1 means Transfer-Encoding: chunked.
// -2 means Transfer-Encoding: identity.
func (h *ResponseHeader) SetContentLength(contentLength int) {
if h.mustSkipContentLength() {
return
}
h.contentLength = contentLength
if contentLength >= 0 {
h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength)
h.h = delAllArgsBytes(h.h, strTransferEncoding)
} else {
h.contentLengthBytes = h.contentLengthBytes[:0]
value := strChunked
if contentLength == -2 {
h.SetConnectionClose()
value = strIdentity
}
h.h = setArgBytes(h.h, strTransferEncoding, value, argsHasValue)
}
}
func (h *ResponseHeader) mustSkipContentLength() bool {
// From http/1.1 specs:
// All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a message-body
statusCode := h.StatusCode()
// Fast path.
if statusCode < 100 || statusCode == StatusOK {
return false
}
// Slow path.
return statusCode == StatusNotModified || statusCode == StatusNoContent || statusCode < 200
}
// ContentLength returns Content-Length header value.
//
// It may be negative:
// -1 means Transfer-Encoding: chunked.
func (h *RequestHeader) ContentLength() int {
return h.realContentLength()
}
// realContentLength returns the actual Content-Length set in the request,
// including positive lengths for GET/HEAD requests.
func (h *RequestHeader) realContentLength() int {
return h.contentLength
}
// SetContentLength sets Content-Length header value.
//
// Negative content-length sets 'Transfer-Encoding: chunked' header.
func (h *RequestHeader) SetContentLength(contentLength int) {
h.contentLength = contentLength
if contentLength >= 0 {
h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength)
h.h = delAllArgsBytes(h.h, strTransferEncoding)
} else {
h.contentLengthBytes = h.contentLengthBytes[:0]
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
}
func (h *ResponseHeader) isCompressibleContentType() bool {
contentType := h.ContentType()
return bytes.HasPrefix(contentType, strTextSlash) ||
bytes.HasPrefix(contentType, strApplicationSlash) ||
bytes.HasPrefix(contentType, strImageSVG) ||
bytes.HasPrefix(contentType, strImageIcon) ||
bytes.HasPrefix(contentType, strFontSlash) ||
bytes.HasPrefix(contentType, strMultipartSlash)
}
// ContentType returns Content-Type header value.
func (h *ResponseHeader) ContentType() []byte {
contentType := h.contentType
if !h.noDefaultContentType && len(h.contentType) == 0 {
contentType = defaultContentType
}
return contentType
}
// SetContentType sets Content-Type header value.
func (h *ResponseHeader) SetContentType(contentType string) {
h.contentType = append(h.contentType[:0], contentType...)
}
// SetContentTypeBytes sets Content-Type header value.
func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
}
// Server returns Server header value.
func (h *ResponseHeader) Server() []byte {
return h.server
}
// SetServer sets Server header value.
func (h *ResponseHeader) SetServer(server string) {
h.server = append(h.server[:0], server...)
}
// SetServerBytes sets Server header value.
func (h *ResponseHeader) SetServerBytes(server []byte) {
h.server = append(h.server[:0], server...)
}
// ContentType returns Content-Type header value.
func (h *RequestHeader) ContentType() []byte {
return h.contentType
}
// SetContentType sets Content-Type header value.
func (h *RequestHeader) SetContentType(contentType string) {
h.contentType = append(h.contentType[:0], contentType...)
}
// SetContentTypeBytes sets Content-Type header value.
func (h *RequestHeader) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
}
// SetMultipartFormBoundary sets the following Content-Type:
// 'multipart/form-data; boundary=...'
// where ... is substituted by the given boundary.
func (h *RequestHeader) SetMultipartFormBoundary(boundary string) {
b := h.bufKV.value[:0]
b = append(b, strMultipartFormData...)
b = append(b, ';', ' ')
b = append(b, strBoundary...)
b = append(b, '=')
b = append(b, boundary...)
h.bufKV.value = b
h.SetContentTypeBytes(h.bufKV.value)
}
// SetMultipartFormBoundaryBytes sets the following Content-Type:
// 'multipart/form-data; boundary=...'
// where ... is substituted by the given boundary.
func (h *RequestHeader) SetMultipartFormBoundaryBytes(boundary []byte) {
b := h.bufKV.value[:0]
b = append(b, strMultipartFormData...)
b = append(b, ';', ' ')
b = append(b, strBoundary...)
b = append(b, '=')
b = append(b, boundary...)
h.bufKV.value = b
h.SetContentTypeBytes(h.bufKV.value)
}
// MultipartFormBoundary returns boundary part
// from 'multipart/form-data; boundary=...' Content-Type.
func (h *RequestHeader) MultipartFormBoundary() []byte {
b := h.ContentType()
if !bytes.HasPrefix(b, strMultipartFormData) {
return nil
}
b = b[len(strMultipartFormData):]
if len(b) == 0 || b[0] != ';' {
return nil
}
var n int
for len(b) > 0 {
n++
for len(b) > n && b[n] == ' ' {
n++
}
b = b[n:]
if !bytes.HasPrefix(b, strBoundary) {
if n = bytes.IndexByte(b, ';'); n < 0 {
return nil
}
continue
}
b = b[len(strBoundary):]
if len(b) == 0 || b[0] != '=' {
return nil
}
b = b[1:]
if n = bytes.IndexByte(b, ';'); n >= 0 {
b = b[:n]
}
if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
b = b[1 : len(b)-1]
}
return b
}
return nil
}
// Host returns Host header value.
func (h *RequestHeader) Host() []byte {
return h.host
}
// SetHost sets Host header value.
func (h *RequestHeader) SetHost(host string) {
h.host = append(h.host[:0], host...)
}
// SetHostBytes sets Host header value.
func (h *RequestHeader) SetHostBytes(host []byte) {
h.host = append(h.host[:0], host...)
}
// UserAgent returns User-Agent header value.
func (h *RequestHeader) UserAgent() []byte {
return h.userAgent
}
// SetUserAgent sets User-Agent header value.
func (h *RequestHeader) SetUserAgent(userAgent string) {
h.userAgent = append(h.userAgent[:0], userAgent...)
}
// SetUserAgentBytes sets User-Agent header value.
func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) {
h.userAgent = append(h.userAgent[:0], userAgent...)
}
// Referer returns Referer header value.
func (h *RequestHeader) Referer() []byte {
return h.PeekBytes(strReferer)
}
// SetReferer sets Referer header value.
func (h *RequestHeader) SetReferer(referer string) {
h.SetBytesK(strReferer, referer)
}
// SetRefererBytes sets Referer header value.
func (h *RequestHeader) SetRefererBytes(referer []byte) {
h.SetCanonical(strReferer, referer)
}
// Method returns HTTP request method.
func (h *RequestHeader) Method() []byte {
if len(h.method) == 0 {
return []byte(MethodGet)
}
return h.method
}
// SetMethod sets HTTP request method.
func (h *RequestHeader) SetMethod(method string) {
h.method = append(h.method[:0], method...)
}
// SetMethodBytes sets HTTP request method.
func (h *RequestHeader) SetMethodBytes(method []byte) {
h.method = append(h.method[:0], method...)
}
// Protocol returns HTTP protocol.
func (h *RequestHeader) Protocol() []byte {
if len(h.proto) == 0 {
return strHTTP11
}
return h.proto
}
// SetProtocol sets HTTP request protocol.
func (h *RequestHeader) SetProtocol(method string) {
h.proto = append(h.proto[:0], method...)
h.noHTTP11 = !bytes.Equal(h.proto, strHTTP11)
}
// SetProtocolBytes sets HTTP request protocol.
func (h *RequestHeader) SetProtocolBytes(method []byte) {
h.proto = append(h.proto[:0], method...)
h.noHTTP11 = !bytes.Equal(h.proto, strHTTP11)
}
// RequestURI returns RequestURI from the first HTTP request line.
func (h *RequestHeader) RequestURI() []byte {
requestURI := h.requestURI
if len(requestURI) == 0 {
requestURI = strSlash
}
return requestURI
}
// SetRequestURI sets RequestURI for the first HTTP request line.
// RequestURI must be properly encoded.
// Use URI.RequestURI for constructing proper RequestURI if unsure.
func (h *RequestHeader) SetRequestURI(requestURI string) {
h.requestURI = append(h.requestURI[:0], requestURI...)
}
// SetRequestURIBytes sets RequestURI for the first HTTP request line.
// RequestURI must be properly encoded.
// Use URI.RequestURI for constructing proper RequestURI if unsure.
func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) {
h.requestURI = append(h.requestURI[:0], requestURI...)
}
// IsGet returns true if request method is GET.
func (h *RequestHeader) IsGet() bool {
return string(h.Method()) == MethodGet
}
// IsPost returns true if request method is POST.
func (h *RequestHeader) IsPost() bool {
return string(h.Method()) == MethodPost
}
// IsPut returns true if request method is PUT.
func (h *RequestHeader) IsPut() bool {
return string(h.Method()) == MethodPut
}
// IsHead returns true if request method is HEAD.
func (h *RequestHeader) IsHead() bool {
return string(h.Method()) == MethodHead
}
// IsDelete returns true if request method is DELETE.
func (h *RequestHeader) IsDelete() bool {
return string(h.Method()) == MethodDelete
}
// IsConnect returns true if request method is CONNECT.
func (h *RequestHeader) IsConnect() bool {
return string(h.Method()) == MethodConnect
}
// IsOptions returns true if request method is OPTIONS.
func (h *RequestHeader) IsOptions() bool {
return string(h.Method()) == MethodOptions
}
// IsTrace returns true if request method is TRACE.
func (h *RequestHeader) IsTrace() bool {
return string(h.Method()) == MethodTrace
}
// IsPatch returns true if request method is PATCH.
func (h *RequestHeader) IsPatch() bool {
return string(h.Method()) == MethodPatch
}
// IsHTTP11 returns true if the request is HTTP/1.1.
func (h *RequestHeader) IsHTTP11() bool {
return !h.noHTTP11
}
// IsHTTP11 returns true if the response is HTTP/1.1.
func (h *ResponseHeader) IsHTTP11() bool {
return !h.noHTTP11
}
// HasAcceptEncoding returns true if the header contains
// the given Accept-Encoding value.
func (h *RequestHeader) HasAcceptEncoding(acceptEncoding string) bool {
h.bufKV.value = append(h.bufKV.value[:0], acceptEncoding...)
return h.HasAcceptEncodingBytes(h.bufKV.value)
}
// HasAcceptEncodingBytes returns true if the header contains
// the given Accept-Encoding value.
func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool {
ae := h.peek(strAcceptEncoding)
n := bytes.Index(ae, acceptEncoding)
if n < 0 {
return false
}
b := ae[n+len(acceptEncoding):]
if len(b) > 0 && b[0] != ',' {
return false
}
if n == 0 {
return true
}
return ae[n-1] == ' '
}
// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *ResponseHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
return n
}
// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *RequestHeader) Len() int {
n := 0
h.VisitAll(func(k, v []byte) { n++ })
return n
}
// DisableNormalizing disables header names' normalization.
//
// By default all the header names are normalized by uppercasing
// the first letter and all the first letters following dashes,
// while lowercasing all the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *RequestHeader) DisableNormalizing() {
h.disableNormalizing = true
}
// EnableNormalizing enables header names' normalization.
//
// Header names are normalized by uppercasing the first letter and
// all the first letters following dashes, while lowercasing all
// the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// This is enabled by default unless disabled using DisableNormalizing()
func (h *RequestHeader) EnableNormalizing() {
h.disableNormalizing = false
}
// DisableNormalizing disables header names' normalization.
//
// By default all the header names are normalized by uppercasing
// the first letter and all the first letters following dashes,
// while lowercasing all the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *ResponseHeader) DisableNormalizing() {
h.disableNormalizing = true
}
// EnableNormalizing enables header names' normalization.
//
// Header names are normalized by uppercasing the first letter and
// all the first letters following dashes, while lowercasing all
// the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// This is enabled by default unless disabled using DisableNormalizing()
func (h *ResponseHeader) EnableNormalizing() {
h.disableNormalizing = false
}
// SetNoDefaultContentType allows you to control if a default Content-Type header will be set (false) or not (true).
func (h *ResponseHeader) SetNoDefaultContentType(noDefaultContentType bool) {
h.noDefaultContentType = noDefaultContentType
}
// Reset clears response header.
func (h *ResponseHeader) Reset() {
h.disableNormalizing = false
h.SetNoDefaultContentType(false)
h.noDefaultDate = false
h.resetSkipNormalize()
}
func (h *ResponseHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.connectionClose = false
h.statusCode = 0
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
h.contentType = h.contentType[:0]
h.server = h.server[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
}
// Reset clears request header.
func (h *RequestHeader) Reset() {
h.disableNormalizing = false
h.resetSkipNormalize()
}
func (h *RequestHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.connectionClose = false
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
h.method = h.method[:0]
h.proto = h.proto[:0]
h.requestURI = h.requestURI[:0]
h.host = h.host[:0]
h.contentType = h.contentType[:0]
h.userAgent = h.userAgent[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
h.cookiesCollected = false
h.rawHeaders = h.rawHeaders[:0]
}
// CopyTo copies all the headers to dst.
func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.Reset()
dst.disableNormalizing = h.disableNormalizing
dst.noHTTP11 = h.noHTTP11
dst.connectionClose = h.connectionClose
dst.noDefaultContentType = h.noDefaultContentType
dst.noDefaultDate = h.noDefaultDate
dst.statusCode = h.statusCode
dst.contentLength = h.contentLength
dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
dst.contentType = append(dst.contentType[:0], h.contentType...)
dst.server = append(dst.server[:0], h.server...)
dst.h = copyArgs(dst.h, h.h)
dst.cookies = copyArgs(dst.cookies, h.cookies)
}
// CopyTo copies all the headers to dst.
func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.Reset()
dst.disableNormalizing = h.disableNormalizing
dst.noHTTP11 = h.noHTTP11
dst.connectionClose = h.connectionClose
dst.contentLength = h.contentLength
dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
dst.method = append(dst.method[:0], h.method...)
dst.proto = append(dst.proto[:0], h.proto...)
dst.requestURI = append(dst.requestURI[:0], h.requestURI...)
dst.host = append(dst.host[:0], h.host...)
dst.contentType = append(dst.contentType[:0], h.contentType...)
dst.userAgent = append(dst.userAgent[:0], h.userAgent...)
dst.h = copyArgs(dst.h, h.h)
dst.cookies = copyArgs(dst.cookies, h.cookies)
dst.cookiesCollected = h.cookiesCollected
dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...)
}
// VisitAll calls f for each header.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
func (h *ResponseHeader) VisitAll(f func(key, value []byte)) {
if len(h.contentLengthBytes) > 0 {
f(strContentLength, h.contentLengthBytes)
}
contentType := h.ContentType()
if len(contentType) > 0 {
f(strContentType, contentType)
}
server := h.Server()
if len(server) > 0 {
f(strServer, server)
}
if len(h.cookies) > 0 {
visitArgs(h.cookies, func(k, v []byte) {
f(strSetCookie, v)
})
}
visitArgs(h.h, f)
if h.ConnectionClose() {
f(strConnection, strClose)
}
}
// VisitAllCookie calls f for each response cookie.
//
// Cookie name is passed in key and the whole Set-Cookie header value
// is passed in value on each f invocation. Value may be parsed
// with Cookie.ParseBytes().
//
// f must not retain references to key and/or value after returning.
func (h *ResponseHeader) VisitAllCookie(f func(key, value []byte)) {
visitArgs(h.cookies, f)
}
// VisitAllCookie calls f for each request cookie.
//
// f must not retain references to key and/or value after returning.
func (h *RequestHeader) VisitAllCookie(f func(key, value []byte)) {
h.collectCookies()
visitArgs(h.cookies, f)
}
// VisitAll calls f for each header.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// To get the headers in order they were received use VisitAllInOrder.
func (h *RequestHeader) VisitAll(f func(key, value []byte)) {
host := h.Host()
if len(host) > 0 {
f(strHost, host)
}
if len(h.contentLengthBytes) > 0 {
f(strContentLength, h.contentLengthBytes)
}
contentType := h.ContentType()
if len(contentType) > 0 {
f(strContentType, contentType)
}
userAgent := h.UserAgent()
if len(userAgent) > 0 {
f(strUserAgent, userAgent)
}
h.collectCookies()
if len(h.cookies) > 0 {
h.bufKV.value = appendRequestCookieBytes(h.bufKV.value[:0], h.cookies)
f(strCookie, h.bufKV.value)
}
visitArgs(h.h, f)
if h.ConnectionClose() {
f(strConnection, strClose)
}
}
// VisitAllInOrder calls f for each header in the order they were received.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// This function is slightly slower than VisitAll because it has to reparse the
// raw headers to get the order.
func (h *RequestHeader) VisitAllInOrder(f func(key, value []byte)) {
var s headerScanner
s.b = h.rawHeaders
s.disableNormalizing = h.disableNormalizing
for s.next() {
if len(s.key) > 0 {
f(s.key, s.value)
}
}
}
// Del deletes header with the given key.
func (h *ResponseHeader) Del(key string) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.del(k)
}
// DelBytes deletes header with the given key.
func (h *ResponseHeader) DelBytes(key []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.del(h.bufKV.key)
}
func (h *ResponseHeader) del(key []byte) {
switch string(key) {
case HeaderContentType:
h.contentType = h.contentType[:0]
case HeaderServer:
h.server = h.server[:0]
case HeaderSetCookie:
h.cookies = h.cookies[:0]
case HeaderContentLength:
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
}
h.h = delAllArgsBytes(h.h, key)
}
// Del deletes header with the given key.
func (h *RequestHeader) Del(key string) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.del(k)
}
// DelBytes deletes header with the given key.
func (h *RequestHeader) DelBytes(key []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.del(h.bufKV.key)
}
func (h *RequestHeader) del(key []byte) {
switch string(key) {
case HeaderHost:
h.host = h.host[:0]
case HeaderContentType:
h.contentType = h.contentType[:0]
case HeaderUserAgent:
h.userAgent = h.userAgent[:0]
case HeaderCookie:
h.cookies = h.cookies[:0]
case HeaderContentLength:
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
}
h.h = delAllArgsBytes(h.h, key)
}
// setSpecialHeader handles special headers and return true when a header is processed.
func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
return false
}
switch key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(strContentType, key) {
h.SetContentTypeBytes(value)
return true
} else if caseInsensitiveCompare(strContentLength, key) {
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
} else if caseInsensitiveCompare(strConnection, key) {
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
return true
}
case 's':
if caseInsensitiveCompare(strServer, key) {
h.SetServerBytes(value)
return true
} else if caseInsensitiveCompare(strSetCookie, key) {
var kv *argsKV
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, value)
kv.value = append(kv.value[:0], value...)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
}
case 'd':
if caseInsensitiveCompare(strDate, key) {
// Date is managed automatically.
return true
}
}
return false
}
// setSpecialHeader handles special headers and return true when a header is processed.
func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
return false
}
switch key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(strContentType, key) {
h.SetContentTypeBytes(value)
return true
} else if caseInsensitiveCompare(strContentLength, key) {
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
} else if caseInsensitiveCompare(strConnection, key) {
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
return true
} else if caseInsensitiveCompare(strCookie, key) {
h.collectCookies()
h.cookies = parseRequestCookies(h.cookies, value)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
}
case 'h':
if caseInsensitiveCompare(strHost, key) {
h.SetHostBytes(value)
return true
}
case 'u':
if caseInsensitiveCompare(strUserAgent, key) {
h.SetUserAgentBytes(value)
return true
}
}
return false
}
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, k, value, argsHasValue)
}
// Set sets the given 'key: value' header.
//
// Use Add for setting multiple header values under the same key.
func (h *ResponseHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, h.bufKV.value)
}
// SetBytesK sets the given 'key: value' header.
//
// Use AddBytesK for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesK(key []byte, value string) {
h.bufKV.value = append(h.bufKV.value[:0], value...)
h.SetBytesKV(key, h.bufKV.value)
}
// SetBytesV sets the given 'key: value' header.
//
// Use AddBytesV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.SetCanonical(k, value)
}
// SetBytesKV sets the given 'key: value' header.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, value)
}
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
func (h *ResponseHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
// SetCookie sets the given response cookie.
//
// It is save re-using the cookie after the function returns.
func (h *ResponseHeader) SetCookie(cookie *Cookie) {
h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie(), argsHasValue)
}
// SetCookie sets 'key: value' cookies.
func (h *RequestHeader) SetCookie(key, value string) {
h.collectCookies()
h.cookies = setArg(h.cookies, key, value, argsHasValue)
}
// SetCookieBytesK sets 'key: value' cookies.
func (h *RequestHeader) SetCookieBytesK(key []byte, value string) {
h.SetCookie(b2s(key), value)
}
// SetCookieBytesKV sets 'key: value' cookies.
func (h *RequestHeader) SetCookieBytesKV(key, value []byte) {
h.SetCookie(b2s(key), b2s(value))
}
// DelClientCookie instructs the client to remove the given cookie.
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
// c := AcquireCookie()
// c.SetKey(key)
// c.SetDomain("example.com")
// c.SetPath("/path")
// c.SetExpire(CookieExpireDelete)
// h.SetCookie(c)
// ReleaseCookie(c)
//
// Use DelCookie if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookie(key string) {
h.DelCookie(key)
c := AcquireCookie()
c.SetKey(key)
c.SetExpire(CookieExpireDelete)
h.SetCookie(c)
ReleaseCookie(c)
}
// DelClientCookieBytes instructs the client to remove the given cookie.
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
// c := AcquireCookie()
// c.SetKey(key)
// c.SetDomain("example.com")
// c.SetPath("/path")
// c.SetExpire(CookieExpireDelete)
// h.SetCookie(c)
// ReleaseCookie(c)
//
// Use DelCookieBytes if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookieBytes(key []byte) {
h.DelClientCookie(b2s(key))
}
// DelCookie removes cookie under the given key from response header.
//
// Note that DelCookie doesn't remove the cookie from the client.
// Use DelClientCookie instead.
func (h *ResponseHeader) DelCookie(key string) {
h.cookies = delAllArgs(h.cookies, key)
}
// DelCookieBytes removes cookie under the given key from response header.
//
// Note that DelCookieBytes doesn't remove the cookie from the client.
// Use DelClientCookieBytes instead.
func (h *ResponseHeader) DelCookieBytes(key []byte) {
h.DelCookie(b2s(key))
}
// DelCookie removes cookie under the given key.
func (h *RequestHeader) DelCookie(key string) {
h.collectCookies()
h.cookies = delAllArgs(h.cookies, key)
}
// DelCookieBytes removes cookie under the given key.
func (h *RequestHeader) DelCookieBytes(key []byte) {
h.DelCookie(b2s(key))
}
// DelAllCookies removes all the cookies from response headers.
func (h *ResponseHeader) DelAllCookies() {
h.cookies = h.cookies[:0]
}
// DelAllCookies removes all the cookies from request headers.
func (h *RequestHeader) DelAllCookies() {
h.collectCookies()
h.cookies = h.cookies[:0]
}
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
func (h *RequestHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
func (h *RequestHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
func (h *RequestHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Cookie,
// Transfer-Encoding, Host and User-Agent headers can only be set once
// and will overwrite the previous value.
func (h *RequestHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, k, value, argsHasValue)
}
// Set sets the given 'key: value' header.
//
// Use Add for setting multiple header values under the same key.
func (h *RequestHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, h.bufKV.value)
}
// SetBytesK sets the given 'key: value' header.
//
// Use AddBytesK for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesK(key []byte, value string) {
h.bufKV.value = append(h.bufKV.value[:0], value...)
h.SetBytesKV(key, h.bufKV.value)
}
// SetBytesV sets the given 'key: value' header.
//
// Use AddBytesV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.SetCanonical(k, value)
}
// SetBytesKV sets the given 'key: value' header.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, value)
}
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
func (h *RequestHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
// Peek returns header value for the given key.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to the returned value. Make copies instead.
func (h *ResponseHeader) Peek(key string) []byte {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
// PeekBytes returns header value for the given key.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) PeekBytes(key []byte) []byte {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
return h.peek(h.bufKV.key)
}
// Peek returns header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Peek(key string) []byte {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
// PeekBytes returns header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) PeekBytes(key []byte) []byte {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
return h.peek(h.bufKV.key)
}
func (h *ResponseHeader) peek(key []byte) []byte {
switch string(key) {
case HeaderContentType:
return h.ContentType()
case HeaderServer:
return h.Server()
case HeaderConnection:
if h.ConnectionClose() {
return strClose
}
return peekArgBytes(h.h, key)
case HeaderContentLength:
return h.contentLengthBytes
case HeaderSetCookie:
return appendResponseCookieBytes(nil, h.cookies)
default:
return peekArgBytes(h.h, key)
}
}
func (h *RequestHeader) peek(key []byte) []byte {
switch string(key) {
case HeaderHost:
return h.Host()
case HeaderContentType:
return h.ContentType()
case HeaderUserAgent:
return h.UserAgent()
case HeaderConnection:
if h.ConnectionClose() {
return strClose
}
return peekArgBytes(h.h, key)
case HeaderContentLength:
return h.contentLengthBytes
case HeaderCookie:
if h.cookiesCollected {
return appendRequestCookieBytes(nil, h.cookies)
}
return peekArgBytes(h.h, key)
default:
return peekArgBytes(h.h, key)
}
}
// Cookie returns cookie for the given key.
func (h *RequestHeader) Cookie(key string) []byte {
h.collectCookies()
return peekArgStr(h.cookies, key)
}
// CookieBytes returns cookie for the given key.
func (h *RequestHeader) CookieBytes(key []byte) []byte {
h.collectCookies()
return peekArgBytes(h.cookies, key)
}
// Cookie fills cookie for the given cookie.Key.
//
// Returns false if cookie with the given cookie.Key is missing.
func (h *ResponseHeader) Cookie(cookie *Cookie) bool {
v := peekArgBytes(h.cookies, cookie.Key())
if v == nil {
return false
}
cookie.ParseBytes(v) //nolint:errcheck
return true
}
// Read reads response header from r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *ResponseHeader) Read(r *bufio.Reader) error {
n := 1
for {
err := h.tryRead(r, n)
if err == nil {
return nil
}
if err != errNeedMore {
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
}
}
func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error {
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
return ErrTimeout
}
// treat all other errors on the first byte read as EOF
if n == 1 || err == io.EOF {
return io.EOF
}
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
if h.secureErrorLogMessage {
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading response headers"),
}
}
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading response headers: %s", errSmallBuffer),
}
}
return fmt.Errorf("error when reading response headers: %s", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
if errParse != nil {
return headerError("response", err, errParse, b, h.secureErrorLogMessage)
}
mustDiscard(r, headersLen)
return nil
}
func headerError(typ string, err, errParse error, b []byte, secureErrorLogMessage bool) error {
if errParse != errNeedMore {
return headerErrorMsg(typ, errParse, b, secureErrorLogMessage)
}
if err == nil {
return errNeedMore
}
// Buggy servers may leave trailing CRLFs after http body.
// Treat this case as EOF.
if isOnlyCRLF(b) {
return io.EOF
}
if err != bufio.ErrBufferFull {
return headerErrorMsg(typ, err, b, secureErrorLogMessage)
}
return &ErrSmallBuffer{
error: headerErrorMsg(typ, errSmallBuffer, b, secureErrorLogMessage),
}
}
func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) error {
if secureErrorLogMessage {
return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d", typ, err, len(b))
}
return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d, contents: %s", typ, err, len(b), bufferSnippet(b))
}
// Read reads request header from r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) Read(r *bufio.Reader) error {
return h.readLoop(r, true)
}
// readLoop reads request header from r optionally loops until it has enough data.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error {
n := 1
for {
err := h.tryRead(r, n)
if err == nil {
return nil
}
if !waitForMore || err != errNeedMore {
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
}
}
func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
if err == io.EOF {
return err
}
if err == nil {
panic("bufio.Reader.Peek() returned nil, nil")
}
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading request headers: %s", errSmallBuffer),
}
}
// n == 1 on the first read for the request.
if n == 1 {
// We didn't read a single byte.
return ErrNothingRead{err}
}
return fmt.Errorf("error when reading request headers: %s", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
if errParse != nil {
return headerError("request", err, errParse, b, h.secureErrorLogMessage)
}
mustDiscard(r, headersLen)
return nil
}
func bufferSnippet(b []byte) string {
n := len(b)
start := 200
end := n - start
if start >= end {
start = n
end = n
}
bStart, bEnd := b[:start], b[end:]
if len(bEnd) == 0 {
return fmt.Sprintf("%q", b)
}
return fmt.Sprintf("%q...%q", bStart, bEnd)
}
func isOnlyCRLF(b []byte) bool {
for _, ch := range b {
if ch != rChar && ch != nChar {
return false
}
}
return true
}
func updateServerDate() {
refreshServerDate()
go func() {
for {
time.Sleep(time.Second)
refreshServerDate()
}
}()
}
var (
serverDate atomic.Value
serverDateOnce sync.Once // serverDateOnce.Do(updateServerDate)
)
func refreshServerDate() {
b := AppendHTTPDate(nil, time.Now())
serverDate.Store(b)
}
// Write writes response header to w.
func (h *ResponseHeader) Write(w *bufio.Writer) error {
_, err := w.Write(h.Header())
return err
}
// WriteTo writes response header to w.
//
// WriteTo implements io.WriterTo interface.
func (h *ResponseHeader) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(h.Header())
return int64(n), err
}
// Header returns response header representation.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) Header() []byte {
h.bufKV.value = h.AppendBytes(h.bufKV.value[:0])
return h.bufKV.value
}
// String returns response header representation.
func (h *ResponseHeader) String() string {
return string(h.Header())
}
// AppendBytes appends response header representation to dst and returns
// the extended dst.
func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
statusCode := h.StatusCode()
if statusCode < 0 {
statusCode = StatusOK
}
dst = append(dst, statusLine(statusCode)...)
server := h.Server()
if len(server) != 0 {
dst = appendHeaderLine(dst, strServer, server)
}
if !h.noDefaultDate {
serverDateOnce.Do(updateServerDate)
dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte))
}
// Append Content-Type only for non-zero responses
// or if it is explicitly set.
// See https://github.com/valyala/fasthttp/issues/28 .
if h.ContentLength() != 0 || len(h.contentType) > 0 {
contentType := h.ContentType()
if len(contentType) > 0 {
dst = appendHeaderLine(dst, strContentType, contentType)
}
}
if len(h.contentLengthBytes) > 0 {
dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes)
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
if h.noDefaultDate || !bytes.Equal(kv.key, strDate) {
dst = appendHeaderLine(dst, kv.key, kv.value)
}
}
n := len(h.cookies)
if n > 0 {
for i := 0; i < n; i++ {
kv := &h.cookies[i]
dst = appendHeaderLine(dst, strSetCookie, kv.value)
}
}
if h.ConnectionClose() {
dst = appendHeaderLine(dst, strConnection, strClose)
}
return append(dst, strCRLF...)
}
// Write writes request header to w.
func (h *RequestHeader) Write(w *bufio.Writer) error {
_, err := w.Write(h.Header())
return err
}
// WriteTo writes request header to w.
//
// WriteTo implements io.WriterTo interface.
func (h *RequestHeader) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(h.Header())
return int64(n), err
}
// Header returns request header representation.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Header() []byte {
h.bufKV.value = h.AppendBytes(h.bufKV.value[:0])
return h.bufKV.value
}
// RawHeaders returns raw header key/value bytes.
//
// Depending on server configuration, header keys may be normalized to
// capital-case in place.
//
// This copy is set aside during parsing, so empty slice is returned for all
// cases where parsing did not happen. Similarly, request line is not stored
// during parsing and can not be returned.
//
// The slice is not safe to use after the handler returns.
func (h *RequestHeader) RawHeaders() []byte {
return h.rawHeaders
}
// String returns request header representation.
func (h *RequestHeader) String() string {
return string(h.Header())
}
// AppendBytes appends request header representation to dst and returns
// the extended dst.
func (h *RequestHeader) AppendBytes(dst []byte) []byte {
dst = append(dst, h.Method()...)
dst = append(dst, ' ')
dst = append(dst, h.RequestURI()...)
dst = append(dst, ' ')
dst = append(dst, h.Protocol()...)
dst = append(dst, strCRLF...)
userAgent := h.UserAgent()
if len(userAgent) > 0 {
dst = appendHeaderLine(dst, strUserAgent, userAgent)
}
host := h.Host()
if len(host) > 0 {
dst = appendHeaderLine(dst, strHost, host)
}
contentType := h.ContentType()
if len(contentType) == 0 && !h.ignoreBody() {
contentType = strDefaultContentType
}
if len(contentType) > 0 {
dst = appendHeaderLine(dst, strContentType, contentType)
}
if len(h.contentLengthBytes) > 0 {
dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes)
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
dst = appendHeaderLine(dst, kv.key, kv.value)
}
// there is no need in h.collectCookies() here, since if cookies aren't collected yet,
// they all are located in h.h.
n := len(h.cookies)
if n > 0 {
dst = append(dst, strCookie...)
dst = append(dst, strColonSpace...)
dst = appendRequestCookieBytes(dst, h.cookies)
dst = append(dst, strCRLF...)
}
if h.ConnectionClose() {
dst = appendHeaderLine(dst, strConnection, strClose)
}
return append(dst, strCRLF...)
}
func appendHeaderLine(dst, key, value []byte) []byte {
dst = append(dst, key...)
dst = append(dst, strColonSpace...)
dst = append(dst, value...)
return append(dst, strCRLF...)
}
func (h *ResponseHeader) parse(buf []byte) (int, error) {
m, err := h.parseFirstLine(buf)
if err != nil {
return 0, err
}
n, err := h.parseHeaders(buf[m:])
if err != nil {
return 0, err
}
return m + n, nil
}
func (h *RequestHeader) ignoreBody() bool {
return h.IsGet() || h.IsHead()
}
func (h *RequestHeader) parse(buf []byte) (int, error) {
m, err := h.parseFirstLine(buf)
if err != nil {
return 0, err
}
h.rawHeaders, _, err = readRawHeaders(h.rawHeaders[:0], buf[m:])
if err != nil {
return 0, err
}
var n int
n, err = h.parseHeaders(buf[m:])
if err != nil {
return 0, err
}
return m + n, nil
}
func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) {
bNext := buf
var b []byte
var err error
for len(b) == 0 {
if b, bNext, err = nextLine(bNext); err != nil {
return 0, err
}
}
// parse protocol
n := bytes.IndexByte(b, ' ')
if n < 0 {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("cannot find whitespace in the first line of response")
}
return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf)
}
h.noHTTP11 = !bytes.Equal(b[:n], strHTTP11)
b = b[n+1:]
// parse status code
h.statusCode, n, err = parseUintBuf(b)
if err != nil {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("cannot parse response status code: %s", err)
}
return 0, fmt.Errorf("cannot parse response status code: %s. Response %q", err, buf)
}
if len(b) > n && b[n] != ' ' {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("unexpected char at the end of status code")
}
return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf)
}
return len(buf) - len(bNext), nil
}
func (h *RequestHeader) parseFirstLine(buf []byte) (int, error) {
bNext := buf
var b []byte
var err error
for len(b) == 0 {
if b, bNext, err = nextLine(bNext); err != nil {
return 0, err
}
}
// parse method
n := bytes.IndexByte(b, ' ')
if n <= 0 {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("cannot find http request method")
}
return 0, fmt.Errorf("cannot find http request method in %q", buf)
}
h.method = append(h.method[:0], b[:n]...)
b = b[n+1:]
protoStr := strHTTP11
// parse requestURI
n = bytes.LastIndexByte(b, ' ')
if n < 0 {
h.noHTTP11 = true
n = len(b)
protoStr = strHTTP10
} else if n == 0 {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("requestURI cannot be empty")
}
return 0, fmt.Errorf("requestURI cannot be empty in %q", buf)
} else if !bytes.Equal(b[n+1:], strHTTP11) {
h.noHTTP11 = true
protoStr = b[n+1:]
}
h.proto = append(h.proto[:0], protoStr...)
h.requestURI = append(h.requestURI[:0], b[:n]...)
return len(buf) - len(bNext), nil
}
func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
n := bytes.IndexByte(buf, nChar)
if n < 0 {
return dst[:0], 0, errNeedMore
}
if (n == 1 && buf[0] == rChar) || n == 0 {
// empty headers
return dst, n + 1, nil
}
n++
b := buf
m := n
for {
b = b[m:]
m = bytes.IndexByte(b, nChar)
if m < 0 {
return dst, 0, errNeedMore
}
m++
n += m
if (m == 2 && b[0] == rChar) || m == 1 {
dst = append(dst, buf[:n]...)
return dst, n, nil
}
}
}
func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
// 'identity' content-length by default
h.contentLength = -2
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
var kv *argsKV
for s.next() {
if len(s.key) > 0 {
switch s.key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
if h.contentLength, err = parseContentLength(s.value); err != nil {
h.contentLength = -2
} else {
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 's':
if caseInsensitiveCompare(s.key, strServer) {
h.server = append(h.server[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strSetCookie) {
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, s.value)
kv.value = append(kv.value[:0], s.value...)
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
}
if s.err != nil {
h.connectionClose = true
return 0, s.err
}
if h.contentLength < 0 {
h.contentLengthBytes = h.contentLengthBytes[:0]
}
if h.contentLength == -2 && !h.ConnectionUpgrade() && !h.mustSkipContentLength() {
h.h = setArgBytes(h.h, strTransferEncoding, strIdentity, argsHasValue)
h.connectionClose = true
}
if h.noHTTP11 && !h.connectionClose {
// close connection for non-http/1.1 response unless 'Connection: keep-alive' is set.
v := peekArgBytes(h.h, strConnection)
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
return len(buf) - len(s.b), nil
}
func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
h.contentLength = -2
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
for s.next() {
if len(s.key) > 0 {
// Spaces between the header key and colon are not allowed.
// See RFC 7230, Section 3.2.4.
if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
err = fmt.Errorf("invalid header key %q", s.key)
continue
}
switch s.key[0] | 0x20 {
case 'h':
if caseInsensitiveCompare(s.key, strHost) {
h.host = append(h.host[:0], s.value...)
continue
}
case 'u':
if caseInsensitiveCompare(s.key, strUserAgent) {
h.userAgent = append(h.userAgent[:0], s.value...)
continue
}
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
var nerr error
if h.contentLength, nerr = parseContentLength(s.value); nerr != nil {
if err == nil {
err = nerr
}
h.contentLength = -2
} else {
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if !bytes.Equal(s.value, strIdentity) {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil && err == nil {
err = s.err
}
if err != nil {
h.connectionClose = true
return 0, err
}
if h.contentLength < 0 {
h.contentLengthBytes = h.contentLengthBytes[:0]
}
if h.noHTTP11 && !h.connectionClose {
// close connection for non-http/1.1 request unless 'Connection: keep-alive' is set.
v := peekArgBytes(h.h, strConnection)
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
return s.hLen, nil
}
func (h *RequestHeader) collectCookies() {
if h.cookiesCollected {
return
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
if caseInsensitiveCompare(kv.key, strCookie) {
h.cookies = parseRequestCookies(h.cookies, kv.value)
tmp := *kv
copy(h.h[i:], h.h[i+1:])
n--
i--
h.h[n] = tmp
h.h = h.h[:n]
}
}
h.cookiesCollected = true
}
func parseContentLength(b []byte) (int, error) {
v, n, err := parseUintBuf(b)
if err != nil {
return -1, err
}
if n != len(b) {
return -1, fmt.Errorf("non-numeric chars at the end of Content-Length")
}
return v, nil
}
type headerScanner struct {
b []byte
key []byte
value []byte
err error
// hLen stores header subslice len
hLen int
disableNormalizing bool
// by checking whether the next line contains a colon or not to tell
// it's a header entry or a multi line value of current header entry.
// the side effect of this operation is that we know the index of the
// next colon and new line, so this can be used during next iteration,
// instead of find them again.
nextColon int
nextNewLine int
initialized bool
}
func (s *headerScanner) next() bool {
if !s.initialized {
s.nextColon = -1
s.nextNewLine = -1
s.initialized = true
}
bLen := len(s.b)
if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar {
s.b = s.b[2:]
s.hLen += 2
return false
}
if bLen >= 1 && s.b[0] == nChar {
s.b = s.b[1:]
s.hLen++
return false
}
var n int
if s.nextColon >= 0 {
n = s.nextColon
s.nextColon = -1
} else {
n = bytes.IndexByte(s.b, ':')
// There can't be a \n inside the header name, check for this.
x := bytes.IndexByte(s.b, nChar)
if x < 0 {
// A header name should always at some point be followed by a \n
// even if it's the one that terminates the header block.
s.err = errNeedMore
return false
}
if x < n {
// There was a \n before the :
s.err = errInvalidName
return false
}
}
if n < 0 {
s.err = errNeedMore
return false
}
s.key = s.b[:n]
normalizeHeaderKey(s.key, s.disableNormalizing)
n++
for len(s.b) > n && s.b[n] == ' ' {
n++
// the newline index is a relative index, and lines below trimed `s.b` by `n`,
// so the relative newline index also shifted forward. it's safe to decrease
// to a minus value, it means it's invalid, and will find the newline again.
s.nextNewLine--
}
s.hLen += n
s.b = s.b[n:]
if s.nextNewLine >= 0 {
n = s.nextNewLine
s.nextNewLine = -1
} else {
n = bytes.IndexByte(s.b, nChar)
}
if n < 0 {
s.err = errNeedMore
return false
}
isMultiLineValue := false
for {
if n+1 >= len(s.b) {
break
}
if s.b[n+1] != ' ' && s.b[n+1] != '\t' {
break
}
d := bytes.IndexByte(s.b[n+1:], nChar)
if d <= 0 {
break
} else if d == 1 && s.b[n+1] == rChar {
break
}
e := n + d + 1
if c := bytes.IndexByte(s.b[n+1:e], ':'); c >= 0 {
s.nextColon = c
s.nextNewLine = d - c - 1
break
}
isMultiLineValue = true
n = e
}
if n >= len(s.b) {
s.err = errNeedMore
return false
}
oldB := s.b
s.value = s.b[:n]
s.hLen += n + 1
s.b = s.b[n+1:]
if n > 0 && s.value[n-1] == rChar {
n--
}
for n > 0 && s.value[n-1] == ' ' {
n--
}
s.value = s.value[:n]
if isMultiLineValue {
s.value, s.b, s.hLen = normalizeHeaderValue(s.value, oldB, s.hLen)
}
return true
}
type headerValueScanner struct {
b []byte
value []byte
}
func (s *headerValueScanner) next() bool {
b := s.b
if len(b) == 0 {
return false
}
n := bytes.IndexByte(b, ',')
if n < 0 {
s.value = stripSpace(b)
s.b = b[len(b):]
return true
}
s.value = stripSpace(b[:n])
s.b = b[n+1:]
return true
}
func stripSpace(b []byte) []byte {
for len(b) > 0 && b[0] == ' ' {
b = b[1:]
}
for len(b) > 0 && b[len(b)-1] == ' ' {
b = b[:len(b)-1]
}
return b
}
func hasHeaderValue(s, value []byte) bool {
var vs headerValueScanner
vs.b = s
for vs.next() {
if caseInsensitiveCompare(vs.value, value) {
return true
}
}
return false
}
func nextLine(b []byte) ([]byte, []byte, error) {
nNext := bytes.IndexByte(b, nChar)
if nNext < 0 {
return nil, nil, errNeedMore
}
n := nNext
if n > 0 && b[n-1] == rChar {
n--
}
return b[:n], b[nNext+1:], nil
}
func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) {
kv.key = getHeaderKeyBytes(kv, key, disableNormalizing)
// https://tools.ietf.org/html/rfc7230#section-3.2.4
kv.value = append(kv.value[:0], value...)
kv.value = removeNewLines(kv.value)
}
func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte {
kv.key = append(kv.key[:0], key...)
normalizeHeaderKey(kv.key, disableNormalizing)
return kv.key
}
func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl int) {
nv = ov
length := len(ov)
if length <= 0 {
return
}
write := 0
shrunk := 0
lineStart := false
for read := 0; read < length; read++ {
c := ov[read]
if c == rChar || c == nChar {
shrunk++
if c == nChar {
lineStart = true
}
continue
} else if lineStart && c == '\t' {
c = ' '
} else {
lineStart = false
}
nv[write] = c
write++
}
nv = nv[:write]
copy(ob[write:], ob[write+shrunk:])
// Check if we need to skip \r\n or just \n
skip := 0
if ob[write] == rChar {
if ob[write+1] == nChar {
skip += 2
} else {
skip++
}
} else if ob[write] == nChar {
skip++
}
nb = ob[write+skip : len(ob)-shrunk]
nhl = headerLength - shrunk
return
}
func normalizeHeaderKey(b []byte, disableNormalizing bool) {
if disableNormalizing {
return
}
n := len(b)
if n == 0 {
return
}
b[0] = toUpperTable[b[0]]
for i := 1; i < n; i++ {
p := &b[i]
if *p == '-' {
i++
if i < n {
b[i] = toUpperTable[b[i]]
}
continue
}
*p = toLowerTable[*p]
}
}
// removeNewLines will replace `\r` and `\n` with an empty space
func removeNewLines(raw []byte) []byte {
// check if a `\r` is present and save the position.
// if no `\r` is found, check if a `\n` is present.
foundR := bytes.IndexByte(raw, rChar)
foundN := bytes.IndexByte(raw, nChar)
start := 0
if foundN != -1 {
if foundR > foundN {
start = foundN
} else if foundR != -1 {
start = foundR
}
} else if foundR != -1 {
start = foundR
} else {
return raw
}
for i := start; i < len(raw); i++ {
switch raw[i] {
case rChar, nChar:
raw[i] = ' '
default:
continue
}
}
return raw
}
// AppendNormalizedHeaderKey appends normalized header key (name) to dst
// and returns the resulting dst.
//
// Normalized header key starts with uppercase letter. The first letters
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
// * coNTENT-TYPe -> Content-Type
// * HOST -> Host
// * foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKey(dst []byte, key string) []byte {
dst = append(dst, key...)
normalizeHeaderKey(dst[len(dst)-len(key):], false)
return dst
}
// AppendNormalizedHeaderKeyBytes appends normalized header key (name) to dst
// and returns the resulting dst.
//
// Normalized header key starts with uppercase letter. The first letters
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
// * coNTENT-TYPe -> Content-Type
// * HOST -> Host
// * foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte {
return AppendNormalizedHeaderKey(dst, b2s(key))
}
var (
errNeedMore = errors.New("need more data: cannot find trailing lf")
errInvalidName = errors.New("invalid header name")
errSmallBuffer = errors.New("small read buffer. Increase ReadBufferSize")
)
// ErrNothingRead is returned when a keep-alive connection is closed,
// either because the remote closed it or because of a read timeout.
type ErrNothingRead struct {
error
}
// ErrSmallBuffer is returned when the provided buffer size is too small
// for reading request and/or response headers.
//
// ReadBufferSize value from Server or clients should reduce the number
// of such errors.
type ErrSmallBuffer struct {
error
}
func mustPeekBuffered(r *bufio.Reader) []byte {
buf, err := r.Peek(r.Buffered())
if len(buf) == 0 || err != nil {
panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err))
}
return buf
}
func mustDiscard(r *bufio.Reader, n int) {
if _, err := r.Discard(n); err != nil {
panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err))
}
}
fasthttp-1.31.0/header_regression_test.go 0000664 0000000 0000000 00000004635 14130360711 0020513 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"fmt"
"strings"
"testing"
)
func TestIssue28ResponseWithoutBodyNoContentType(t *testing.T) {
t.Parallel()
var r Response
// Empty response without content-type
s := r.String()
if strings.Contains(s, "Content-Type") {
t.Fatalf("unexpected Content-Type found in response header with empty body: %q", s)
}
// Explicitly set content-type
r.Header.SetContentType("foo/bar")
s = r.String()
if !strings.Contains(s, "Content-Type: foo/bar\r\n") {
t.Fatalf("missing explicitly set content-type for empty response: %q", s)
}
// Non-empty response.
r.Reset()
r.SetBodyString("foobar")
s = r.String()
if !strings.Contains(s, fmt.Sprintf("Content-Type: %s\r\n", defaultContentType)) {
t.Fatalf("missing default content-type for non-empty response: %q", s)
}
// Non-empty response with custom content-type.
r.Header.SetContentType("aaa/bbb")
s = r.String()
if !strings.Contains(s, "Content-Type: aaa/bbb\r\n") {
t.Fatalf("missing custom content-type: %q", s)
}
}
func TestIssue6RequestHeaderSetContentType(t *testing.T) {
t.Parallel()
testIssue6RequestHeaderSetContentType(t, MethodGet)
testIssue6RequestHeaderSetContentType(t, MethodPost)
testIssue6RequestHeaderSetContentType(t, MethodPut)
testIssue6RequestHeaderSetContentType(t, MethodPatch)
}
func testIssue6RequestHeaderSetContentType(t *testing.T, method string) {
contentType := "application/json"
contentLength := 123
var h RequestHeader
h.SetMethod(method)
h.SetRequestURI("http://localhost/test")
h.SetContentType(contentType)
h.SetContentLength(contentLength)
issue6VerifyRequestHeader(t, &h, contentType, contentLength, method)
s := h.String()
var h1 RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
issue6VerifyRequestHeader(t, &h1, contentType, contentLength, method)
}
func issue6VerifyRequestHeader(t *testing.T, h *RequestHeader, contentType string, contentLength int, method string) {
if string(h.ContentType()) != contentType {
t.Fatalf("unexpected content-type: %q. Expecting %q. method=%q", h.ContentType(), contentType, method)
}
if string(h.Method()) != method {
t.Fatalf("unexpected method: %q. Expecting %q", h.Method(), method)
}
if h.ContentLength() != contentLength {
t.Fatalf("unexpected content-length: %d. Expecting %d. method=%q", h.ContentLength(), contentLength, method)
}
}
fasthttp-1.31.0/header_test.go 0000664 0000000 0000000 00000231071 14130360711 0016247 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
)
func TestResponseHeaderAddContentType(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.Add("Content-Type", "test")
got := string(h.Peek("Content-Type"))
expected := "test"
if got != expected {
t.Errorf("expected %q got %q", expected, got)
}
var buf bytes.Buffer
h.WriteTo(&buf) //nolint:errcheck
if n := strings.Count(buf.String(), "Content-Type: "); n != 1 {
t.Errorf("Content-Type occurred %d times", n)
}
}
func TestResponseHeaderMultiLineValue(t *testing.T) {
t.Parallel()
s := "HTTP/1.1 200 OK\r\n" +
"EmptyValue1:\r\n" +
"Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" +
"Foo: Bar\r\n" +
"Multi-Line: one;\r\n two\r\n" +
"Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" +
"\r\n"
header := new(ResponseHeader)
if _, err := header.parse([]byte(s)); err != nil {
t.Fatalf("parse headers with multi-line values failed, %s", err)
}
response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil)
if err != nil {
t.Fatalf("parse response using net/http failed, %s", err)
}
for name, vals := range response.Header {
got := string(header.Peek(name))
want := vals[0]
if got != want {
t.Errorf("unexpected %s got: %q want: %q", name, got, want)
}
}
}
func TestResponseHeaderMultiLineName(t *testing.T) {
t.Parallel()
s := "HTTP/1.1 200 OK\r\n" +
"Host: golang.org\r\n" +
"Gopher-New-\r\n" +
" Line: This is a header on multiple lines\r\n" +
"\r\n"
header := new(ResponseHeader)
if _, err := header.parse([]byte(s)); err != errInvalidName {
m := make(map[string]string)
header.VisitAll(func(key, value []byte) {
m[string(key)] = string(value)
})
t.Errorf("expected error, got %q (%v)", m, err)
}
}
func TestResponseHeaderMultiLinePaniced(t *testing.T) {
t.Parallel()
// Input generated by fuzz testing that caused the parser to panic.
s, _ := base64.StdEncoding.DecodeString("aAEAIDoKKDoKICA6CgkKCiA6CiA6CgkpCiA6CiA6CiA6Cig6CiAgOgoJCgogOgogOgoJKQogOgogOgogOgogOgogOgoJOg86CiA6CiA6Cig6CiAyCg==")
header := new(RequestHeader)
header.parse(s) //nolint:errcheck
}
func TestResponseHeaderEmptyValueFromHeader(t *testing.T) {
t.Parallel()
var h1 ResponseHeader
h1.SetContentType("foo/bar")
h1.Set("EmptyValue1", "")
h1.Set("EmptyValue2", " ")
s := h1.String()
var h ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.ContentType()) != string(h1.ContentType()) {
t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), h1.ContentType())
}
v1 := h.Peek("EmptyValue1")
if len(v1) > 0 {
t.Fatalf("expecting empty value. Got %q", v1)
}
v2 := h.Peek("EmptyValue2")
if len(v2) > 0 {
t.Fatalf("expecting empty value. Got %q", v2)
}
}
func TestResponseHeaderEmptyValueFromString(t *testing.T) {
t.Parallel()
s := "HTTP/1.1 200 OK\r\n" +
"EmptyValue1:\r\n" +
"Content-Type: foo/bar\r\n" +
"EmptyValue2: \r\n" +
"\r\n"
var h ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.ContentType()) != "foo/bar" {
t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), "foo/bar")
}
v1 := h.Peek("EmptyValue1")
if len(v1) > 0 {
t.Fatalf("expecting empty value. Got %q", v1)
}
v2 := h.Peek("EmptyValue2")
if len(v2) > 0 {
t.Fatalf("expecting empty value. Got %q", v2)
}
}
func TestRequestHeaderEmptyValueFromHeader(t *testing.T) {
t.Parallel()
var h1 RequestHeader
h1.SetRequestURI("/foo/bar")
h1.SetHost("foobar")
h1.Set("EmptyValue1", "")
h1.Set("EmptyValue2", " ")
s := h1.String()
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != string(h1.Host()) {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), h1.Host())
}
v1 := h.Peek("EmptyValue1")
if len(v1) > 0 {
t.Fatalf("expecting empty value. Got %q", v1)
}
v2 := h.Peek("EmptyValue2")
if len(v2) > 0 {
t.Fatalf("expecting empty value. Got %q", v2)
}
}
func TestRequestHeaderEmptyValueFromString(t *testing.T) {
t.Parallel()
s := "GET / HTTP/1.1\r\n" +
"EmptyValue1:\r\n" +
"Host: foobar\r\n" +
"EmptyValue2: \r\n" +
"\r\n"
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
}
v1 := h.Peek("EmptyValue1")
if len(v1) > 0 {
t.Fatalf("expecting empty value. Got %q", v1)
}
v2 := h.Peek("EmptyValue2")
if len(v2) > 0 {
t.Fatalf("expecting empty value. Got %q", v2)
}
}
func TestRequestRawHeaders(t *testing.T) {
t.Parallel()
kvs := "hOsT: foobar\r\n" +
"value: b\r\n" +
"\r\n"
t.Run("normalized", func(t *testing.T) {
s := "GET / HTTP/1.1\r\n" + kvs
exp := kvs
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
}
v2 := h.Peek("Value")
if !bytes.Equal(v2, []byte{'b'}) {
t.Fatalf("expecting non empty value. Got %q", v2)
}
if raw := h.RawHeaders(); string(raw) != exp {
t.Fatalf("expected header %q, got %q", exp, raw)
}
})
for _, n := range []int{0, 1, 4, 8} {
t.Run(fmt.Sprintf("post-%dk", n), func(t *testing.T) {
l := 1024 * n
body := make([]byte, l)
for i := range body {
body[i] = 'a'
}
cl := fmt.Sprintf("Content-Length: %d\r\n", l)
s := "POST / HTTP/1.1\r\n" + cl + kvs + string(body)
exp := cl + kvs
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
}
v2 := h.Peek("Value")
if !bytes.Equal(v2, []byte{'b'}) {
t.Fatalf("expecting non empty value. Got %q", v2)
}
if raw := h.RawHeaders(); string(raw) != exp {
t.Fatalf("expected header %q, got %q", exp, raw)
}
})
}
t.Run("http10", func(t *testing.T) {
s := "GET / HTTP/1.0\r\n" + kvs
exp := kvs
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
}
v2 := h.Peek("Value")
if !bytes.Equal(v2, []byte{'b'}) {
t.Fatalf("expecting non empty value. Got %q", v2)
}
if raw := h.RawHeaders(); string(raw) != exp {
t.Fatalf("expected header %q, got %q", exp, raw)
}
})
t.Run("no-kvs", func(t *testing.T) {
s := "GET / HTTP/1.1\r\n\r\n"
exp := ""
var h RequestHeader
h.DisableNormalizing()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h.Host()) != "" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "")
}
v1 := h.Peek("NoKey")
if len(v1) > 0 {
t.Fatalf("expecting empty value. Got %q", v1)
}
if raw := h.RawHeaders(); string(raw) != exp {
t.Fatalf("expected header %q, got %q", exp, raw)
}
})
}
func TestRequestHeaderSetCookieWithSpecialChars(t *testing.T) {
t.Parallel()
var h RequestHeader
h.Set("Cookie", "ID&14")
s := h.String()
if !strings.Contains(s, "Cookie: ID&14") {
t.Fatalf("Missing cookie in request header: [%s]", s)
}
var h1 RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
cookie := h1.Peek(HeaderCookie)
if string(cookie) != "ID&14" {
t.Fatalf("unexpected cooke: %q. Expecting %q", cookie, "ID&14")
}
cookie = h1.Cookie("")
if string(cookie) != "ID&14" {
t.Fatalf("unexpected cooke: %q. Expecting %q", cookie, "ID&14")
}
}
func TestResponseHeaderDefaultStatusCode(t *testing.T) {
t.Parallel()
var h ResponseHeader
statusCode := h.StatusCode()
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
}
func TestResponseHeaderDelClientCookie(t *testing.T) {
t.Parallel()
cookieName := "foobar"
var h ResponseHeader
c := AcquireCookie()
c.SetKey(cookieName)
c.SetValue("aasdfsdaf")
h.SetCookie(c)
h.DelClientCookieBytes([]byte(cookieName))
if !h.Cookie(c) {
t.Fatalf("expecting cookie %q", c.Key())
}
if !c.Expire().Equal(CookieExpireDelete) {
t.Fatalf("unexpected cookie expiration time: %s. Expecting %s", c.Expire(), CookieExpireDelete)
}
if len(c.Value()) > 0 {
t.Fatalf("unexpected cookie value: %q. Expecting empty value", c.Value())
}
ReleaseCookie(c)
}
func TestResponseHeaderAdd(t *testing.T) {
t.Parallel()
m := make(map[string]struct{})
var h ResponseHeader
h.Add("aaa", "bbb")
h.Add("content-type", "xxx")
m["bbb"] = struct{}{}
m["xxx"] = struct{}{}
for i := 0; i < 10; i++ {
v := fmt.Sprintf("%d", i)
h.Add("Foo-Bar", v)
m[v] = struct{}{}
}
if h.Len() != 12 {
t.Fatalf("unexpected header len %d. Expecting 12", h.Len())
}
h.VisitAll(func(k, v []byte) {
switch string(k) {
case "Aaa", "Foo-Bar", "Content-Type":
if _, ok := m[string(v)]; !ok {
t.Fatalf("unexpected value found %q. key %q", v, k)
}
delete(m, string(v))
default:
t.Fatalf("unexpected key found: %q", k)
}
})
if len(m) > 0 {
t.Fatalf("%d headers are missed", len(m))
}
s := h.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var h1 ResponseHeader
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
h.VisitAll(func(k, v []byte) {
switch string(k) {
case "Aaa", "Foo-Bar", "Content-Type":
m[string(v)] = struct{}{}
default:
t.Fatalf("unexpected key found: %q", k)
}
})
if len(m) != 12 {
t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m))
}
}
func TestRequestHeaderAdd(t *testing.T) {
t.Parallel()
m := make(map[string]struct{})
var h RequestHeader
h.Add("aaa", "bbb")
h.Add("user-agent", "xxx")
m["bbb"] = struct{}{}
m["xxx"] = struct{}{}
for i := 0; i < 10; i++ {
v := fmt.Sprintf("%d", i)
h.Add("Foo-Bar", v)
m[v] = struct{}{}
}
if h.Len() != 12 {
t.Fatalf("unexpected header len %d. Expecting 12", h.Len())
}
h.VisitAll(func(k, v []byte) {
switch string(k) {
case "Aaa", "Foo-Bar", "User-Agent":
if _, ok := m[string(v)]; !ok {
t.Fatalf("unexpected value found %q. key %q", v, k)
}
delete(m, string(v))
default:
t.Fatalf("unexpected key found: %q", k)
}
})
if len(m) > 0 {
t.Fatalf("%d headers are missed", len(m))
}
s := h.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var h1 RequestHeader
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
h.VisitAll(func(k, v []byte) {
switch string(k) {
case "Aaa", "Foo-Bar", "User-Agent":
m[string(v)] = struct{}{}
default:
t.Fatalf("unexpected key found: %q", k)
}
})
if len(m) != 12 {
t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m))
}
s1 := h1.String()
if s != s1 {
t.Fatalf("unexpected headers %q. Expecting %q", s1, s)
}
}
func TestHasHeaderValue(t *testing.T) {
t.Parallel()
testHasHeaderValue(t, "foobar", "foobar", true)
testHasHeaderValue(t, "foobar", "foo", false)
testHasHeaderValue(t, "foobar", "bar", false)
testHasHeaderValue(t, "keep-alive, Upgrade", "keep-alive", true)
testHasHeaderValue(t, "keep-alive , Upgrade", "Upgrade", true)
testHasHeaderValue(t, "keep-alive, Upgrade", "Upgrade-foo", false)
testHasHeaderValue(t, "keep-alive, Upgrade", "Upgr", false)
testHasHeaderValue(t, "foo , bar, baz ,", "foo", true)
testHasHeaderValue(t, "foo , bar, baz ,", "bar", true)
testHasHeaderValue(t, "foo , bar, baz ,", "baz", true)
testHasHeaderValue(t, "foo , bar, baz ,", "ba", false)
testHasHeaderValue(t, "foo, ", "", true)
testHasHeaderValue(t, "foo", "", false)
}
func testHasHeaderValue(t *testing.T, s, value string, has bool) {
ok := hasHeaderValue([]byte(s), []byte(value))
if ok != has {
t.Fatalf("unexpected hasHeaderValue(%q, %q)=%v. Expecting %v", s, value, ok, has)
}
}
func TestRequestHeaderDel(t *testing.T) {
t.Parallel()
var h RequestHeader
h.Set("Foo-Bar", "baz")
h.Set("aaa", "bbb")
h.Set(HeaderConnection, "keep-alive")
h.Set("Content-Type", "aaa")
h.Set(HeaderHost, "aaabbb")
h.Set("User-Agent", "asdfas")
h.Set("Content-Length", "1123")
h.Set("Cookie", "foobar=baz")
h.Del("foo-bar")
h.Del("connection")
h.DelBytes([]byte("content-type"))
h.Del("Host")
h.Del("user-agent")
h.Del("content-length")
h.Del("cookie")
hv := h.Peek("aaa")
if string(hv) != "bbb" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb")
}
hv = h.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderConnection)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderContentType)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderHost)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderUserAgent)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderContentLength)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderCookie)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
cv := h.Cookie("foobar")
if len(cv) > 0 {
t.Fatalf("unexpected cookie obtianed: %q", cv)
}
if h.ContentLength() != 0 {
t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength())
}
}
func TestResponseHeaderDel(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.Set("Foo-Bar", "baz")
h.Set("aaa", "bbb")
h.Set(HeaderConnection, "keep-alive")
h.Set(HeaderContentType, "aaa")
h.Set(HeaderServer, "aaabbb")
h.Set(HeaderContentLength, "1123")
var c Cookie
c.SetKey("foo")
c.SetValue("bar")
h.SetCookie(&c)
h.Del("foo-bar")
h.Del("connection")
h.DelBytes([]byte("content-type"))
h.Del(HeaderServer)
h.Del("content-length")
h.Del("set-cookie")
hv := h.Peek("aaa")
if string(hv) != "bbb" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb")
}
hv = h.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("non-zero header value: %q", hv)
}
hv = h.Peek(HeaderConnection)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderContentType)
if string(hv) != string(defaultContentType) {
t.Fatalf("unexpected content-type: %q. Expecting %q", hv, defaultContentType)
}
hv = h.Peek(HeaderServer)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
hv = h.Peek(HeaderContentLength)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
if h.Cookie(&c) {
t.Fatalf("unexpected cookie obtianed: %q", &c)
}
if h.ContentLength() != 0 {
t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength())
}
}
func TestAppendNormalizedHeaderKeyBytes(t *testing.T) {
t.Parallel()
testAppendNormalizedHeaderKeyBytes(t, "", "")
testAppendNormalizedHeaderKeyBytes(t, "Content-Type", "Content-Type")
testAppendNormalizedHeaderKeyBytes(t, "foO-bAr-BAZ", "Foo-Bar-Baz")
}
func testAppendNormalizedHeaderKeyBytes(t *testing.T, key, expectedKey string) {
buf := []byte("foobar")
result := AppendNormalizedHeaderKeyBytes(buf, []byte(key))
normalizedKey := result[len(buf):]
if string(normalizedKey) != expectedKey {
t.Fatalf("unexpected normalized key %q. Expecting %q", normalizedKey, expectedKey)
}
}
func TestRequestHeaderHTTP10ConnectionClose(t *testing.T) {
t.Parallel()
s := "GET / HTTP/1.0\r\nHost: foobar\r\n\r\n"
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !h.ConnectionClose() {
t.Fatalf("expecting 'Connection: close' request header")
}
}
func TestRequestHeaderHTTP10ConnectionKeepAlive(t *testing.T) {
t.Parallel()
s := "GET / HTTP/1.0\r\nHost: foobar\r\nConnection: keep-alive\r\n\r\n"
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if h.ConnectionClose() {
t.Fatalf("unexpected 'Connection: close' request header")
}
}
func TestBufferSnippet(t *testing.T) {
t.Parallel()
testBufferSnippet(t, "", `""`)
testBufferSnippet(t, "foobar", `"foobar"`)
b := string(createFixedBody(199))
bExpected := fmt.Sprintf("%q", b)
testBufferSnippet(t, b, bExpected)
for i := 0; i < 10; i++ {
b += "foobar"
bExpected = fmt.Sprintf("%q", b)
testBufferSnippet(t, b, bExpected)
}
b = string(createFixedBody(400))
bExpected = fmt.Sprintf("%q", b)
testBufferSnippet(t, b, bExpected)
for i := 0; i < 10; i++ {
b += "sadfqwer"
bExpected = fmt.Sprintf("%q...%q", b[:200], b[len(b)-200:])
testBufferSnippet(t, b, bExpected)
}
}
func testBufferSnippet(t *testing.T, buf, expectedSnippet string) {
snippet := bufferSnippet([]byte(buf))
if snippet != expectedSnippet {
t.Fatalf("unexpected snippet %s. Expecting %s", snippet, expectedSnippet)
}
}
func TestResponseHeaderTrailingCRLFSuccess(t *testing.T) {
t.Parallel()
trailingCRLF := "\r\n\r\n\r\n"
s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF
var r ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// try reading the trailing CRLF. It must return EOF
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
}
func TestResponseHeaderTrailingCRLFError(t *testing.T) {
t.Parallel()
trailingCRLF := "\r\nerror\r\n\r\n"
s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF
var r ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// try reading the trailing CRLF. It must return EOF
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("unexpected error: %s", err)
}
}
func TestRequestHeaderTrailingCRLFSuccess(t *testing.T) {
t.Parallel()
trailingCRLF := "\r\n\r\n\r\n"
s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF
var r RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// try reading the trailing CRLF. It must return EOF
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
}
func TestRequestHeaderTrailingCRLFError(t *testing.T) {
t.Parallel()
trailingCRLF := "\r\nerror\r\n\r\n"
s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF
var r RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// try reading the trailing CRLF. It must return EOF
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("unexpected error: %s", err)
}
}
func TestRequestHeaderReadEOF(t *testing.T) {
t.Parallel()
var r RequestHeader
br := bufio.NewReader(&bytes.Buffer{})
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
// incomplete request header mustn't return io.EOF
br = bufio.NewReader(bytes.NewBufferString("GET "))
err = r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("expecting non-EOF error")
}
}
func TestResponseHeaderReadEOF(t *testing.T) {
t.Parallel()
var r ResponseHeader
br := bufio.NewReader(&bytes.Buffer{})
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
// incomplete response header mustn't return io.EOF
br = bufio.NewReader(bytes.NewBufferString("HTTP/1.1 "))
err = r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("expecting non-EOF error")
}
}
func TestResponseHeaderOldVersion(t *testing.T) {
t.Parallel()
var h ResponseHeader
s := "HTTP/1.0 200 OK\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345"
s += "HTTP/1.0 200 OK\r\nContent-Length: 2\r\nContent-Type: ass\r\nConnection: keep-alive\r\n\r\n42"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !h.ConnectionClose() {
t.Fatalf("expecting 'Connection: close' for the response with old http protocol")
}
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if h.ConnectionClose() {
t.Fatalf("unexpected 'Connection: close' for keep-alive response with old http protocol")
}
}
func TestRequestHeaderSetByteRange(t *testing.T) {
t.Parallel()
testRequestHeaderSetByteRange(t, 0, 10, "bytes=0-10")
testRequestHeaderSetByteRange(t, 123, -1, "bytes=123-")
testRequestHeaderSetByteRange(t, -234, 58349, "bytes=-234")
}
func testRequestHeaderSetByteRange(t *testing.T, startPos, endPos int, expectedV string) {
var h RequestHeader
h.SetByteRange(startPos, endPos)
v := h.Peek(HeaderRange)
if string(v) != expectedV {
t.Fatalf("unexpected range: %q. Expecting %q. startPos=%d, endPos=%d", v, expectedV, startPos, endPos)
}
}
func TestResponseHeaderSetContentRange(t *testing.T) {
t.Parallel()
testResponseHeaderSetContentRange(t, 0, 0, 1, "bytes 0-0/1")
testResponseHeaderSetContentRange(t, 123, 456, 789, "bytes 123-456/789")
}
func testResponseHeaderSetContentRange(t *testing.T, startPos, endPos, contentLength int, expectedV string) {
var h ResponseHeader
h.SetContentRange(startPos, endPos, contentLength)
v := h.Peek(HeaderContentRange)
if string(v) != expectedV {
t.Fatalf("unexpected content-range: %q. Expecting %q. startPos=%d, endPos=%d, contentLength=%d",
v, expectedV, startPos, endPos, contentLength)
}
}
func TestRequestHeaderHasAcceptEncoding(t *testing.T) {
t.Parallel()
testRequestHeaderHasAcceptEncoding(t, "", "gzip", false)
testRequestHeaderHasAcceptEncoding(t, "gzip", "sdhc", false)
testRequestHeaderHasAcceptEncoding(t, "deflate", "deflate", true)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzi", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "dhc", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdh", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "zip", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flat", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flate", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "def", false)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzip", true)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "deflate", true)
testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdhc", true)
}
func testRequestHeaderHasAcceptEncoding(t *testing.T, ae, v string, resultExpected bool) {
var h RequestHeader
h.Set(HeaderAcceptEncoding, ae)
result := h.HasAcceptEncoding(v)
if result != resultExpected {
t.Fatalf("unexpected result in HasAcceptEncoding(%q, %q): %v. Expecting %v", ae, v, result, resultExpected)
}
}
func TestRequestMultipartFormBoundary(t *testing.T) {
t.Parallel()
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=foobar\r\n\r\n", "foobar")
// incorrect content-type
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: foo/bar\r\n\r\n", "")
// empty boundary
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\r\n\r\n", "")
// missing boundary
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data\r\n\r\n", "")
// boundary after other content-type params
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; foo=bar; boundary=--aaabb \r\n\r\n", "--aaabb")
// quoted boundary
testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\"foobar\"\r\n\r\n", "foobar")
var h RequestHeader
h.SetMultipartFormBoundary("foobarbaz")
b := h.MultipartFormBoundary()
if string(b) != "foobarbaz" {
t.Fatalf("unexpected boundary %q. Expecting %q", b, "foobarbaz")
}
}
func testRequestMultipartFormBoundary(t *testing.T, s, boundary string) {
var h RequestHeader
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s. s=%q, boundary=%q", err, s, boundary)
}
b := h.MultipartFormBoundary()
if string(b) != boundary {
t.Fatalf("unexpected boundary %q. Expecting %q. s=%q", b, boundary, s)
}
}
func TestResponseHeaderConnectionUpgrade(t *testing.T) {
t.Parallel()
testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, HTTP2-Settings\r\n\r\n",
true, true)
testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: keep-alive, Upgrade\r\n\r\n",
true, true)
// non-http/1.1 protocol has 'connection: close' by default, which also disables 'connection: upgrade'
testResponseHeaderConnectionUpgrade(t, "HTTP/1.0 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, HTTP2-Settings\r\n\r\n",
false, false)
// explicit keep-alive for non-http/1.1, so 'connection: upgrade' works
testResponseHeaderConnectionUpgrade(t, "HTTP/1.0 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, keep-alive\r\n\r\n",
true, true)
// implicit keep-alive for http/1.1
testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n", false, true)
// no content-length, so 'connection: close' is assumed
testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\n\r\n", false, false)
}
func testResponseHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKeepAlive bool) {
var h ResponseHeader
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s. Response header %q", err, s)
}
upgrade := h.ConnectionUpgrade()
if upgrade != isUpgrade {
t.Fatalf("unexpected 'connection: upgrade' when parsing response header: %v. Expecting %v. header %q. v=%q",
upgrade, isUpgrade, s, h.Peek("Connection"))
}
keepAlive := !h.ConnectionClose()
if keepAlive != isKeepAlive {
t.Fatalf("unexpected 'connection: keep-alive' when parsing response header: %v. Expecting %v. header %q. v=%q",
keepAlive, isKeepAlive, s, &h)
}
}
func TestRequestHeaderConnectionUpgrade(t *testing.T) {
t.Parallel()
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: Upgrade, HTTP2-Settings\r\nHost: foobar.com\r\n\r\n",
true, true)
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: keep-alive,Upgrade\r\nHost: foobar.com\r\n\r\n",
true, true)
// non-http/1.1 has 'connection: close' by default, which resets 'connection: upgrade'
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.0\r\nConnection: Upgrade, HTTP2-Settings\r\nHost: foobar.com\r\n\r\n",
false, false)
// explicit 'connection: keep-alive' in non-http/1.1
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.0\r\nConnection: foo, Upgrade, keep-alive\r\nHost: foobar.com\r\n\r\n",
true, true)
// no upgrade
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: Upgradess, foobar\r\nHost: foobar.com\r\n\r\n",
false, true)
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nHost: foobar.com\r\n\r\n",
false, true)
// explicit connection close
testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: close\r\nHost: foobar.com\r\n\r\n",
false, false)
}
func testRequestHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKeepAlive bool) {
var h RequestHeader
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s. Request header %q", err, s)
}
upgrade := h.ConnectionUpgrade()
if upgrade != isUpgrade {
t.Fatalf("unexpected 'connection: upgrade' when parsing request header: %v. Expecting %v. header %q",
upgrade, isUpgrade, s)
}
keepAlive := !h.ConnectionClose()
if keepAlive != isKeepAlive {
t.Fatalf("unexpected 'connection: keep-alive' when parsing request header: %v. Expecting %v. header %q",
keepAlive, isKeepAlive, s)
}
}
func TestRequestHeaderProxyWithCookie(t *testing.T) {
t.Parallel()
// Proxy request header (read it, then write it without touching any headers).
var h RequestHeader
r := bytes.NewBufferString("GET /foo HTTP/1.1\r\nFoo: bar\r\nHost: aaa.com\r\nCookie: foo=bar; bazzz=aaaaaaa; x=y\r\nCookie: aqqqqq=123\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var h1 RequestHeader
br.Reset(w)
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(h1.RequestURI()) != "/foo" {
t.Fatalf("unexpected requestURI: %q. Expecting %q", h1.RequestURI(), "/foo")
}
if string(h1.Host()) != "aaa.com" {
t.Fatalf("unexpected host: %q. Expecting %q", h1.Host(), "aaa.com")
}
if string(h1.Peek("Foo")) != "bar" {
t.Fatalf("unexpected Foo: %q. Expecting %q", h1.Peek("Foo"), "bar")
}
if string(h1.Cookie("foo")) != "bar" {
t.Fatalf("unexpected coookie foo=%q. Expecting %q", h1.Cookie("foo"), "bar")
}
if string(h1.Cookie("bazzz")) != "aaaaaaa" {
t.Fatalf("unexpected cookie bazzz=%q. Expecting %q", h1.Cookie("bazzz"), "aaaaaaa")
}
if string(h1.Cookie("x")) != "y" {
t.Fatalf("unexpected cookie x=%q. Expecting %q", h1.Cookie("x"), "y")
}
if string(h1.Cookie("aqqqqq")) != "123" {
t.Fatalf("unexpected cookie aqqqqq=%q. Expecting %q", h1.Cookie("aqqqqq"), "123")
}
}
func TestResponseHeaderFirstByteReadEOF(t *testing.T) {
t.Parallel()
var h ResponseHeader
r := &errorReader{fmt.Errorf("non-eof error")}
br := bufio.NewReader(r)
err := h.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error %s. Expecting %s", err, io.EOF)
}
}
type errorReader struct {
err error
}
func (r *errorReader) Read(p []byte) (int, error) {
return 0, r.err
}
func TestRequestHeaderEmptyMethod(t *testing.T) {
t.Parallel()
var h RequestHeader
if !h.IsGet() {
t.Fatalf("empty method must be equivalent to GET")
}
}
func TestResponseHeaderHTTPVer(t *testing.T) {
t.Parallel()
// non-http/1.1
testResponseHeaderHTTPVer(t, "HTTP/1.0 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true)
testResponseHeaderHTTPVer(t, "HTTP/0.9 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true)
testResponseHeaderHTTPVer(t, "foobar 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true)
// http/1.1
testResponseHeaderHTTPVer(t, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", false)
}
func TestRequestHeaderHTTPVer(t *testing.T) {
t.Parallel()
// non-http/1.1
testRequestHeaderHTTPVer(t, "GET / HTTP/1.0\r\nHost: aa.com\r\n\r\n", true)
testRequestHeaderHTTPVer(t, "GET / HTTP/0.9\r\nHost: aa.com\r\n\r\n", true)
testRequestHeaderHTTPVer(t, "GET / foobar\r\nHost: aa.com\r\n\r\n", true)
// empty http version
testRequestHeaderHTTPVer(t, "GET /\r\nHost: aaa.com\r\n\r\n", true)
testRequestHeaderHTTPVer(t, "GET / \r\nHost: aaa.com\r\n\r\n", true)
// http/1.1
testRequestHeaderHTTPVer(t, "GET / HTTP/1.1\r\nHost: a.com\r\n\r\n", false)
}
func testResponseHeaderHTTPVer(t *testing.T, s string, connectionClose bool) {
var h ResponseHeader
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s. response=%q", err, s)
}
if h.ConnectionClose() != connectionClose {
t.Fatalf("unexpected connectionClose %v. Expecting %v. response=%q", h.ConnectionClose(), connectionClose, s)
}
}
func testRequestHeaderHTTPVer(t *testing.T, s string, connectionClose bool) {
var h RequestHeader
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("unexpected error: %s. request=%q", err, s)
}
if h.ConnectionClose() != connectionClose {
t.Fatalf("unexpected connectionClose %v. Expecting %v. request=%q", h.ConnectionClose(), connectionClose, s)
}
}
func TestResponseHeaderCopyTo(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.Set(HeaderSetCookie, "foo=bar")
h.Set(HeaderContentType, "foobar")
h.Set("AAA-BBB", "aaaa")
var h1 ResponseHeader
h.CopyTo(&h1)
if !bytes.Equal(h1.Peek("Set-cookie"), h.Peek("Set-Cookie")) {
t.Fatalf("unexpected cookie %q. Expected %q", h1.Peek("set-cookie"), h.Peek("set-cookie"))
}
if !bytes.Equal(h1.Peek(HeaderContentType), h.Peek(HeaderContentType)) {
t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type"))
}
if !bytes.Equal(h1.Peek("aaa-bbb"), h.Peek("AAA-BBB")) {
t.Fatalf("unexpected aaa-bbb %q. Expected %q", h1.Peek("aaa-bbb"), h.Peek("aaa-bbb"))
}
// flush buf
h.bufKV = argsKV{}
h1.bufKV = argsKV{}
if !reflect.DeepEqual(h, h1) { //nolint:govet
t.Fatalf("ResponseHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet
}
}
func TestRequestHeaderCopyTo(t *testing.T) {
t.Parallel()
var h RequestHeader
h.Set(HeaderCookie, "aa=bb; cc=dd")
h.Set(HeaderContentType, "foobar")
h.Set(HeaderHost, "aaaa")
h.Set("aaaxxx", "123")
var h1 RequestHeader
h.CopyTo(&h1)
if !bytes.Equal(h1.Peek("cookie"), h.Peek(HeaderCookie)) {
t.Fatalf("unexpected cookie after copying: %q. Expected %q", h1.Peek("cookie"), h.Peek("cookie"))
}
if !bytes.Equal(h1.Peek("content-type"), h.Peek(HeaderContentType)) {
t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type"))
}
if !bytes.Equal(h1.Peek("host"), h.Peek("host")) {
t.Fatalf("unexpected host %q. Expected %q", h1.Peek("host"), h.Peek("host"))
}
if !bytes.Equal(h1.Peek("aaaxxx"), h.Peek("aaaxxx")) {
t.Fatalf("unexpected aaaxxx %q. Expected %q", h1.Peek("aaaxxx"), h.Peek("aaaxxx"))
}
// flush buf
h.bufKV = argsKV{}
h1.bufKV = argsKV{}
if !reflect.DeepEqual(h, h1) { //nolint:govet
t.Fatalf("RequestHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet
}
}
func TestResponseContentTypeNoDefaultNotEmpty(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.SetNoDefaultContentType(true)
h.SetContentLength(5)
headers := h.String()
if strings.Contains(headers, "Content-Type: \r\n") {
t.Fatalf("ResponseContentTypeNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet
}
}
func TestRequestContentTypeDefaultNotEmpty(t *testing.T) {
t.Parallel()
var h RequestHeader
h.SetMethod(MethodPost)
h.SetContentLength(5)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if string(h1.contentType) != "application/octet-stream" {
t.Fatalf("unexpected Content-Type %q. Expecting %q", h1.contentType, "application/octet-stream")
}
}
func TestResponseDateNoDefaultNotEmpty(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.noDefaultDate = true
headers := h.String()
if strings.Contains(headers, "\r\nDate: ") {
t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet
}
}
func TestRequestHeaderConnectionClose(t *testing.T) {
t.Parallel()
var h RequestHeader
h.Set(HeaderConnection, "close")
h.Set(HeaderHost, "foobar")
if !h.ConnectionClose() {
t.Fatalf("connection: close not set")
}
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := h.Write(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var h1 RequestHeader
br := bufio.NewReader(&w)
if err := h1.Read(br); err != nil {
t.Fatalf("error when reading request header: %s", err)
}
if !h1.ConnectionClose() {
t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose())
}
if string(h1.Peek(HeaderConnection)) != "close" {
t.Fatalf("unexpected connection value: %q. Expecting %q", h.Peek("Connection"), "close")
}
}
func TestRequestHeaderSetCookie(t *testing.T) {
t.Parallel()
var h RequestHeader
h.Set("Cookie", "foo=bar; baz=aaa")
h.Set("cOOkie", "xx=yyy")
if string(h.Cookie("foo")) != "bar" {
t.Fatalf("Unexpected cookie %q. Expecting %q", h.Cookie("foo"), "bar")
}
if string(h.Cookie("baz")) != "aaa" {
t.Fatalf("Unexpected cookie %q. Expecting %q", h.Cookie("baz"), "aaa")
}
if string(h.Cookie("xx")) != "yyy" {
t.Fatalf("unexpected cookie %q. Expecting %q", h.Cookie("xx"), "yyy")
}
}
func TestResponseHeaderSetCookie(t *testing.T) {
t.Parallel()
var h ResponseHeader
h.Set("set-cookie", "foo=bar; path=/aa/bb; domain=aaa.com")
h.Set(HeaderSetCookie, "aaaaa=bxx")
var c Cookie
c.SetKey("foo")
if !h.Cookie(&c) {
t.Fatalf("cannot obtain %q cookie", c.Key())
}
if string(c.Value()) != "bar" {
t.Fatalf("unexpected cookie value %q. Expected %q", c.Value(), "bar")
}
if string(c.Path()) != "/aa/bb" {
t.Fatalf("unexpected cookie path %q. Expected %q", c.Path(), "/aa/bb")
}
if string(c.Domain()) != "aaa.com" {
t.Fatalf("unexpected cookie domain %q. Expected %q", c.Domain(), "aaa.com")
}
c.SetKey("aaaaa")
if !h.Cookie(&c) {
t.Fatalf("cannot obtain %q cookie", c.Key())
}
if string(c.Value()) != "bxx" {
t.Fatalf("unexpected cookie value %q. Expecting %q", c.Value(), "bxx")
}
}
func TestResponseHeaderVisitAll(t *testing.T) {
t.Parallel()
var h ResponseHeader
r := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\nSet-Cookie: aa=bb; path=/foo/bar\r\nSet-Cookie: ccc\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if h.Len() != 4 {
t.Fatalf("Unexpected number of headers: %d. Expected 4", h.Len())
}
contentLengthCount := 0
contentTypeCount := 0
cookieCount := 0
h.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
switch k {
case HeaderContentLength:
if v != string(h.Peek(k)) {
t.Fatalf("unexpected content-length: %q. Expecting %q", v, h.Peek(k))
}
contentLengthCount++
case HeaderContentType:
if v != string(h.Peek(k)) {
t.Fatalf("Unexpected content-type: %q. Expected %q", v, h.Peek(k))
}
contentTypeCount++
case HeaderSetCookie:
if cookieCount == 0 && v != "aa=bb; path=/foo/bar" {
t.Fatalf("unexpected cookie header: %q. Expected %q", v, "aa=bb; path=/foo/bar")
}
if cookieCount == 1 && v != "ccc" {
t.Fatalf("unexpected cookie header: %q. Expected %q", v, "ccc")
}
cookieCount++
default:
t.Fatalf("unexpected header %q=%q", k, v)
}
})
if contentLengthCount != 1 {
t.Fatalf("unexpected number of content-length headers: %d. Expected 1", contentLengthCount)
}
if contentTypeCount != 1 {
t.Fatalf("unexpected number of content-type headers: %d. Expected 1", contentTypeCount)
}
if cookieCount != 2 {
t.Fatalf("unexpected number of cookie header: %d. Expected 2", cookieCount)
}
}
func TestRequestHeaderVisitAll(t *testing.T) {
t.Parallel()
var h RequestHeader
r := bytes.NewBufferString("GET / HTTP/1.1\r\nHost: aa.com\r\nXX: YYY\r\nXX: ZZ\r\nCookie: a=b; c=d\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if h.Len() != 4 {
t.Fatalf("Unexpected number of header: %d. Expected 4", h.Len())
}
hostCount := 0
xxCount := 0
cookieCount := 0
h.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
switch k {
case HeaderHost:
if v != string(h.Peek(k)) {
t.Fatalf("Unexpected host value %q. Expected %q", v, h.Peek(k))
}
hostCount++
case "Xx":
if xxCount == 0 && v != "YYY" {
t.Fatalf("Unexpected value %q. Expected %q", v, "YYY")
}
if xxCount == 1 && v != "ZZ" {
t.Fatalf("Unexpected value %q. Expected %q", v, "ZZ")
}
xxCount++
case HeaderCookie:
if v != "a=b; c=d" {
t.Fatalf("Unexpected cookie %q. Expected %q", v, "a=b; c=d")
}
cookieCount++
default:
t.Fatalf("Unexpected header %q=%q", k, v)
}
})
if hostCount != 1 {
t.Fatalf("Unexpected number of host headers detected %d. Expected 1", hostCount)
}
if xxCount != 2 {
t.Fatalf("Unexpected number of xx headers detected %d. Expected 2", xxCount)
}
if cookieCount != 1 {
t.Fatalf("Unexpected number of cookie headers %d. Expected 1", cookieCount)
}
}
func TestResponseHeaderVisitAllInOrder(t *testing.T) {
t.Parallel()
var h RequestHeader
r := bytes.NewBufferString("GET / HTTP/1.1\r\nContent-Type: aa\r\nCookie: a=b\r\nHost: example.com\r\nUser-Agent: xxx\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if h.Len() != 4 {
t.Fatalf("Unexpected number of headers: %d. Expected 4", h.Len())
}
order := []string{
HeaderContentType,
HeaderCookie,
HeaderHost,
HeaderUserAgent,
}
values := []string{
"aa",
"a=b",
"example.com",
"xxx",
}
h.VisitAllInOrder(func(key, value []byte) {
if len(order) == 0 {
t.Fatalf("no more headers expected, got %q", key)
}
if order[0] != string(key) {
t.Fatalf("expected header %q got %q", order[0], key)
}
if values[0] != string(value) {
t.Fatalf("expected header value %q got %q", values[0], value)
}
order = order[1:]
values = values[1:]
})
}
func TestResponseHeaderCookie(t *testing.T) {
t.Parallel()
var h ResponseHeader
var c Cookie
c.SetKey("foobar")
c.SetValue("aaa")
h.SetCookie(&c)
c.SetKey("йцук")
c.SetDomain("foobar.com")
h.SetCookie(&c)
c.Reset()
c.SetKey("foobar")
if !h.Cookie(&c) {
t.Fatalf("Cannot find cookie %q", c.Key())
}
var expectedC1 Cookie
expectedC1.SetKey("foobar")
expectedC1.SetValue("aaa")
if !equalCookie(&expectedC1, &c) {
t.Fatalf("unexpected cookie\n%#v\nExpected\n%#v\n", &c, &expectedC1)
}
c.SetKey("йцук")
if !h.Cookie(&c) {
t.Fatalf("cannot find cookie %q", c.Key())
}
var expectedC2 Cookie
expectedC2.SetKey("йцук")
expectedC2.SetValue("aaa")
expectedC2.SetDomain("foobar.com")
if !equalCookie(&expectedC2, &c) {
t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2)
}
h.VisitAllCookie(func(key, value []byte) {
var cc Cookie
if err := cc.ParseBytes(value); err != nil {
t.Fatal(err)
}
if !bytes.Equal(key, cc.Key()) {
t.Fatalf("Unexpected cookie key %q. Expected %q", key, cc.Key())
}
switch {
case bytes.Equal(key, []byte("foobar")):
if !equalCookie(&expectedC1, &cc) {
t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC1)
}
case bytes.Equal(key, []byte("йцук")):
if !equalCookie(&expectedC2, &cc) {
t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC2)
}
default:
t.Fatalf("unexpected cookie key %q", key)
}
})
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
h.DelAllCookies()
var h1 ResponseHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
c.SetKey("foobar")
if !h1.Cookie(&c) {
t.Fatalf("Cannot find cookie %q", c.Key())
}
if !equalCookie(&expectedC1, &c) {
t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC1)
}
h1.DelCookie("foobar")
if h.Cookie(&c) {
t.Fatalf("Unexpected cookie found: %v", &c)
}
if h1.Cookie(&c) {
t.Fatalf("Unexpected cookie found: %v", &c)
}
c.SetKey("йцук")
if !h1.Cookie(&c) {
t.Fatalf("cannot find cookie %q", c.Key())
}
if !equalCookie(&expectedC2, &c) {
t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2)
}
h1.DelCookie("йцук")
if h.Cookie(&c) {
t.Fatalf("Unexpected cookie found: %v", &c)
}
if h1.Cookie(&c) {
t.Fatalf("Unexpected cookie found: %v", &c)
}
}
func equalCookie(c1, c2 *Cookie) bool {
if !bytes.Equal(c1.Key(), c2.Key()) {
return false
}
if !bytes.Equal(c1.Value(), c2.Value()) {
return false
}
if !c1.Expire().Equal(c2.Expire()) {
return false
}
if !bytes.Equal(c1.Domain(), c2.Domain()) {
return false
}
if !bytes.Equal(c1.Path(), c2.Path()) {
return false
}
return true
}
func TestRequestHeaderCookie(t *testing.T) {
t.Parallel()
var h RequestHeader
h.SetRequestURI("/foobar")
h.Set(HeaderHost, "foobar.com")
h.SetCookie("foo", "bar")
h.SetCookie("привет", "мир")
if string(h.Cookie("foo")) != "bar" {
t.Fatalf("Unexpected cookie value %q. Exepcted %q", h.Cookie("foo"), "bar")
}
if string(h.Cookie("привет")) != "мир" {
t.Fatalf("Unexpected cookie value %q. Expected %q", h.Cookie("привет"), "мир")
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if !bytes.Equal(h1.Cookie("foo"), h.Cookie("foo")) {
t.Fatalf("Unexpected cookie value %q. Exepcted %q", h1.Cookie("foo"), h.Cookie("foo"))
}
h1.DelCookie("foo")
if len(h1.Cookie("foo")) > 0 {
t.Fatalf("Unexpected cookie found: %q", h1.Cookie("foo"))
}
if !bytes.Equal(h1.Cookie("привет"), h.Cookie("привет")) {
t.Fatalf("Unexpected cookie value %q. Expected %q", h1.Cookie("привет"), h.Cookie("привет"))
}
h1.DelCookie("привет")
if len(h1.Cookie("привет")) > 0 {
t.Fatalf("Unexpected cookie found: %q", h1.Cookie("привет"))
}
h.DelAllCookies()
if len(h.Cookie("foo")) > 0 {
t.Fatalf("Unexpected cookie found: %q", h.Cookie("foo"))
}
if len(h.Cookie("привет")) > 0 {
t.Fatalf("Unexpected cookie found: %q", h.Cookie("привет"))
}
}
func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Parallel()
var h ResponseHeader
c := AcquireCookie()
c.SetKey("foo")
c.SetValue("bar")
h.SetCookie(c)
if string(h.Peek(HeaderSetCookie)) != "foo=bar" {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
}
})
if !cookieSeen {
t.Fatalf("Set-Cookie not present in VisitAll")
}
c = AcquireCookie()
c.SetKey("foo")
h.Cookie(c)
if string(c.Value()) != "bar" {
t.Fatalf("Unexpected cookie value %q. Exepcted %q", c.Value(), "bar")
}
if string(h.Peek(HeaderSetCookie)) != "foo=bar" {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
}
})
if !cookieSeen {
t.Fatalf("Set-Cookie not present in VisitAll")
}
}
func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Parallel()
var h RequestHeader
h.SetRequestURI("/")
h.Set(HeaderHost, "foobar.com")
h.SetCookie("foo", "bar")
if string(h.Peek(HeaderCookie)) != "foo=bar" {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen := false
h.VisitAll(func(key, value []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
}
})
if !cookieSeen {
t.Fatalf("Cookie not present in VisitAll")
}
if string(h.Cookie("foo")) != "bar" {
t.Fatalf("Unexpected cookie value %q. Exepcted %q", h.Cookie("foo"), "bar")
}
if string(h.Peek(HeaderCookie)) != "foo=bar" {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen = false
h.VisitAll(func(key, value []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
}
})
if !cookieSeen {
t.Fatalf("Cookie not present in VisitAll")
}
}
func TestRequestHeaderMethod(t *testing.T) {
t.Parallel()
// common http methods
testRequestHeaderMethod(t, MethodGet)
testRequestHeaderMethod(t, MethodPost)
testRequestHeaderMethod(t, MethodHead)
testRequestHeaderMethod(t, MethodDelete)
// non-http methods
testRequestHeaderMethod(t, "foobar")
testRequestHeaderMethod(t, "ABC")
}
func testRequestHeaderMethod(t *testing.T, expectedMethod string) {
var h RequestHeader
h.SetMethod(expectedMethod)
m := h.Method()
if string(m) != expectedMethod {
t.Fatalf("unexpected method: %q. Expecting %q", m, expectedMethod)
}
s := h.String()
var h1 RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
m1 := h1.Method()
if string(m) != string(m1) {
t.Fatalf("unexpected method: %q. Expecting %q", m, m1)
}
}
func TestRequestHeaderSetGet(t *testing.T) {
t.Parallel()
h := &RequestHeader{}
h.SetRequestURI("/aa/bbb")
h.SetMethod(MethodPost)
h.Set("foo", "bar")
h.Set("host", "12345")
h.Set("content-type", "aaa/bbb")
h.Set("content-length", "1234")
h.Set("user-agent", "aaabbb")
h.Set("referer", "axcv")
h.Set("baz", "xxxxx")
h.Set("transfer-encoding", "chunked")
h.Set("connection", "close")
expectRequestHeaderGet(t, h, "Foo", "bar")
expectRequestHeaderGet(t, h, HeaderHost, "12345")
expectRequestHeaderGet(t, h, HeaderContentType, "aaa/bbb")
expectRequestHeaderGet(t, h, HeaderContentLength, "1234")
expectRequestHeaderGet(t, h, "USER-AGent", "aaabbb")
expectRequestHeaderGet(t, h, HeaderReferer, "axcv")
expectRequestHeaderGet(t, h, "baz", "xxxxx")
expectRequestHeaderGet(t, h, HeaderTransferEncoding, "")
expectRequestHeaderGet(t, h, "connecTION", "close")
if !h.ConnectionClose() {
t.Fatalf("unset connection: close")
}
if h.ContentLength() != 1234 {
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
t.Fatalf("Unexpected error when writing request header: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("Unexpected error when flushing request header: %s", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err = h1.Read(br); err != nil {
t.Fatalf("Unexpected error when reading request header: %s", err)
}
if h1.ContentLength() != h.ContentLength() {
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength())
}
expectRequestHeaderGet(t, &h1, "Foo", "bar")
expectRequestHeaderGet(t, &h1, "HOST", "12345")
expectRequestHeaderGet(t, &h1, HeaderContentType, "aaa/bbb")
expectRequestHeaderGet(t, &h1, HeaderContentLength, "1234")
expectRequestHeaderGet(t, &h1, "USER-AGent", "aaabbb")
expectRequestHeaderGet(t, &h1, HeaderReferer, "axcv")
expectRequestHeaderGet(t, &h1, "baz", "xxxxx")
expectRequestHeaderGet(t, &h1, HeaderTransferEncoding, "")
expectRequestHeaderGet(t, &h1, HeaderConnection, "close")
if !h1.ConnectionClose() {
t.Fatalf("unset connection: close")
}
}
func TestResponseHeaderSetGet(t *testing.T) {
t.Parallel()
h := &ResponseHeader{}
h.Set("foo", "bar")
h.Set("content-type", "aaa/bbb")
h.Set("connection", "close")
h.Set("content-length", "1234")
h.Set(HeaderServer, "aaaa")
h.Set("baz", "xxxxx")
h.Set(HeaderTransferEncoding, "chunked")
expectResponseHeaderGet(t, h, "Foo", "bar")
expectResponseHeaderGet(t, h, HeaderContentType, "aaa/bbb")
expectResponseHeaderGet(t, h, HeaderConnection, "close")
expectResponseHeaderGet(t, h, HeaderContentLength, "1234")
expectResponseHeaderGet(t, h, "seRVer", "aaaa")
expectResponseHeaderGet(t, h, "baz", "xxxxx")
expectResponseHeaderGet(t, h, HeaderTransferEncoding, "")
if h.ContentLength() != 1234 {
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234)
}
if !h.ConnectionClose() {
t.Fatalf("Unexpected Connection: close value %v. Expected %v", h.ConnectionClose(), true)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
t.Fatalf("Unexpected error when writing response header: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("Unexpected error when flushing response header: %s", err)
}
var h1 ResponseHeader
br := bufio.NewReader(w)
if err = h1.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response header: %s", err)
}
if h1.ContentLength() != h.ContentLength() {
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength())
}
if h1.ConnectionClose() != h.ConnectionClose() {
t.Fatalf("unexpected connection: close %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose())
}
expectResponseHeaderGet(t, &h1, "Foo", "bar")
expectResponseHeaderGet(t, &h1, HeaderContentType, "aaa/bbb")
expectResponseHeaderGet(t, &h1, HeaderConnection, "close")
expectResponseHeaderGet(t, &h1, "seRVer", "aaaa")
expectResponseHeaderGet(t, &h1, "baz", "xxxxx")
}
func expectRequestHeaderGet(t *testing.T, h *RequestHeader, key, expectedValue string) {
if string(h.Peek(key)) != expectedValue {
t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.Peek(key), expectedValue)
}
}
func expectResponseHeaderGet(t *testing.T, h *ResponseHeader, key, expectedValue string) {
if string(h.Peek(key)) != expectedValue {
t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.Peek(key), expectedValue)
}
}
func TestResponseHeaderConnectionClose(t *testing.T) {
t.Parallel()
testResponseHeaderConnectionClose(t, true)
testResponseHeaderConnectionClose(t, false)
}
func testResponseHeaderConnectionClose(t *testing.T, connectionClose bool) {
h := &ResponseHeader{}
if connectionClose {
h.SetConnectionClose()
}
h.SetContentLength(123)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
t.Fatalf("Unexpected error when writing response header: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("Unexpected error when flushing response header: %s", err)
}
var h1 ResponseHeader
br := bufio.NewReader(w)
err = h1.Read(br)
if err != nil {
t.Fatalf("Unexpected error when reading response header: %s", err)
}
if h1.ConnectionClose() != h.ConnectionClose() {
t.Fatalf("Unexpected value for ConnectionClose: %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose())
}
}
func TestRequestHeaderTooBig(t *testing.T) {
t.Parallel()
s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n" + getHeaders(10500) + "\r\n"
r := bytes.NewBufferString(s)
br := bufio.NewReaderSize(r, 4096)
h := &RequestHeader{}
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading too big header")
}
}
func TestResponseHeaderTooBig(t *testing.T) {
t.Parallel()
s := "HTTP/1.1 200 OK\r\nContent-Type: sss\r\nContent-Length: 0\r\n" + getHeaders(100500) + "\r\n"
r := bytes.NewBufferString(s)
br := bufio.NewReaderSize(r, 4096)
h := &ResponseHeader{}
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading too big header")
}
}
type bufioPeekReader struct {
s string
n int
}
func (r *bufioPeekReader) Read(b []byte) (int, error) {
if len(r.s) == 0 {
return 0, io.EOF
}
r.n++
n := r.n
if len(r.s) < n {
n = len(r.s)
}
src := []byte(r.s[:n])
r.s = r.s[n:]
n = copy(b, src)
return n, nil
}
func TestRequestHeaderBufioPeek(t *testing.T) {
t.Parallel()
r := &bufioPeekReader{
s: "GET / HTTP/1.1\r\nHost: foobar.com\r\n" + getHeaders(10) + "\r\naaaa",
}
br := bufio.NewReaderSize(r, 4096)
h := &RequestHeader{}
if err := h.Read(br); err != nil {
t.Fatalf("Unexpected error when reading request: %s", err)
}
verifyRequestHeader(t, h, -2, "/", "foobar.com", "", "")
verifyTrailer(t, br, "aaaa")
}
func TestResponseHeaderBufioPeek(t *testing.T) {
t.Parallel()
r := &bufioPeekReader{
s: "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: aaa\r\n" + getHeaders(10) + "\r\n0123456789",
}
br := bufio.NewReaderSize(r, 4096)
h := &ResponseHeader{}
if err := h.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %s", err)
}
verifyResponseHeader(t, h, 200, 10, "aaa")
verifyTrailer(t, br, "0123456789")
}
func getHeaders(n int) string {
var h []string
for i := 0; i < n; i++ {
h = append(h, fmt.Sprintf("Header_%d: Value_%d\r\n", i, i))
}
return strings.Join(h, "")
}
func TestResponseHeaderReadSuccess(t *testing.T) {
t.Parallel()
h := &ResponseHeader{}
// straight order of content-length and content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n",
200, 123, "text/html", "")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// reverse order of content-length and content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\nConnection: close\r\n\r\n",
202, 543, "text/plain; encoding=utf-8", "")
if !h.ConnectionClose() {
t.Fatalf("expecting connection: close")
}
// tranfer-encoding: chunked
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 505 Internal error\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n",
505, -1, "text/html", "")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// reverse order of content-type and tranfer-encoding
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 343 foobar\r\nTransfer-Encoding: chunked\r\nContent-Type: text/json\r\n\r\n",
343, -1, "text/json", "")
// additional headers
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 100 Continue\r\nFoobar: baz\r\nContent-Type: aaa/bbb\r\nUser-Agent: x\r\nContent-Length: 123\r\nZZZ: werer\r\n\r\n",
100, 123, "aaa/bbb", "")
// trailer (aka body)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 32245\r\n\r\nqwert aaa",
200, 32245, "text/plain", "qwert aaa")
// ancient http protocol
testResponseHeaderReadSuccess(t, h, "HTTP/0.9 300 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\nqqqq",
300, 123, "text/html", "qqqq")
// lf instead of crlf
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length: 123\nContent-Type: text/html\n\n",
200, 123, "text/html", "")
// Zero-length headers with mixed crlf and lf
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nContent-Length: 345\nZero-Value: \r\nContent-Type: aaa\n: zero-key\r\n\r\nooa",
400, 345, "aaa", "ooa")
// No space after colon
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length:34\nContent-Type: sss\n\naaaa",
200, 34, "sss", "aaaa")
// invalid case
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nconTEnt-leNGTH: 123\nConTENT-TYPE: ass\n\n",
400, 123, "ass", "")
// duplicate content-length
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 456\r\nContent-Type: foo/bar\r\nContent-Length: 321\r\n\r\n",
200, 321, "foo/bar", "")
// duplicate content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 234\r\nContent-Type: foo/bar\r\nContent-Type: baz/bar\r\n\r\n",
200, 234, "baz/bar", "")
// both transfer-encoding: chunked and content-length
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 123\r\nTransfer-Encoding: chunked\r\n\r\n",
200, -1, "foo/bar", "")
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 300 OK\r\nContent-Type: foo/barr\r\nTransfer-Encoding: chunked\r\nContent-Length: 354\r\n\r\n",
300, -1, "foo/barr", "")
// duplicate transfer-encoding: chunked
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\nTransfer-Encoding: chunked\r\n\r\n",
200, -1, "text/html", "")
// no reason string in the first line
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 456\r\nContent-Type: xxx/yyy\r\nContent-Length: 134\r\n\r\naaaxxx",
456, 134, "xxx/yyy", "aaaxxx")
// blank lines before the first line
testResponseHeaderReadSuccess(t, h, "\r\nHTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 0\r\n\r\nsss",
200, 0, "aa", "sss")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// no content-length (informational responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 101 OK\r\n\r\n",
101, -2, "text/plain; charset=utf-8", "")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for informational response")
}
// no content-length (no-content responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 204 OK\r\n\r\n",
204, -2, "text/plain; charset=utf-8", "")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for no-content response")
}
// no content-length (not-modified responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 304 OK\r\n\r\n",
304, -2, "text/plain; charset=utf-8", "")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for not-modified response")
}
// no content-length (identity transfer-encoding)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\n\r\nabcdefg",
200, -2, "foo/bar", "abcdefg")
if !h.ConnectionClose() {
t.Fatalf("expecting connection: close for identity response")
}
// non-numeric content-length
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: faaa\r\nContent-Type: text/html\r\n\r\nfoobar",
200, -2, "text/html", "foobar")
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 201 OK\r\nContent-Length: 123aa\r\nContent-Type: text/ht\r\n\r\naaa",
201, -2, "text/ht", "aaa")
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: aa124\r\nContent-Type: html\r\n\r\nxx",
200, -2, "html", "xx")
// no content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa",
400, 123, string(defaultContentType), "foiaaa")
// no content-type and no default
h.SetNoDefaultContentType(true)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa",
400, 123, "", "foiaaa")
h.SetNoDefaultContentType(false)
// no headers
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\n\r\naaaabbb",
200, -2, string(defaultContentType), "aaaabbb")
if !h.IsHTTP11() {
t.Fatalf("expecting http/1.1 protocol")
}
// ancient http protocol
testResponseHeaderReadSuccess(t, h, "HTTP/1.0 203 OK\r\nContent-Length: 123\r\nContent-Type: foobar\r\n\r\naaa",
203, 123, "foobar", "aaa")
if h.IsHTTP11() {
t.Fatalf("ancient protocol must be non-http/1.1")
}
if !h.ConnectionClose() {
t.Fatalf("expecting connection: close for ancient protocol")
}
// ancient http protocol with 'Connection: keep-alive' header.
testResponseHeaderReadSuccess(t, h, "HTTP/1.0 403 aa\r\nContent-Length: 0\r\nContent-Type: 2\r\nConnection: Keep-Alive\r\n\r\nww",
403, 0, "2", "ww")
if h.IsHTTP11() {
t.Fatalf("ancient protocol must be non-http/1.1")
}
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for ancient protocol")
}
}
func TestRequestHeaderReadSuccess(t *testing.T) {
t.Parallel()
h := &RequestHeader{}
// simple headers
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n",
-2, "/foo/bar", "google.com", "", "", "")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close header")
}
// simple headers with body
testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar",
-2, "/a/bar", "gole.com", "", "", "foobar")
if !h.ConnectionClose() {
t.Fatalf("connection: close unset")
}
// ancient http protocol
testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp",
-2, "/bar", "gole", "", "", "pppp")
if h.IsHTTP11() {
t.Fatalf("ancient http protocol cannot be http/1.1")
}
if !h.ConnectionClose() {
t.Fatalf("expecting connectionClose for ancient http protocol")
}
// ancient http protocol with 'Connection: keep-alive' header
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.0\r\nHost: bb\r\nConnection: keep-alive\r\n\r\nxxx",
-2, "/aa", "bb", "", "", "xxx")
if h.IsHTTP11() {
t.Fatalf("ancient http protocol cannot be http/1.1")
}
if h.ConnectionClose() {
t.Fatalf("unexpected 'connection: close' for ancient http protocol")
}
// complex headers with body
testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz",
-2, "/aabar", "ole.com", "", "", "zzz")
if !h.IsHTTP11() {
t.Fatalf("expecting http/1.1 protocol")
}
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// lf instead of crlf
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\nHost: google.com\n\n",
-2, "/foo/bar", "google.com", "", "", "")
// post method
testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef",
1235, "/aaa?bbb", "foobar.com", "", "aaa", "abcdef")
// zero-length headers with mixed crlf and lf
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv",
-2, "/a", "aaa", "", "", "xccv")
// no space after colon
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds",
-2, "/a", "aaaxd", "", "", "sdfds")
// get with zero content-length
testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 0\n\n",
0, "/xxx", "aaa.com", "", "", "")
// get with non-zero content-length
testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 123\n\n",
123, "/xxx", "aaa.com", "", "", "")
// invalid case
testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\nhoST: bbb.com\n\naas",
-2, "/aaa", "bbb.com", "", "", "aas")
// referer
testRequestHeaderReadSuccess(t, h, "GET /asdf HTTP/1.1\nHost: aaa.com\nReferer: bb.com\n\naaa",
-2, "/asdf", "aaa.com", "bb.com", "", "aaa")
// duplicate host
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aaaaaa.com\r\nHost: bb.com\r\n\r\n",
-2, "/aa", "bb.com", "", "", "")
// post with duplicate content-type
testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n",
123, "/a", "aa", "", "xx", "")
// post with duplicate content-length
testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n",
1, "/xx", "aa", "", "s", "")
// non-post with content-type
testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n",
-2, "/aaa", "bbb.com", "", "aaab", "")
// non-post with content-length
testRequestHeaderReadSuccess(t, h, "HEAD / HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\n\r\n",
123, "/", "aaa.com", "", "", "")
// non-post with content-type and content-length
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aa.com\r\nContent-Type: abd/test\r\nContent-Length: 123\r\n\r\n",
123, "/aa", "aa.com", "", "abd/test", "")
// request uri with hostname
testRequestHeaderReadSuccess(t, h, "GET http://gooGle.com/foO/%20bar?xxx#aaa HTTP/1.1\r\nHost: aa.cOM\r\n\r\ntrail",
-2, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", "trail")
// no protocol in the first line
testRequestHeaderReadSuccess(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD",
-2, "/foo/bar", "google.com", "", "", "isdD")
// blank lines before the first line
testRequestHeaderReadSuccess(t, h, "\r\n\n\r\nGET /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nsss",
-2, "/aaa", "aaa.com", "", "", "sss")
// request uri with spaces
testRequestHeaderReadSuccess(t, h, "GET /foo/ bar baz HTTP/1.1\r\nHost: aa.com\r\n\r\nxxx",
-2, "/foo/ bar baz", "aa.com", "", "", "xxx")
// no host
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa",
-2, "/foo/bar", "", "", "", "aaa")
// no host, no headers
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar",
-2, "/foo/bar", "", "", "", "foobar")
// post without content-length and content-type
testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc",
-2, "/aaa", "aaa.com", "", "", "zxc")
// post without content-type
testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Length: 123\r\n\r\npoiuy",
123, "/abc", "aa.com", "", "", "poiuy")
// post without content-length
testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Type: adv\r\n\r\n123456",
-2, "/abc", "aa.com", "", "adv", "123456")
// invalid method
testRequestHeaderReadSuccess(t, h, "POST /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\nmnbv",
-2, "/foo/bar", "google.com", "", "", "mnbv")
// put request
testRequestHeaderReadSuccess(t, h, "PUT /faa HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\nContent-Type: aaa\r\n\r\nxwwere",
123, "/faa", "aaa.com", "", "aaa", "xwwere")
}
func TestResponseHeaderReadError(t *testing.T) {
t.Parallel()
h := &ResponseHeader{}
// incorrect first line
testResponseHeaderReadError(t, h, "")
testResponseHeaderReadError(t, h, "fo")
testResponseHeaderReadError(t, h, "foobarbaz")
testResponseHeaderReadError(t, h, "HTTP/1.1")
testResponseHeaderReadError(t, h, "HTTP/1.1 ")
testResponseHeaderReadError(t, h, "HTTP/1.1 s")
// non-numeric status code
testResponseHeaderReadError(t, h, "HTTP/1.1 foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
testResponseHeaderReadError(t, h, "HTTP/1.1 123foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
testResponseHeaderReadError(t, h, "HTTP/1.1 foobar344 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
// no headers
testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\n")
// no trailing crlf
testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n")
}
func TestResponseHeaderReadErrorSecureLog(t *testing.T) {
t.Parallel()
h := &ResponseHeader{
secureErrorLogMessage: true,
}
// incorrect first line
testResponseHeaderReadSecuredError(t, h, "fo")
testResponseHeaderReadSecuredError(t, h, "foobarbaz")
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1")
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 ")
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 s")
// non-numeric status code
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 123foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 foobar344 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
// no headers
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 200 OK\r\n")
// no trailing crlf
testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n")
}
func TestRequestHeaderReadError(t *testing.T) {
t.Parallel()
h := &RequestHeader{}
// incorrect first line
testRequestHeaderReadError(t, h, "")
testRequestHeaderReadError(t, h, "fo")
testRequestHeaderReadError(t, h, "GET ")
testRequestHeaderReadError(t, h, "GET / HTTP/1.1\r")
// missing RequestURI
testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n")
// post with invalid content-length
testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty")
}
func TestRequestHeaderReadSecuredError(t *testing.T) {
t.Parallel()
h := &RequestHeader{
secureErrorLogMessage: true,
}
// incorrect first line
testRequestHeaderReadSecuredError(t, h, "fo")
testRequestHeaderReadSecuredError(t, h, "GET ")
testRequestHeaderReadSecuredError(t, h, "GET / HTTP/1.1\r")
// missing RequestURI
testRequestHeaderReadSecuredError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n")
// post with invalid content-length
testRequestHeaderReadSecuredError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty")
}
func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading response header %q", headers)
}
// make sure response header works after error
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss",
200, 12345, "foo/bar", "sss")
}
func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading response header %q", headers)
}
if strings.Contains(err.Error(), headers) {
t.Fatalf("Not expecting header content in err %q", err)
}
// make sure response header works after error
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss",
200, 12345, "foo/bar", "sss")
}
func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading request header %q", headers)
}
// make sure request header works after error
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx",
-2, "/foo/bar", "aaaa", "", "", "xxx")
}
func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err == nil {
t.Fatalf("Expecting error when reading request header %q", headers)
}
if strings.Contains(err.Error(), headers) {
t.Fatalf("Not expecting header content in err %q", err)
}
// make sure request header works after error
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx",
-2, "/foo/bar", "aaaa", "", "", "xxx")
}
func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int,
expectedContentType, expectedTrailer string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err != nil {
t.Fatalf("Unexpected error when parsing response headers: %s. headers=%q", err, headers)
}
verifyResponseHeader(t, h, expectedStatusCode, expectedContentLength, expectedContentType)
verifyTrailer(t, br, expectedTrailer)
}
func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int,
expectedRequestURI, expectedHost, expectedReferer, expectedContentType, expectedTrailer string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err != nil {
t.Fatalf("Unexpected error when parsing request headers: %s. headers=%q", err, headers)
}
verifyRequestHeader(t, h, expectedContentLength, expectedRequestURI, expectedHost, expectedReferer, expectedContentType)
verifyTrailer(t, br, expectedTrailer)
}
func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType string) {
if h.StatusCode() != expectedStatusCode {
t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode)
}
if h.ContentLength() != expectedContentLength {
t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength)
}
if string(h.Peek(HeaderContentType)) != expectedContentType {
t.Fatalf("Unexpected content type %q. Expected %q", h.Peek(HeaderContentType), expectedContentType)
}
}
func verifyResponseHeaderConnection(t *testing.T, h *ResponseHeader, expectConnection string) {
if string(h.Peek(HeaderConnection)) != expectConnection {
t.Fatalf("Unexpected Connection %q. Expected %q", h.Peek(HeaderConnection), expectConnection)
}
}
func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength int,
expectedRequestURI, expectedHost, expectedReferer, expectedContentType string) {
if h.ContentLength() != expectedContentLength {
t.Fatalf("Unexpected Content-Length %d. Expected %d", h.ContentLength(), expectedContentLength)
}
if string(h.RequestURI()) != expectedRequestURI {
t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI(), expectedRequestURI)
}
if string(h.Peek(HeaderHost)) != expectedHost {
t.Fatalf("Unexpected host %q. Expected %q", h.Peek(HeaderHost), expectedHost)
}
if string(h.Peek(HeaderReferer)) != expectedReferer {
t.Fatalf("Unexpected referer %q. Expected %q", h.Peek(HeaderReferer), expectedReferer)
}
if string(h.Peek(HeaderContentType)) != expectedContentType {
t.Fatalf("Unexpected content-type %q. Expected %q", h.Peek(HeaderContentType), expectedContentType)
}
}
func verifyTrailer(t *testing.T, r *bufio.Reader, expectedTrailer string) {
trailer, err := ioutil.ReadAll(r)
if err != nil {
t.Fatalf("Cannot read trailer: %s", err)
}
if !bytes.Equal(trailer, []byte(expectedTrailer)) {
t.Fatalf("Unexpected trailer %q. Expected %q", trailer, expectedTrailer)
}
}
fasthttp-1.31.0/header_timing_test.go 0000664 0000000 0000000 00000010645 14130360711 0017620 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"io"
"strconv"
"testing"
"github.com/valyala/bytebufferpool"
)
var strFoobar = []byte("foobar.com")
type benchReadBuf struct {
s []byte
n int
}
func (r *benchReadBuf) Read(p []byte) (int, error) {
if r.n == len(r.s) {
return 0, io.EOF
}
n := copy(p, r.s[r.n:])
r.n += n
return n, nil
}
func BenchmarkRequestHeaderRead(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
buf := &benchReadBuf{
s: []byte("GET /foo/bar HTTP/1.1\r\nHost: foobar.com\r\nUser-Agent: aaa.bbb\r\nReferer: http://google.com/aaa/bbb\r\n\r\n"),
}
br := bufio.NewReader(buf)
for pb.Next() {
buf.n = 0
br.Reset(buf)
if err := h.Read(br); err != nil {
b.Fatalf("unexpected error when reading header: %s", err)
}
}
})
}
func BenchmarkResponseHeaderRead(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h ResponseHeader
buf := &benchReadBuf{
s: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 1256\r\nServer: aaa 1/2.3\r\nTest: 1.2.3\r\n\r\n"),
}
br := bufio.NewReader(buf)
for pb.Next() {
buf.n = 0
br.Reset(buf)
if err := h.Read(br); err != nil {
b.Fatalf("unexpected error when reading header: %s", err)
}
}
})
}
func BenchmarkRequestHeaderWrite(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
h.SetRequestURI("/foo/bar")
h.SetHost("foobar.com")
h.SetUserAgent("aaa.bbb")
h.SetReferer("http://google.com/aaa/bbb")
var w bytebufferpool.ByteBuffer
for pb.Next() {
if _, err := h.WriteTo(&w); err != nil {
b.Fatalf("unexpected error when writing header: %s", err)
}
w.Reset()
}
})
}
func BenchmarkResponseHeaderWrite(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h ResponseHeader
h.SetStatusCode(200)
h.SetContentType("text/html")
h.SetContentLength(1256)
h.SetServer("aaa 1/2.3")
h.Set("Test", "1.2.3")
var w bytebufferpool.ByteBuffer
for pb.Next() {
if _, err := h.WriteTo(&w); err != nil {
b.Fatalf("unexpected error when writing header: %s", err)
}
w.Reset()
}
})
}
func BenchmarkRequestHeaderPeekBytesCanonical(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
h.SetBytesV("Host", strFoobar)
for pb.Next() {
v := h.PeekBytes(strHost)
if !bytes.Equal(v, strFoobar) {
b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
}
}
})
}
func BenchmarkRequestHeaderPeekBytesNonCanonical(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
h.SetBytesV("Host", strFoobar)
hostBytes := []byte("HOST")
for pb.Next() {
v := h.PeekBytes(hostBytes)
if !bytes.Equal(v, strFoobar) {
b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
}
}
})
}
func BenchmarkNormalizeHeaderKeyCommonCase(b *testing.B) {
src := []byte("User-Agent-Host-Content-Type-Content-Length-Server")
benchmarkNormalizeHeaderKey(b, src)
}
func BenchmarkNormalizeHeaderKeyLowercase(b *testing.B) {
src := []byte("user-agent-host-content-type-content-length-server")
benchmarkNormalizeHeaderKey(b, src)
}
func BenchmarkNormalizeHeaderKeyUppercase(b *testing.B) {
src := []byte("USER-AGENT-HOST-CONTENT-TYPE-CONTENT-LENGTH-SERVER")
benchmarkNormalizeHeaderKey(b, src)
}
func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) {
b.RunParallel(func(pb *testing.PB) {
buf := make([]byte, len(src))
for pb.Next() {
copy(buf, src)
normalizeHeaderKey(buf, false)
}
})
}
func BenchmarkRemoveNewLines(b *testing.B) {
type testcase struct {
value string
expectedValue string
}
var testcases = []testcase{
{value: "MaliciousValue", expectedValue: "MaliciousValue"},
{value: "MaliciousValue\r\n", expectedValue: "MaliciousValue "},
{value: "Malicious\nValue", expectedValue: "Malicious Value"},
{value: "Malicious\rValue", expectedValue: "Malicious Value"},
}
for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
b.Run(caseName, func(subB *testing.B) {
subB.ReportAllocs()
var h RequestHeader
for i := 0; i < subB.N; i++ {
h.Set("Test", tcase.value)
}
subB.StopTimer()
actualValue := string(h.Peek("Test"))
if actualValue != tcase.expectedValue {
subB.Errorf("unexpected value, got: %+v", actualValue)
}
})
}
}
func BenchmarkRequestHeaderIsGet(b *testing.B) {
req := &RequestHeader{method: []byte(MethodGet)}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req.IsGet()
}
})
}
fasthttp-1.31.0/headers.go 0000664 0000000 0000000 00000013146 14130360711 0015374 0 ustar 00root root 0000000 0000000 package fasthttp
// Headers
const (
// Authentication
HeaderAuthorization = "Authorization"
HeaderProxyAuthenticate = "Proxy-Authenticate"
HeaderProxyAuthorization = "Proxy-Authorization"
HeaderWWWAuthenticate = "WWW-Authenticate"
// Caching
HeaderAge = "Age"
HeaderCacheControl = "Cache-Control"
HeaderClearSiteData = "Clear-Site-Data"
HeaderExpires = "Expires"
HeaderPragma = "Pragma"
HeaderWarning = "Warning"
// Client hints
HeaderAcceptCH = "Accept-CH"
HeaderAcceptCHLifetime = "Accept-CH-Lifetime"
HeaderContentDPR = "Content-DPR"
HeaderDPR = "DPR"
HeaderEarlyData = "Early-Data"
HeaderSaveData = "Save-Data"
HeaderViewportWidth = "Viewport-Width"
HeaderWidth = "Width"
// Conditionals
HeaderETag = "ETag"
HeaderIfMatch = "If-Match"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderIfNoneMatch = "If-None-Match"
HeaderIfUnmodifiedSince = "If-Unmodified-Since"
HeaderLastModified = "Last-Modified"
HeaderVary = "Vary"
// Connection management
HeaderConnection = "Connection"
HeaderKeepAlive = "Keep-Alive"
// Content negotiation
HeaderAccept = "Accept"
HeaderAcceptCharset = "Accept-Charset"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAcceptLanguage = "Accept-Language"
// Controls
HeaderCookie = "Cookie"
HeaderExpect = "Expect"
HeaderMaxForwards = "Max-Forwards"
HeaderSetCookie = "Set-Cookie"
// CORS
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
HeaderOrigin = "Origin"
HeaderTimingAllowOrigin = "Timing-Allow-Origin"
HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies"
// Do Not Track
HeaderDNT = "DNT"
HeaderTk = "Tk"
// Downloads
HeaderContentDisposition = "Content-Disposition"
// Message body information
HeaderContentEncoding = "Content-Encoding"
HeaderContentLanguage = "Content-Language"
HeaderContentLength = "Content-Length"
HeaderContentLocation = "Content-Location"
HeaderContentType = "Content-Type"
// Proxies
HeaderForwarded = "Forwarded"
HeaderVia = "Via"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedHost = "X-Forwarded-Host"
HeaderXForwardedProto = "X-Forwarded-Proto"
// Redirects
HeaderLocation = "Location"
// Request context
HeaderFrom = "From"
HeaderHost = "Host"
HeaderReferer = "Referer"
HeaderReferrerPolicy = "Referrer-Policy"
HeaderUserAgent = "User-Agent"
// Response context
HeaderAllow = "Allow"
HeaderServer = "Server"
// Range requests
HeaderAcceptRanges = "Accept-Ranges"
HeaderContentRange = "Content-Range"
HeaderIfRange = "If-Range"
HeaderRange = "Range"
// Security
HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy"
HeaderExpectCT = "Expect-CT"
HeaderFeaturePolicy = "Feature-Policy"
HeaderPublicKeyPins = "Public-Key-Pins"
HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only"
HeaderStrictTransportSecurity = "Strict-Transport-Security"
HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests"
HeaderXContentTypeOptions = "X-Content-Type-Options"
HeaderXDownloadOptions = "X-Download-Options"
HeaderXFrameOptions = "X-Frame-Options"
HeaderXPoweredBy = "X-Powered-By"
HeaderXXSSProtection = "X-XSS-Protection"
// Server-sent event
HeaderLastEventID = "Last-Event-ID"
HeaderNEL = "NEL"
HeaderPingFrom = "Ping-From"
HeaderPingTo = "Ping-To"
HeaderReportTo = "Report-To"
// Transfer coding
HeaderTE = "TE"
HeaderTrailer = "Trailer"
HeaderTransferEncoding = "Transfer-Encoding"
// WebSockets
HeaderSecWebSocketAccept = "Sec-WebSocket-Accept"
HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions"
HeaderSecWebSocketKey = "Sec-WebSocket-Key"
HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol"
HeaderSecWebSocketVersion = "Sec-WebSocket-Version"
// Other
HeaderAcceptPatch = "Accept-Patch"
HeaderAcceptPushPolicy = "Accept-Push-Policy"
HeaderAcceptSignature = "Accept-Signature"
HeaderAltSvc = "Alt-Svc"
HeaderDate = "Date"
HeaderIndex = "Index"
HeaderLargeAllocation = "Large-Allocation"
HeaderLink = "Link"
HeaderPushPolicy = "Push-Policy"
HeaderRetryAfter = "Retry-After"
HeaderServerTiming = "Server-Timing"
HeaderSignature = "Signature"
HeaderSignedHeaders = "Signed-Headers"
HeaderSourceMap = "SourceMap"
HeaderUpgrade = "Upgrade"
HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control"
HeaderXPingback = "X-Pingback"
HeaderXRequestedWith = "X-Requested-With"
HeaderXRobotsTag = "X-Robots-Tag"
HeaderXUACompatible = "X-UA-Compatible"
)
fasthttp-1.31.0/http.go 0000664 0000000 0000000 00000153416 14130360711 0014745 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"os"
"sync"
"time"
"github.com/valyala/bytebufferpool"
)
// Request represents HTTP request.
//
// It is forbidden copying Request instances. Create new instances
// and use CopyTo instead.
//
// Request instance MUST NOT be used from concurrently running goroutines.
type Request struct {
noCopy noCopy //nolint:unused,structcheck
// Request header
//
// Copying Header by value is forbidden. Use pointer to Header instead.
Header RequestHeader
uri URI
postArgs Args
bodyStream io.Reader
w requestBodyWriter
body *bytebufferpool.ByteBuffer
bodyRaw []byte
multipartForm *multipart.Form
multipartFormBoundary string
secureErrorLogMessage bool
// Group bool members in order to reduce Request object size.
parsedURI bool
parsedPostArgs bool
keepBodyBuffer bool
// Used by Server to indicate the request was received on a HTTPS endpoint.
// Client/HostClient shouldn't use this field but should depend on the uri.scheme instead.
isTLS bool
// Request timeout. Usually set by DoDeadline or DoTimeout
// if <= 0, means not set
timeout time.Duration
}
// Response represents HTTP response.
//
// It is forbidden copying Response instances. Create new instances
// and use CopyTo instead.
//
// Response instance MUST NOT be used from concurrently running goroutines.
type Response struct {
noCopy noCopy //nolint:unused,structcheck
// Response header
//
// Copying Header by value is forbidden. Use pointer to Header instead.
Header ResponseHeader
// Flush headers as soon as possible without waiting for first body bytes.
// Relevant for bodyStream only.
ImmediateHeaderFlush bool
bodyStream io.Reader
w responseBodyWriter
body *bytebufferpool.ByteBuffer
bodyRaw []byte
// Response.Read() skips reading body if set to true.
// Use it for reading HEAD responses.
//
// Response.Write() skips writing body if set to true.
// Use it for writing HEAD responses.
SkipBody bool
keepBodyBuffer bool
secureErrorLogMessage bool
// Remote TCPAddr from concurrently net.Conn
raddr net.Addr
// Local TCPAddr from concurrently net.Conn
laddr net.Addr
}
// SetHost sets host for the request.
func (req *Request) SetHost(host string) {
req.URI().SetHost(host)
}
// SetHostBytes sets host for the request.
func (req *Request) SetHostBytes(host []byte) {
req.URI().SetHostBytes(host)
}
// Host returns the host for the given request.
func (req *Request) Host() []byte {
return req.URI().Host()
}
// SetRequestURI sets RequestURI.
func (req *Request) SetRequestURI(requestURI string) {
req.Header.SetRequestURI(requestURI)
req.parsedURI = false
}
// SetRequestURIBytes sets RequestURI.
func (req *Request) SetRequestURIBytes(requestURI []byte) {
req.Header.SetRequestURIBytes(requestURI)
req.parsedURI = false
}
// RequestURI returns request's URI.
func (req *Request) RequestURI() []byte {
if req.parsedURI {
requestURI := req.uri.RequestURI()
req.SetRequestURIBytes(requestURI)
}
return req.Header.RequestURI()
}
// StatusCode returns response status code.
func (resp *Response) StatusCode() int {
return resp.Header.StatusCode()
}
// SetStatusCode sets response status code.
func (resp *Response) SetStatusCode(statusCode int) {
resp.Header.SetStatusCode(statusCode)
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (resp *Response) ConnectionClose() bool {
return resp.Header.ConnectionClose()
}
// SetConnectionClose sets 'Connection: close' header.
func (resp *Response) SetConnectionClose() {
resp.Header.SetConnectionClose()
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (req *Request) ConnectionClose() bool {
return req.Header.ConnectionClose()
}
// SetConnectionClose sets 'Connection: close' header.
func (req *Request) SetConnectionClose() {
req.Header.SetConnectionClose()
}
// SendFile registers file on the given path to be used as response body
// when Write is called.
//
// Note that SendFile doesn't set Content-Type, so set it yourself
// with Header.SetContentType.
func (resp *Response) SendFile(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
fileInfo, err := f.Stat()
if err != nil {
f.Close()
return err
}
size64 := fileInfo.Size()
size := int(size64)
if int64(size) != size64 {
size = -1
}
resp.Header.SetLastModified(fileInfo.ModTime())
resp.SetBodyStream(f, size)
return nil
}
// SetBodyStream sets request body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// Note that GET and HEAD requests cannot have body.
//
// See also SetBodyStreamWriter.
func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) {
req.ResetBody()
req.bodyStream = bodyStream
req.Header.SetContentLength(bodySize)
}
// SetBodyStream sets response body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// See also SetBodyStreamWriter.
func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) {
resp.ResetBody()
resp.bodyStream = bodyStream
resp.Header.SetContentLength(bodySize)
}
// IsBodyStream returns true if body is set via SetBodyStream*
func (req *Request) IsBodyStream() bool {
return req.bodyStream != nil
}
// IsBodyStream returns true if body is set via SetBodyStream*
func (resp *Response) IsBodyStream() bool {
return resp.bodyStream != nil
}
// SetBodyStreamWriter registers the given sw for populating request body.
//
// This function may be used in the following cases:
//
// * if request body is too big (more than 10MB).
// * if request body is streamed from slow external sources.
// * if request body must be streamed to the server in chunks
// (aka `http client push` or `chunked transfer-encoding`).
//
// Note that GET and HEAD requests cannot have body.
//
/// See also SetBodyStream.
func (req *Request) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
req.SetBodyStream(sr, -1)
}
// SetBodyStreamWriter registers the given sw for populating response body.
//
// This function may be used in the following cases:
//
// * if response body is too big (more than 10MB).
// * if response body is streamed from slow external sources.
// * if response body must be streamed to the client in chunks
// (aka `http server push` or `chunked transfer-encoding`).
//
// See also SetBodyStream.
func (resp *Response) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
resp.SetBodyStream(sr, -1)
}
// BodyWriter returns writer for populating response body.
//
// If used inside RequestHandler, the returned writer must not be used
// after returning from RequestHandler. Use RequestCtx.Write
// or SetBodyStreamWriter in this case.
func (resp *Response) BodyWriter() io.Writer {
resp.w.r = resp
return &resp.w
}
// BodyWriter returns writer for populating request body.
func (req *Request) BodyWriter() io.Writer {
req.w.r = req
return &req.w
}
type responseBodyWriter struct {
r *Response
}
func (w *responseBodyWriter) Write(p []byte) (int, error) {
w.r.AppendBody(p)
return len(p), nil
}
type requestBodyWriter struct {
r *Request
}
func (w *requestBodyWriter) Write(p []byte) (int, error) {
w.r.AppendBody(p)
return len(p), nil
}
func (resp *Response) parseNetConn(conn net.Conn) {
resp.raddr = conn.RemoteAddr()
resp.laddr = conn.LocalAddr()
}
// RemoteAddr returns the remote network address. The Addr returned is shared
// by all invocations of RemoteAddr, so do not modify it.
func (resp *Response) RemoteAddr() net.Addr {
return resp.raddr
}
// LocalAddr returns the local network address. The Addr returned is shared
// by all invocations of LocalAddr, so do not modify it.
func (resp *Response) LocalAddr() net.Addr {
return resp.laddr
}
// Body returns response body.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (resp *Response) Body() []byte {
if resp.bodyStream != nil {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
return resp.bodyBytes()
}
func (resp *Response) bodyBytes() []byte {
if resp.bodyRaw != nil {
return resp.bodyRaw
}
if resp.body == nil {
return nil
}
return resp.body.B
}
func (req *Request) bodyBytes() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
}
if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
if req.body == nil {
return nil
}
return req.body.B
}
func (resp *Response) bodyBuffer() *bytebufferpool.ByteBuffer {
if resp.body == nil {
resp.body = responseBodyPool.Get()
}
resp.bodyRaw = nil
return resp.body
}
func (req *Request) bodyBuffer() *bytebufferpool.ByteBuffer {
if req.body == nil {
req.body = requestBodyPool.Get()
}
req.bodyRaw = nil
return req.body
}
var (
responseBodyPool bytebufferpool.Pool
requestBodyPool bytebufferpool.Pool
)
// BodyGunzip returns un-gzipped body data.
//
// This method may be used if the request header contains
// 'Content-Encoding: gzip' for reading un-gzipped body.
// Use Body for reading gzipped request body.
func (req *Request) BodyGunzip() ([]byte, error) {
return gunzipData(req.Body())
}
// BodyGunzip returns un-gzipped body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: gzip' for reading un-gzipped body.
// Use Body for reading gzipped response body.
func (resp *Response) BodyGunzip() ([]byte, error) {
return gunzipData(resp.Body())
}
func gunzipData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteGunzip(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
// BodyUnbrotli returns un-brotlied body data.
//
// This method may be used if the request header contains
// 'Content-Encoding: br' for reading un-brotlied body.
// Use Body for reading brotlied request body.
func (req *Request) BodyUnbrotli() ([]byte, error) {
return unBrotliData(req.Body())
}
// BodyUnbrotli returns un-brotlied body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: br' for reading un-brotlied body.
// Use Body for reading brotlied response body.
func (resp *Response) BodyUnbrotli() ([]byte, error) {
return unBrotliData(resp.Body())
}
func unBrotliData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteUnbrotli(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
// BodyInflate returns inflated body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: deflate' for reading inflated request body.
// Use Body for reading deflated request body.
func (req *Request) BodyInflate() ([]byte, error) {
return inflateData(req.Body())
}
// BodyInflate returns inflated body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: deflate' for reading inflated response body.
// Use Body for reading deflated response body.
func (resp *Response) BodyInflate() ([]byte, error) {
return inflateData(resp.Body())
}
func (ctx *RequestCtx) RequestBodyStream() io.Reader {
return ctx.Request.bodyStream
}
func inflateData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteInflate(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
// BodyWriteTo writes request body to w.
func (req *Request) BodyWriteTo(w io.Writer) error {
if req.bodyStream != nil {
_, err := copyZeroAlloc(w, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
return err
}
if req.onlyMultipartForm() {
return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary)
}
_, err := w.Write(req.bodyBytes())
return err
}
// BodyWriteTo writes response body to w.
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
_, err := copyZeroAlloc(w, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
return err
}
_, err := w.Write(resp.bodyBytes())
return err
}
// AppendBody appends p to response body.
//
// It is safe re-using p after the function returns.
func (resp *Response) AppendBody(p []byte) {
resp.closeBodyStream() //nolint:errcheck
resp.bodyBuffer().Write(p) //nolint:errcheck
}
// AppendBodyString appends s to response body.
func (resp *Response) AppendBodyString(s string) {
resp.closeBodyStream() //nolint:errcheck
resp.bodyBuffer().WriteString(s) //nolint:errcheck
}
// SetBody sets response body.
//
// It is safe re-using body argument after the function returns.
func (resp *Response) SetBody(body []byte) {
resp.closeBodyStream() //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.Write(body) //nolint:errcheck
}
// SetBodyString sets response body.
func (resp *Response) SetBodyString(body string) {
resp.closeBodyStream() //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.WriteString(body) //nolint:errcheck
}
// ResetBody resets response body.
func (resp *Response) ResetBody() {
resp.bodyRaw = nil
resp.closeBodyStream() //nolint:errcheck
if resp.body != nil {
if resp.keepBodyBuffer {
resp.body.Reset()
} else {
responseBodyPool.Put(resp.body)
resp.body = nil
}
}
}
// SetBodyRaw sets response body, but without copying it.
//
// From this point onward the body argument must not be changed.
func (resp *Response) SetBodyRaw(body []byte) {
resp.ResetBody()
resp.bodyRaw = body
}
// SetBodyRaw sets response body, but without copying it.
//
// From this point onward the body argument must not be changed.
func (req *Request) SetBodyRaw(body []byte) {
req.ResetBody()
req.bodyRaw = body
}
// ReleaseBody retires the response body if it is greater than "size" bytes.
//
// This permits GC to reclaim the large buffer. If used, must be before
// ReleaseResponse.
//
// Use this method only if you really understand how it works.
// The majority of workloads don't need this method.
func (resp *Response) ReleaseBody(size int) {
resp.bodyRaw = nil
if cap(resp.body.B) > size {
resp.closeBodyStream() //nolint:errcheck
resp.body = nil
}
}
// ReleaseBody retires the request body if it is greater than "size" bytes.
//
// This permits GC to reclaim the large buffer. If used, must be before
// ReleaseRequest.
//
// Use this method only if you really understand how it works.
// The majority of workloads don't need this method.
func (req *Request) ReleaseBody(size int) {
req.bodyRaw = nil
if cap(req.body.B) > size {
req.closeBodyStream() //nolint:errcheck
req.body = nil
}
}
// SwapBody swaps response body with the given body and returns
// the previous response body.
//
// It is forbidden to use the body passed to SwapBody after
// the function returns.
func (resp *Response) SwapBody(body []byte) []byte {
bb := resp.bodyBuffer()
if resp.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
}
}
resp.bodyRaw = nil
oldBody := bb.B
bb.B = body
return oldBody
}
// SwapBody swaps request body with the given body and returns
// the previous request body.
//
// It is forbidden to use the body passed to SwapBody after
// the function returns.
func (req *Request) SwapBody(body []byte) []byte {
bb := req.bodyBuffer()
if req.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
}
}
req.bodyRaw = nil
oldBody := bb.B
bb.B = body
return oldBody
}
// Body returns request body.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (req *Request) Body() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
} else if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
return []byte(err.Error())
}
return body
}
return req.bodyBytes()
}
// AppendBody appends p to request body.
//
// It is safe re-using p after the function returns.
func (req *Request) AppendBody(p []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().Write(p) //nolint:errcheck
}
// AppendBodyString appends s to request body.
func (req *Request) AppendBodyString(s string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().WriteString(s) //nolint:errcheck
}
// SetBody sets request body.
//
// It is safe re-using body argument after the function returns.
func (req *Request) SetBody(body []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().Set(body)
}
// SetBodyString sets request body.
func (req *Request) SetBodyString(body string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().SetString(body)
}
// ResetBody resets request body.
func (req *Request) ResetBody() {
req.bodyRaw = nil
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
if req.body != nil {
if req.keepBodyBuffer {
req.body.Reset()
} else {
requestBodyPool.Put(req.body)
req.body = nil
}
}
}
// CopyTo copies req contents to dst except of body stream.
func (req *Request) CopyTo(dst *Request) {
req.copyToSkipBody(dst)
if req.bodyRaw != nil {
dst.bodyRaw = req.bodyRaw
if dst.body != nil {
dst.body.Reset()
}
} else if req.body != nil {
dst.bodyBuffer().Set(req.body.B)
} else if dst.body != nil {
dst.body.Reset()
}
}
func (req *Request) copyToSkipBody(dst *Request) {
dst.Reset()
req.Header.CopyTo(&dst.Header)
req.uri.CopyTo(&dst.uri)
dst.parsedURI = req.parsedURI
req.postArgs.CopyTo(&dst.postArgs)
dst.parsedPostArgs = req.parsedPostArgs
dst.isTLS = req.isTLS
// do not copy multipartForm - it will be automatically
// re-created on the first call to MultipartForm.
}
// CopyTo copies resp contents to dst except of body stream.
func (resp *Response) CopyTo(dst *Response) {
resp.copyToSkipBody(dst)
if resp.bodyRaw != nil {
dst.bodyRaw = resp.bodyRaw
if dst.body != nil {
dst.body.Reset()
}
} else if resp.body != nil {
dst.bodyBuffer().Set(resp.body.B)
} else if dst.body != nil {
dst.body.Reset()
}
}
func (resp *Response) copyToSkipBody(dst *Response) {
dst.Reset()
resp.Header.CopyTo(&dst.Header)
dst.SkipBody = resp.SkipBody
dst.raddr = resp.raddr
dst.laddr = resp.laddr
}
func swapRequestBody(a, b *Request) {
a.body, b.body = b.body, a.body
a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw
a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream
}
func swapResponseBody(a, b *Response) {
a.body, b.body = b.body, a.body
a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw
a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream
}
// URI returns request URI
func (req *Request) URI() *URI {
req.parseURI() //nolint:errcheck
return &req.uri
}
func (req *Request) parseURI() error {
if req.parsedURI {
return nil
}
req.parsedURI = true
return req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS)
}
// PostArgs returns POST arguments.
func (req *Request) PostArgs() *Args {
req.parsePostArgs()
return &req.postArgs
}
func (req *Request) parsePostArgs() {
if req.parsedPostArgs {
return
}
req.parsedPostArgs = true
if !bytes.HasPrefix(req.Header.ContentType(), strPostArgsContentType) {
return
}
req.postArgs.ParseBytes(req.bodyBytes())
}
// ErrNoMultipartForm means that the request's Content-Type
// isn't 'multipart/form-data'.
var ErrNoMultipartForm = errors.New("request has no multipart/form-data Content-Type")
// MultipartForm returns requests's multipart form.
//
// Returns ErrNoMultipartForm if request's Content-Type
// isn't 'multipart/form-data'.
//
// RemoveMultipartFormFiles must be called after returned multipart form
// is processed.
func (req *Request) MultipartForm() (*multipart.Form, error) {
if req.multipartForm != nil {
return req.multipartForm, nil
}
req.multipartFormBoundary = string(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) == 0 {
return nil, ErrNoMultipartForm
}
var err error
ce := req.Header.peek(strContentEncoding)
if req.bodyStream != nil {
bodyStream := req.bodyStream
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
req.multipartForm, err = mr.ReadForm(8 * 1024)
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
}
} else {
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
}
}
return req.multipartForm, nil
}
func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) {
var buf bytebufferpool.ByteBuffer
if err := WriteMultipartForm(&buf, f, boundary); err != nil {
return nil, err
}
return buf.B, nil
}
// WriteMultipartForm writes the given multipart form f with the given
// boundary to w.
func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error {
// Do not care about memory allocations here, since multipart
// form processing is slow.
if len(boundary) == 0 {
panic("BUG: form boundary cannot be empty")
}
mw := multipart.NewWriter(w)
if err := mw.SetBoundary(boundary); err != nil {
return fmt.Errorf("cannot use form boundary %q: %s", boundary, err)
}
// marshal values
for k, vv := range f.Value {
for _, v := range vv {
if err := mw.WriteField(k, v); err != nil {
return fmt.Errorf("cannot write form field %q value %q: %s", k, v, err)
}
}
}
// marshal files
for k, fvv := range f.File {
for _, fv := range fvv {
vw, err := mw.CreatePart(fv.Header)
if err != nil {
return fmt.Errorf("cannot create form file %q (%q): %s", k, fv.Filename, err)
}
fh, err := fv.Open()
if err != nil {
return fmt.Errorf("cannot open form file %q (%q): %s", k, fv.Filename, err)
}
if _, err = copyZeroAlloc(vw, fh); err != nil {
return fmt.Errorf("error when copying form file %q (%q): %s", k, fv.Filename, err)
}
if err = fh.Close(); err != nil {
return fmt.Errorf("cannot close form file %q (%q): %s", k, fv.Filename, err)
}
}
}
if err := mw.Close(); err != nil {
return fmt.Errorf("error when closing multipart form writer: %s", err)
}
return nil
}
func readMultipartForm(r io.Reader, boundary string, size, maxInMemoryFileSize int) (*multipart.Form, error) {
// Do not care about memory allocations here, since they are tiny
// compared to multipart data (aka multi-MB files) usually sent
// in multipart/form-data requests.
if size <= 0 {
return nil, fmt.Errorf("form size must be greater than 0. Given %d", size)
}
lr := io.LimitReader(r, int64(size))
mr := multipart.NewReader(lr, boundary)
f, err := mr.ReadForm(int64(maxInMemoryFileSize))
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
}
return f, nil
}
// Reset clears request contents.
func (req *Request) Reset() {
req.Header.Reset()
req.resetSkipHeader()
req.timeout = 0
}
func (req *Request) resetSkipHeader() {
req.ResetBody()
req.uri.Reset()
req.parsedURI = false
req.postArgs.Reset()
req.parsedPostArgs = false
req.isTLS = false
}
// RemoveMultipartFormFiles removes multipart/form-data temporary files
// associated with the request.
func (req *Request) RemoveMultipartFormFiles() {
if req.multipartForm != nil {
// Do not check for error, since these files may be deleted or moved
// to new places by user code.
req.multipartForm.RemoveAll() //nolint:errcheck
req.multipartForm = nil
}
req.multipartFormBoundary = ""
}
// Reset clears response contents.
func (resp *Response) Reset() {
resp.Header.Reset()
resp.resetSkipHeader()
resp.SkipBody = false
resp.raddr = nil
resp.laddr = nil
resp.ImmediateHeaderFlush = false
}
func (resp *Response) resetSkipHeader() {
resp.ResetBody()
}
// Read reads request (including body) from the given r.
//
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) Read(r *bufio.Reader) error {
return req.ReadLimitBody(r, 0)
}
const defaultMaxInMemoryFileSize = 16 * 1024 * 1024
// ErrGetOnly is returned when server expects only GET requests,
// but some other type of request came (Server.GetOnly option is true).
var ErrGetOnly = errors.New("non-GET request received")
// ReadLimitBody reads request from the given r, limiting the body size.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
//
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
req.resetSkipHeader()
if err := req.Header.Read(r); err != nil {
return err
}
return req.readLimitBody(r, maxBodySize, false, true)
}
func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}
return req.ContinueReadBody(r, maxBodySize, preParseMultipartForm)
}
func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}
return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm)
}
// MayContinue returns true if the request contains
// 'Expect: 100-continue' header.
//
// The caller must do one of the following actions if MayContinue returns true:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
func (req *Request) MayContinue() bool {
return bytes.Equal(req.Header.peek(strExpect), str100Continue)
}
// ContinueReadBody reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if maxBodySize > 0 && contentLength > maxBodySize {
return ErrBodyTooLarge
}
if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = string(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}
if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
// refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
if !req.Header.ignoreBody() {
req.Header.SetContentLength(0)
}
return nil
}
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
if err != nil {
req.Reset()
return err
}
req.Header.SetContentLength(len(bodyBuf.B))
return nil
}
// ContinueReadBody reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}
if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
req.Header.SetContentLength(0)
return nil
}
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
return nil
}
req.Reset()
return err
}
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
req.Header.SetContentLength(contentLength)
return nil
}
// Read reads response (including body) from the given r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (resp *Response) Read(r *bufio.Reader) error {
return resp.ReadLimitBody(r, 0)
}
// ReadLimitBody reads response from the given r, limiting the body size.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
resp.resetSkipHeader()
err := resp.Header.Read(r)
if err != nil {
return err
}
if resp.Header.StatusCode() == StatusContinue {
// Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html .
if err = resp.Header.Read(r); err != nil {
return err
}
}
if !resp.mustSkipBody() {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B)
if err != nil {
return err
}
resp.Header.SetContentLength(len(bodyBuf.B))
}
return nil
}
func (resp *Response) mustSkipBody() bool {
return resp.SkipBody || resp.Header.mustSkipContentLength()
}
var errRequestHostRequired = errors.New("missing required Host header in request")
// WriteTo writes request to w. It implements io.WriterTo.
func (req *Request) WriteTo(w io.Writer) (int64, error) {
return writeBufio(req, w)
}
// WriteTo writes response to w. It implements io.WriterTo.
func (resp *Response) WriteTo(w io.Writer) (int64, error) {
return writeBufio(resp, w)
}
func writeBufio(hw httpWriter, w io.Writer) (int64, error) {
sw := acquireStatsWriter(w)
bw := acquireBufioWriter(sw)
err1 := hw.Write(bw)
err2 := bw.Flush()
releaseBufioWriter(bw)
n := sw.bytesWritten
releaseStatsWriter(sw)
err := err1
if err == nil {
err = err2
}
return n, err
}
type statsWriter struct {
w io.Writer
bytesWritten int64
}
func (w *statsWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.bytesWritten += int64(n)
return n, err
}
func acquireStatsWriter(w io.Writer) *statsWriter {
v := statsWriterPool.Get()
if v == nil {
return &statsWriter{
w: w,
}
}
sw := v.(*statsWriter)
sw.w = w
return sw
}
func releaseStatsWriter(sw *statsWriter) {
sw.w = nil
sw.bytesWritten = 0
statsWriterPool.Put(sw)
}
var statsWriterPool sync.Pool
func acquireBufioWriter(w io.Writer) *bufio.Writer {
v := bufioWriterPool.Get()
if v == nil {
return bufio.NewWriter(w)
}
bw := v.(*bufio.Writer)
bw.Reset(w)
return bw
}
func releaseBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
var bufioWriterPool sync.Pool
func (req *Request) onlyMultipartForm() bool {
return req.multipartForm != nil && (req.body == nil || len(req.body.B) == 0)
}
// Write writes request to w.
//
// Write doesn't flush request to w for performance reasons.
//
// See also WriteTo.
func (req *Request) Write(w *bufio.Writer) error {
if len(req.Header.Host()) == 0 || req.parsedURI {
uri := req.URI()
host := uri.Host()
if len(host) == 0 {
return errRequestHostRequired
}
req.Header.SetHostBytes(host)
req.Header.SetRequestURIBytes(uri.RequestURI())
if len(uri.username) > 0 {
// RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key
// So we are free to use RequestHeader.bufKV.value as a scratch pad for
// the base64 encoding.
nl := len(uri.username) + len(uri.password) + 1
nb := nl + len(strBasicSpace)
tl := nb + base64.StdEncoding.EncodedLen(nl)
if tl > cap(req.Header.bufKV.value) {
req.Header.bufKV.value = make([]byte, 0, tl)
}
buf := req.Header.bufKV.value[:0]
buf = append(buf, uri.username...)
buf = append(buf, strColon...)
buf = append(buf, uri.password...)
buf = append(buf, strBasicSpace...)
base64.StdEncoding.Encode(buf[nb:tl], buf[:nl])
req.Header.SetBytesKV(strAuthorization, buf[nl:tl])
}
}
if req.bodyStream != nil {
return req.writeBodyStream(w)
}
body := req.bodyBytes()
var err error
if req.onlyMultipartForm() {
body, err = marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
return fmt.Errorf("error when marshaling multipart form: %s", err)
}
req.Header.SetMultipartFormBoundary(req.multipartFormBoundary)
}
hasBody := false
if len(body) == 0 {
body = req.postArgs.QueryString()
}
if len(body) != 0 || !req.Header.ignoreBody() {
hasBody = true
req.Header.SetContentLength(len(body))
}
if err = req.Header.Write(w); err != nil {
return err
}
if hasBody {
_, err = w.Write(body)
} else if len(body) > 0 {
if req.secureErrorLogMessage {
return fmt.Errorf("non-zero body for non-POST request")
}
return fmt.Errorf("non-zero body for non-POST request. body=%q", body)
}
return err
}
// WriteGzip writes response with gzipped body to w.
//
// The method gzips response body and sets 'Content-Encoding: gzip'
// header before writing response to w.
//
// WriteGzip doesn't flush response to w for performance reasons.
func (resp *Response) WriteGzip(w *bufio.Writer) error {
return resp.WriteGzipLevel(w, CompressDefaultCompression)
}
// WriteGzipLevel writes response with gzipped body to w.
//
// Level is the desired compression level:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
//
// The method gzips response body and sets 'Content-Encoding: gzip'
// header before writing response to w.
//
// WriteGzipLevel doesn't flush response to w for performance reasons.
func (resp *Response) WriteGzipLevel(w *bufio.Writer, level int) error {
if err := resp.gzipBody(level); err != nil {
return err
}
return resp.Write(w)
}
// WriteDeflate writes response with deflated body to w.
//
// The method deflates response body and sets 'Content-Encoding: deflate'
// header before writing response to w.
//
// WriteDeflate doesn't flush response to w for performance reasons.
func (resp *Response) WriteDeflate(w *bufio.Writer) error {
return resp.WriteDeflateLevel(w, CompressDefaultCompression)
}
// WriteDeflateLevel writes response with deflated body to w.
//
// Level is the desired compression level:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
//
// The method deflates response body and sets 'Content-Encoding: deflate'
// header before writing response to w.
//
// WriteDeflateLevel doesn't flush response to w for performance reasons.
func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error {
if err := resp.deflateBody(level); err != nil {
return err
}
return resp.Write(w)
}
func (resp *Response) brotliBody(level int) error {
if len(resp.Header.peek(strContentEncoding)) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return nil
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since brotli is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessBrotliWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
releaseStacklessBrotliWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return nil
}
w := responseBodyPool.Get()
w.B = AppendBrotliBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetCanonical(strContentEncoding, strBr)
return nil
}
func (resp *Response) gzipBody(level int) error {
if len(resp.Header.peek(strContentEncoding)) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return nil
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since gzip is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessGzipWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
releaseStacklessGzipWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return nil
}
w := responseBodyPool.Get()
w.B = AppendGzipBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetCanonical(strContentEncoding, strGzip)
return nil
}
func (resp *Response) deflateBody(level int) error {
if len(resp.Header.peek(strContentEncoding)) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return nil
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since flate is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessDeflateWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
releaseStacklessDeflateWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return nil
}
w := responseBodyPool.Get()
w.B = AppendDeflateBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetCanonical(strContentEncoding, strDeflate)
return nil
}
// Bodies with sizes smaller than minCompressLen aren't compressed at all
const minCompressLen = 200
type writeFlusher interface {
io.Writer
Flush() error
}
type flushWriter struct {
wf writeFlusher
bw *bufio.Writer
}
func (w *flushWriter) Write(p []byte) (int, error) {
n, err := w.wf.Write(p)
if err != nil {
return 0, err
}
if err = w.wf.Flush(); err != nil {
return 0, err
}
if err = w.bw.Flush(); err != nil {
return 0, err
}
return n, nil
}
// Write writes response to w.
//
// Write doesn't flush response to w for performance reasons.
//
// See also WriteTo.
func (resp *Response) Write(w *bufio.Writer) error {
sendBody := !resp.mustSkipBody()
if resp.bodyStream != nil {
return resp.writeBodyStream(w, sendBody)
}
body := resp.bodyBytes()
bodyLen := len(body)
if sendBody || bodyLen > 0 {
resp.Header.SetContentLength(bodyLen)
}
if err := resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if _, err := w.Write(body); err != nil {
return err
}
}
return nil
}
func (req *Request) writeBodyStream(w *bufio.Writer) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(req.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
req.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = req.Header.Write(w); err == nil {
err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength))
}
} else {
req.Header.SetContentLength(-1)
if err = req.Header.Write(w); err == nil {
err = writeBodyChunked(w, req.bodyStream)
}
}
err1 := req.closeBodyStream()
if err == nil {
err = err1
}
return err
}
// ErrBodyStreamWritePanic is returned when panic happens during writing body stream.
type ErrBodyStreamWritePanic struct {
error
}
func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error) {
defer func() {
if r := recover(); r != nil {
err = &ErrBodyStreamWritePanic{
error: fmt.Errorf("panic while writing body stream: %+v", r),
}
}
}()
contentLength := resp.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(resp.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
resp.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = resp.Header.Write(w); err == nil {
if resp.ImmediateHeaderFlush {
err = w.Flush()
}
if err == nil && sendBody {
err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength))
}
}
} else {
resp.Header.SetContentLength(-1)
if err = resp.Header.Write(w); err == nil {
if resp.ImmediateHeaderFlush {
err = w.Flush()
}
if err == nil && sendBody {
err = writeBodyChunked(w, resp.bodyStream)
}
}
}
err1 := resp.closeBodyStream()
if err == nil {
err = err1
}
return err
}
func (req *Request) closeBodyStream() error {
if req.bodyStream == nil {
return nil
}
var err error
if bsc, ok := req.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
req.bodyStream = nil
return err
}
func (resp *Response) closeBodyStream() error {
if resp.bodyStream == nil {
return nil
}
var err error
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
resp.bodyStream = nil
return err
}
// String returns request representation.
//
// Returns error message instead of request representation on error.
//
// Use Write instead of String for performance-critical code.
func (req *Request) String() string {
return getHTTPString(req)
}
// String returns response representation.
//
// Returns error message instead of response representation on error.
//
// Use Write instead of String for performance-critical code.
func (resp *Response) String() string {
return getHTTPString(resp)
}
func getHTTPString(hw httpWriter) string {
w := bytebufferpool.Get()
bw := bufio.NewWriter(w)
if err := hw.Write(bw); err != nil {
return err.Error()
}
if err := bw.Flush(); err != nil {
return err.Error()
}
s := string(w.B)
bytebufferpool.Put(w)
return s
}
type httpWriter interface {
Write(w *bufio.Writer) error
}
func writeBodyChunked(w *bufio.Writer, r io.Reader) error {
vbuf := copyBufPool.Get()
buf := vbuf.([]byte)
var err error
var n int
for {
n, err = r.Read(buf)
if n == 0 {
if err == nil {
panic("BUG: io.Reader returned 0, nil")
}
if err == io.EOF {
if err = writeChunk(w, buf[:0]); err != nil {
break
}
err = nil
}
break
}
if err = writeChunk(w, buf[:n]); err != nil {
break
}
}
copyBufPool.Put(vbuf)
return err
}
func limitedReaderSize(r io.Reader) int64 {
lr, ok := r.(*io.LimitedReader)
if !ok {
return -1
}
return lr.N
}
func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error {
if size > maxSmallFileSize {
// w buffer must be empty for triggering
// sendfile path in bufio.Writer.ReadFrom.
if err := w.Flush(); err != nil {
return err
}
}
n, err := copyZeroAlloc(w, r)
if n != size && err == nil {
err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size)
}
return err
}
func copyZeroAlloc(w io.Writer, r io.Reader) (int64, error) {
vbuf := copyBufPool.Get()
buf := vbuf.([]byte)
n, err := io.CopyBuffer(w, r, buf)
copyBufPool.Put(vbuf)
return n, err
}
var copyBufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 4096)
},
}
func writeChunk(w *bufio.Writer, b []byte) error {
n := len(b)
if err := writeHexInt(w, n); err != nil {
return err
}
if _, err := w.Write(strCRLF); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
_, err := w.Write(strCRLF)
err1 := w.Flush()
if err == nil {
err = err1
}
return err
}
// ErrBodyTooLarge is returned if either request or response body exceeds
// the given limit.
var ErrBodyTooLarge = errors.New("body size exceeds the given limit")
func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ([]byte, error) {
dst = dst[:0]
if contentLength >= 0 {
if maxBodySize > 0 && contentLength > maxBodySize {
return dst, ErrBodyTooLarge
}
return appendBodyFixedSize(r, dst, contentLength)
}
if contentLength == -1 {
return readBodyChunked(r, maxBodySize, dst)
}
return readBodyIdentity(r, maxBodySize, dst)
}
var errChunkedStream = errors.New("chunked stream")
func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) {
if contentLength == -1 {
// handled in requestStream.Read()
return b, errChunkedStream
}
dst = dst[:0]
readN := maxBodySize
if readN > contentLength {
readN = contentLength
}
if readN > 8*1024 {
readN = 8 * 1024
}
if contentLength >= 0 && maxBodySize >= contentLength {
b, err = appendBodyFixedSize(r, dst, readN)
} else {
b, err = readBodyIdentity(r, readN, dst)
}
if err != nil {
return b, err
}
if contentLength > maxBodySize {
return b, ErrBodyTooLarge
}
return b, nil
}
func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) {
dst = dst[:cap(dst)]
if len(dst) == 0 {
dst = make([]byte, 1024)
}
offset := 0
for {
nn, err := r.Read(dst[offset:])
if nn <= 0 {
if err != nil {
if err == io.EOF {
return dst[:offset], nil
}
return dst[:offset], err
}
panic(fmt.Sprintf("BUG: bufio.Read() returned (%d, nil)", nn))
}
offset += nn
if maxBodySize > 0 && offset > maxBodySize {
return dst[:offset], ErrBodyTooLarge
}
if len(dst) == offset {
n := round2(2 * offset)
if maxBodySize > 0 && n > maxBodySize {
n = maxBodySize + 1
}
b := make([]byte, n)
copy(b, dst)
dst = b
}
}
}
func appendBodyFixedSize(r *bufio.Reader, dst []byte, n int) ([]byte, error) {
if n == 0 {
return dst, nil
}
offset := len(dst)
dstLen := offset + n
if cap(dst) < dstLen {
b := make([]byte, round2(dstLen))
copy(b, dst)
dst = b
}
dst = dst[:dstLen]
for {
nn, err := r.Read(dst[offset:])
if nn <= 0 {
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return dst[:offset], err
}
panic(fmt.Sprintf("BUG: bufio.Read() returned (%d, nil)", nn))
}
offset += nn
if offset == dstLen {
return dst, nil
}
}
}
// ErrBrokenChunk is returned when server receives a broken chunked body (Transfer-Encoding: chunked).
type ErrBrokenChunk struct {
error
}
func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) {
if len(dst) > 0 {
panic("BUG: expected zero-length buffer")
}
strCRLFLen := len(strCRLF)
for {
chunkSize, err := parseChunkSize(r)
if err != nil {
return dst, err
}
if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize {
return dst, ErrBodyTooLarge
}
dst, err = appendBodyFixedSize(r, dst, chunkSize+strCRLFLen)
if err != nil {
return dst, err
}
if !bytes.Equal(dst[len(dst)-strCRLFLen:], strCRLF) {
return dst, ErrBrokenChunk{
error: fmt.Errorf("cannot find crlf at the end of chunk"),
}
}
dst = dst[:len(dst)-strCRLFLen]
if chunkSize == 0 {
return dst, nil
}
}
}
func parseChunkSize(r *bufio.Reader) (int, error) {
n, err := readHexInt(r)
if err != nil {
return -1, err
}
for {
c, err := r.ReadByte()
if err != nil {
return -1, ErrBrokenChunk{
error: fmt.Errorf("cannot read '\r' char at the end of chunk size: %s", err),
}
}
// Skip any trailing whitespace after chunk size.
if c == ' ' {
continue
}
if err := r.UnreadByte(); err != nil {
return -1, ErrBrokenChunk{
error: fmt.Errorf("cannot unread '\r' char at the end of chunk size: %s", err),
}
}
break
}
err = readCrLf(r)
if err != nil {
return -1, err
}
return n, nil
}
func readCrLf(r *bufio.Reader) error {
for _, exp := range []byte{'\r', '\n'} {
c, err := r.ReadByte()
if err != nil {
return ErrBrokenChunk{
error: fmt.Errorf("cannot read %q char at the end of chunk size: %s", exp, err),
}
}
if c != exp {
return ErrBrokenChunk{
error: fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, exp),
}
}
}
return nil
}
func round2(n int) int {
if n <= 0 {
return 0
}
x := uint32(n - 1)
x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16
return int(x + 1)
}
fasthttp-1.31.0/http_test.go 0000664 0000000 0000000 00000200647 14130360711 0016003 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/valyala/bytebufferpool"
)
func TestResponseEmptyTransferEncoding(t *testing.T) {
t.Parallel()
var r Response
body := "Some body"
br := bufio.NewReader(bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nTransfer-Encoding: \r\nContent-Length: 9\r\n\r\n" + body))
err := r.Read(br)
if err != nil {
t.Fatal(err)
}
if got := string(r.Body()); got != body {
t.Fatalf("expected %q got %q", body, got)
}
}
// Don't send the fragment/hash/# part of a URL to the server.
func TestFragmentInURIRequest(t *testing.T) {
t.Parallel()
var req Request
req.SetRequestURI("https://docs.gitlab.com/ee/user/project/integrations/webhooks.html#events")
var b bytes.Buffer
req.WriteTo(&b) //nolint:errcheck
got := b.String()
expected := "GET /ee/user/project/integrations/webhooks.html HTTP/1.1\r\nHost: docs.gitlab.com\r\n\r\n"
if got != expected {
t.Errorf("got %q expected %q", got, expected)
}
}
func TestIssue875(t *testing.T) {
t.Parallel()
type testcase struct {
uri string
expectedRedirect string
expectedLocation string
}
var testcases = []testcase{
{
uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0dSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\rSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
}
for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
t.Run(caseName, func(subT *testing.T) {
ctx := &RequestCtx{
Request: Request{},
Response: Response{},
}
ctx.Request.SetRequestURI(tcase.uri)
q := string(ctx.QueryArgs().Peek("redirect"))
if q != tcase.expectedRedirect {
subT.Errorf("unexpected redirect query value, got: %+v", q)
}
ctx.Response.Header.Set("Location", q)
if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) {
subT.Errorf("invalid escaping, got\n%s", ctx.Response.String())
}
})
}
}
func TestRequestCopyTo(t *testing.T) {
t.Parallel()
var req Request
// empty copy
testRequestCopyTo(t, &req)
// init
expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8"
expectedHost := "test.com"
expectedBody := "0123=56789"
s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: %s\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s",
expectedHost, expectedContentType, len(expectedBody), expectedBody)
br := bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
testRequestCopyTo(t, &req)
}
func TestResponseCopyTo(t *testing.T) {
t.Parallel()
var resp Response
// empty copy
testResponseCopyTo(t, &resp)
// init resp
resp.laddr = zeroTCPAddr
resp.SkipBody = true
resp.Header.SetStatusCode(200)
resp.SetBodyString("test")
testResponseCopyTo(t, &resp)
}
func testRequestCopyTo(t *testing.T, src *Request) {
var dst Request
src.CopyTo(&dst)
if !reflect.DeepEqual(*src, dst) { //nolint:govet
t.Fatalf("RequestCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet
}
}
func testResponseCopyTo(t *testing.T, src *Response) {
var dst Response
src.CopyTo(&dst)
if !reflect.DeepEqual(*src, dst) { //nolint:govet
t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet
}
}
func TestResponseBodyStreamDeflate(t *testing.T) {
t.Parallel()
body := createFixedBody(1e5)
// Verifies https://github.com/valyala/fasthttp/issues/176
// when Content-Length is explicitly set.
testResponseBodyStreamDeflate(t, body, len(body))
// Verifies that 'transfer-encoding: chunked' works as expected.
testResponseBodyStreamDeflate(t, body, -1)
}
func TestResponseBodyStreamGzip(t *testing.T) {
t.Parallel()
body := createFixedBody(1e5)
// Verifies https://github.com/valyala/fasthttp/issues/176
// when Content-Length is explicitly set.
testResponseBodyStreamGzip(t, body, len(body))
// Verifies that 'transfer-encoding: chunked' works as expected.
testResponseBodyStreamGzip(t, body, -1)
}
func testResponseBodyStreamDeflate(t *testing.T, body []byte, bodySize int) {
var r Response
r.SetBodyStream(bytes.NewReader(body), bodySize)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteDeflate(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
respBody, err := resp.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(respBody, body) {
t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
}
}
func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) {
var r Response
r.SetBodyStream(bytes.NewReader(body), bodySize)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteGzip(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
respBody, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !bytes.Equal(respBody, body) {
t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
}
}
func TestResponseWriteGzipNilBody(t *testing.T) {
t.Parallel()
var r Response
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteGzip(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestResponseWriteDeflateNilBody(t *testing.T) {
t.Parallel()
var r Response
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteDeflate(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestResponseSwapBodySerial(t *testing.T) {
t.Parallel()
testResponseSwapBody(t)
}
func TestResponseSwapBodyConcurrent(t *testing.T) {
t.Parallel()
ch := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
testResponseSwapBody(t)
ch <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testResponseSwapBody(t *testing.T) {
var b []byte
r := AcquireResponse()
for i := 0; i < 20; i++ {
bOrig := r.Body()
b = r.SwapBody(b)
if !bytes.Equal(bOrig, b) {
t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig)
}
r.AppendBodyString("foobar")
}
s := "aaaabbbbcccc"
b = b[:0]
for i := 0; i < 10; i++ {
r.SetBodyStream(bytes.NewBufferString(s), len(s))
b = r.SwapBody(b)
if string(b) != s {
t.Fatalf("unexpected body returned: %q. Expecting %q", b, s)
}
b = r.SwapBody(b)
if len(b) > 0 {
t.Fatalf("unexpected body with non-zero size returned: %q", b)
}
}
ReleaseResponse(r)
}
func TestRequestSwapBodySerial(t *testing.T) {
t.Parallel()
testRequestSwapBody(t)
}
func TestRequestSwapBodyConcurrent(t *testing.T) {
t.Parallel()
ch := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
testRequestSwapBody(t)
ch <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testRequestSwapBody(t *testing.T) {
var b []byte
r := AcquireRequest()
for i := 0; i < 20; i++ {
bOrig := r.Body()
b = r.SwapBody(b)
if !bytes.Equal(bOrig, b) {
t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig)
}
r.AppendBodyString("foobar")
}
s := "aaaabbbbcccc"
b = b[:0]
for i := 0; i < 10; i++ {
r.SetBodyStream(bytes.NewBufferString(s), len(s))
b = r.SwapBody(b)
if string(b) != s {
t.Fatalf("unexpected body returned: %q. Expecting %q", b, s)
}
b = r.SwapBody(b)
if len(b) > 0 {
t.Fatalf("unexpected body with non-zero size returned: %q", b)
}
}
ReleaseRequest(r)
}
func TestRequestHostFromRequestURI(t *testing.T) {
t.Parallel()
hExpected := "foobar.com"
var req Request
req.SetRequestURI("http://proxy-host:123/foobar?baz")
req.SetHost(hExpected)
h := req.Host()
if string(h) != hExpected {
t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected)
}
}
func TestRequestHostFromHeader(t *testing.T) {
t.Parallel()
hExpected := "foobar.com"
var req Request
req.Header.SetHost(hExpected)
h := req.Host()
if string(h) != hExpected {
t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected)
}
}
func TestRequestContentTypeWithCharsetIssue100(t *testing.T) {
t.Parallel()
expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8"
expectedBody := "0123=56789"
s := fmt.Sprintf("POST / HTTP/1.1\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s",
expectedContentType, len(expectedBody), expectedBody)
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
body := r.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
ct := r.Header.ContentType()
if string(ct) != expectedContentType {
t.Fatalf("unexpected content-type %q. Expecting %q", ct, expectedContentType)
}
args := r.PostArgs()
if args.Len() != 1 {
t.Fatalf("unexpected number of POST args: %d. Expecting 1", args.Len())
}
av := args.Peek("0123")
if string(av) != "56789" {
t.Fatalf("unexpected POST arg value: %q. Expecting %q", av, "56789")
}
}
func TestRequestReadMultipartFormWithFile(t *testing.T) {
t.Parallel()
s := `POST /upload HTTP/1.1
Host: localhost:10000
Content-Length: 521
Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="f1"
value1
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="fileaaa"; filename="TODO"
Content-Type: application/octet-stream
- SessionClient with referer and cookies support.
- Client with requests' pipelining support.
- ProxyHandler similar to FSHandler.
- WebSockets. See https://tools.ietf.org/html/rfc6455 .
- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 .
------WebKitFormBoundaryJwfATyF8tmxSJnLg--
tailfoobar`
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(tail) != "tailfoobar" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar")
}
f, err := r.MultipartForm()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer r.RemoveMultipartFormFiles()
// verify values
if len(f.Value) != 1 {
t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value))
}
for k, vv := range f.Value {
if k != "f1" {
t.Fatalf("unexpected value name %q. Expecting %q", k, "f1")
}
if len(vv) != 1 {
t.Fatalf("unexpected number of values %d. Expecting 1", len(vv))
}
v := vv[0]
if v != "value1" {
t.Fatalf("unexpected value %q. Expecting %q", v, "value1")
}
}
// verify files
if len(f.File) != 1 {
t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File))
}
for k, vv := range f.File {
if k != "fileaaa" {
t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa")
}
if len(vv) != 1 {
t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv))
}
v := vv[0]
if v.Filename != "TODO" {
t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO")
}
ct := v.Header.Get("Content-Type")
if ct != "application/octet-stream" {
t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream")
}
}
}
func TestRequestRequestURI(t *testing.T) {
t.Parallel()
var r Request
// Set request uri via SetRequestURI()
uri := "/foo/bar?baz"
r.SetRequestURI(uri)
if string(r.RequestURI()) != uri {
t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri)
}
// Set request uri via Request.URI().Update()
r.Reset()
uri = "/aa/bbb?ccc=sdfsdf"
r.URI().Update(uri)
if string(r.RequestURI()) != uri {
t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri)
}
// update query args in the request uri
qa := r.URI().QueryArgs()
qa.Reset()
qa.Set("foo", "bar")
uri = "/aa/bbb?foo=bar"
if string(r.RequestURI()) != uri {
t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri)
}
}
func TestRequestUpdateURI(t *testing.T) {
t.Parallel()
var r Request
r.Header.SetHost("aaa.bbb")
r.SetRequestURI("/lkjkl/kjl")
// Modify request uri and host via URI() object and make sure
// the requestURI and Host header are properly updated
u := r.URI()
u.SetPath("/123/432.html")
u.SetHost("foobar.com")
a := u.QueryArgs()
a.Set("aaa", "bcse")
s := r.String()
if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") {
t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s)
}
if !strings.Contains(s, "\r\nHost: foobar.com\r\n") {
t.Fatalf("cannot find %q in %q", "\r\nHost: foobar.com\r\n", s)
}
}
func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) {
t.Parallel()
var r Request
s := "foobar baz abc"
if r.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
r.SetBodyStream(bytes.NewBufferString(s), len(s))
if !r.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
for i := 0; i < 10; i++ {
body := r.Body()
if string(body) != s {
t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i)
}
}
}
func TestResponseBodyStreamMultipleBodyCalls(t *testing.T) {
t.Parallel()
var r Response
s := "foobar baz abc"
if r.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
r.SetBodyStream(bytes.NewBufferString(s), len(s))
if !r.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
for i := 0; i < 10; i++ {
body := r.Body()
if string(body) != s {
t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i)
}
}
}
func TestRequestBodyWriteToPlain(t *testing.T) {
t.Parallel()
var r Request
expectedS := "foobarbaz"
r.AppendBodyString(expectedS)
testBodyWriteTo(t, &r, expectedS, true)
}
func TestResponseBodyWriteToPlain(t *testing.T) {
t.Parallel()
var r Response
expectedS := "foobarbaz"
r.AppendBodyString(expectedS)
testBodyWriteTo(t, &r, expectedS, true)
}
func TestResponseBodyWriteToStream(t *testing.T) {
t.Parallel()
var r Response
expectedS := "aaabbbccc"
buf := bytes.NewBufferString(expectedS)
if r.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
r.SetBodyStream(buf, len(expectedS))
if !r.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
testBodyWriteTo(t, &r, expectedS, false)
}
func TestRequestBodyWriteToMultipart(t *testing.T) {
t.Parallel()
expectedS := "--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar--\r\n"
s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: %d\r\n\r\n%s",
len(expectedS), expectedS)
var r Request
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
testBodyWriteTo(t, &r, expectedS, true)
}
type bodyWriterTo interface {
BodyWriteTo(io.Writer) error
Body() []byte
}
func testBodyWriteTo(t *testing.T, bw bodyWriterTo, expectedS string, isRetainedBody bool) {
var buf bytebufferpool.ByteBuffer
if err := bw.BodyWriteTo(&buf); err != nil {
t.Fatalf("unexpected error: %s", err)
}
s := buf.B
if string(s) != expectedS {
t.Fatalf("unexpected result %q. Expecting %q", s, expectedS)
}
body := bw.Body()
if isRetainedBody {
if string(body) != expectedS {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedS)
}
} else {
if len(body) > 0 {
t.Fatalf("unexpected non-zero body after BodyWriteTo: %q", body)
}
}
}
func TestRequestReadEOF(t *testing.T) {
t.Parallel()
var r Request
br := bufio.NewReader(&bytes.Buffer{})
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
// incomplete request mustn't return io.EOF
br = bufio.NewReader(bytes.NewBufferString("POST / HTTP/1.1\r\nContent-Type: aa\r\nContent-Length: 1234\r\n\r\nIncomplete body"))
err = r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("expecting non-EOF error")
}
}
func TestResponseReadEOF(t *testing.T) {
t.Parallel()
var r Response
br := bufio.NewReader(&bytes.Buffer{})
err := r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err != io.EOF {
t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
}
// incomplete response mustn't return io.EOF
br = bufio.NewReader(bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\nIncomplete body"))
err = r.Read(br)
if err == nil {
t.Fatalf("expecting error")
}
if err == io.EOF {
t.Fatalf("expecting non-EOF error")
}
}
func TestRequestReadNoBody(t *testing.T) {
t.Parallel()
var r Request
br := bufio.NewReader(bytes.NewBufferString("GET / HTTP/1.1\r\n\r\n"))
err := r.Read(br)
r.SetHost("foobar")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
s := r.String()
if strings.Contains(s, "Content-Length: ") {
t.Fatalf("unexpected Content-Length")
}
}
func TestResponseWriteTo(t *testing.T) {
t.Parallel()
var r Response
r.SetBodyString("foobar")
s := r.String()
var buf bytebufferpool.ByteBuffer
n, err := r.WriteTo(&buf)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected response length %d. Expecting %d", n, len(s))
}
if string(buf.B) != s {
t.Fatalf("unexpected response %q. Expecting %q", buf.B, s)
}
}
func TestRequestWriteTo(t *testing.T) {
t.Parallel()
var r Request
r.SetRequestURI("http://foobar.com/aaa/bbb")
s := r.String()
var buf bytebufferpool.ByteBuffer
n, err := r.WriteTo(&buf)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected request length %d. Expecting %d", n, len(s))
}
if string(buf.B) != s {
t.Fatalf("unexpected request %q. Expecting %q", buf.B, s)
}
}
func TestResponseSkipBody(t *testing.T) {
t.Parallel()
var r Response
// set StatusNotModified
r.Header.SetStatusCode(StatusNotModified)
r.SetBodyString("foobar")
s := r.String()
if strings.Contains(s, "\r\n\r\nfoobar") {
t.Fatalf("unexpected non-zero body in response %q", s)
}
if strings.Contains(s, "Content-Length: ") {
t.Fatalf("unexpected content-length in response %q", s)
}
if strings.Contains(s, "Content-Type: ") {
t.Fatalf("unexpected content-type in response %q", s)
}
// set StatusNoContent
r.Header.SetStatusCode(StatusNoContent)
r.SetBodyString("foobar")
s = r.String()
if strings.Contains(s, "\r\n\r\nfoobar") {
t.Fatalf("unexpected non-zero body in response %q", s)
}
if strings.Contains(s, "Content-Length: ") {
t.Fatalf("unexpected content-length in response %q", s)
}
if strings.Contains(s, "Content-Type: ") {
t.Fatalf("unexpected content-type in response %q", s)
}
// explicitly skip body
r.Header.SetStatusCode(StatusOK)
r.SkipBody = true
r.SetBodyString("foobar")
s = r.String()
if strings.Contains(s, "\r\n\r\nfoobar") {
t.Fatalf("unexpected non-zero body in response %q", s)
}
if !strings.Contains(s, "Content-Length: 6\r\n") {
t.Fatalf("expecting content-length in response %q", s)
}
if !strings.Contains(s, "Content-Type: ") {
t.Fatalf("expecting content-type in response %q", s)
}
}
func TestRequestNoContentLength(t *testing.T) {
t.Parallel()
var r Request
r.Header.SetMethod(MethodHead)
r.Header.SetHost("foobar")
s := r.String()
if strings.Contains(s, "Content-Length: ") {
t.Fatalf("unexpected content-length in HEAD request %q", s)
}
r.Header.SetMethod(MethodPost)
fmt.Fprintf(r.BodyWriter(), "foobar body")
s = r.String()
if !strings.Contains(s, "Content-Length: ") {
t.Fatalf("missing content-length header in non-GET request %q", s)
}
}
func TestRequestReadGzippedBody(t *testing.T) {
t.Parallel()
var r Request
bodyOriginal := "foo bar baz compress me better!"
body := AppendGzipBytes(nil, []byte(bodyOriginal))
s := fmt.Sprintf("POST /foobar HTTP/1.1\r\nContent-Type: foo/bar\r\nContent-Encoding: gzip\r\nContent-Length: %d\r\n\r\n%s",
len(body), body)
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(r.Header.Peek(HeaderContentEncoding)) != "gzip" {
t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek(HeaderContentEncoding), "gzip")
}
if r.Header.ContentLength() != len(body) {
t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body))
}
if string(r.Body()) != string(body) {
t.Fatalf("unexpected body: %q. Expecting %q", r.Body(), body)
}
bodyGunzipped, err := AppendGunzipBytes(nil, r.Body())
if err != nil {
t.Fatalf("unexpected error when uncompressing data: %s", err)
}
if string(bodyGunzipped) != bodyOriginal {
t.Fatalf("unexpected uncompressed body %q. Expecting %q", bodyGunzipped, bodyOriginal)
}
}
func TestRequestReadPostNoBody(t *testing.T) {
t.Parallel()
var r Request
s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(r.Header.RequestURI()) != "/foo/bar" {
t.Fatalf("unexpected request uri %q. Expecting %q", r.Header.RequestURI(), "/foo/bar")
}
if string(r.Header.ContentType()) != "aaa/bbb" {
t.Fatalf("unexpected content-type %q. Expecting %q", r.Header.ContentType(), "aaa/bbb")
}
if len(r.Body()) != 0 {
t.Fatalf("unexpected body found %q. Expecting empty body", r.Body())
}
if r.Header.ContentLength() != 0 {
t.Fatalf("unexpected content-length: %d. Expecting 0", r.Header.ContentLength())
}
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(tail) != "aaaa" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "aaaa")
}
}
func TestRequestContinueReadBody(t *testing.T) {
t.Parallel()
s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !r.MayContinue() {
t.Fatalf("MayContinue must return true")
}
if err := r.ContinueReadBody(br, 0, true); err != nil {
t.Fatalf("error when reading request body: %s", err)
}
body := r.Body()
if string(body) != "abcde" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcde")
}
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(tail) != "f4343" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343")
}
}
func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) {
t.Parallel()
var w bytes.Buffer
mw := multipart.NewWriter(&w)
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
if err := mw.WriteField(k, v); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
boundary := mw.Boundary()
if err := mw.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
formData := w.Bytes()
s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s",
boundary, len(formData), formData)
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Header.Read(br); err != nil {
t.Fatalf("unexpected error reading headers: %s", err)
}
if err := r.readLimitBody(br, 10000, false, false); err != nil {
t.Fatalf("unexpected error reading body: %s", err)
}
if r.multipartForm != nil {
t.Fatalf("The multipartForm of the Request must be nil")
}
if string(formData) != string(r.Body()) {
t.Fatalf("The body given must equal the body in the Request")
}
}
func TestRequestMayContinue(t *testing.T) {
t.Parallel()
var r Request
if r.MayContinue() {
t.Fatalf("MayContinue on empty request must return false")
}
r.Header.Set("Expect", "123sdfds")
if r.MayContinue() {
t.Fatalf("MayContinue on invalid Expect header must return false")
}
r.Header.Set("Expect", "100-continue")
if !r.MayContinue() {
t.Fatalf("MayContinue on 'Expect: 100-continue' header must return true")
}
}
func TestResponseGzipStream(t *testing.T) {
t.Parallel()
var r Response
if r.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
r.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "foo")
w.Flush()
time.Sleep(time.Millisecond)
w.Write([]byte("barbaz")) //nolint:errcheck
w.Flush() //nolint:errcheck
time.Sleep(time.Millisecond)
fmt.Fprintf(w, "1234") //nolint:errcheck
if err := w.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
})
if !r.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
testResponseGzipExt(t, &r, "foobarbaz1234")
}
func TestResponseDeflateStream(t *testing.T) {
t.Parallel()
var r Response
if r.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
r.SetBodyStreamWriter(func(w *bufio.Writer) {
w.Write([]byte("foo")) //nolint:errcheck
w.Flush() //nolint:errcheck
fmt.Fprintf(w, "barbaz") //nolint:errcheck
w.Flush() //nolint:errcheck
w.Write([]byte("1234")) //nolint:errcheck
if err := w.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
})
if !r.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
testResponseDeflateExt(t, &r, "foobarbaz1234")
}
func TestResponseDeflate(t *testing.T) {
t.Parallel()
for _, s := range compressTestcases {
testResponseDeflate(t, s)
}
}
func TestResponseGzip(t *testing.T) {
t.Parallel()
for _, s := range compressTestcases {
testResponseGzip(t, s)
}
}
func testResponseDeflate(t *testing.T, s string) {
var r Response
r.SetBodyString(s)
testResponseDeflateExt(t, &r, s)
// make sure the uncompressible Content-Type isn't compressed
r.Reset()
r.Header.SetContentType("image/jpeg")
r.SetBodyString(s)
testResponseDeflateExt(t, &r, s)
}
func testResponseDeflateExt(t *testing.T, r *Response, s string) {
isCompressible := isCompressibleResponse(r, s)
var buf bytes.Buffer
var err error
bw := bufio.NewWriter(&buf)
if err = r.WriteDeflate(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err = bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var r1 Response
br := bufio.NewReader(&buf)
if err = r1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := r1.Header.Peek(HeaderContentEncoding)
var body []byte
if isCompressible {
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding %q. Expecting %q. len(s)=%d, Content-Type: %q",
ce, "deflate", len(s), r.Header.ContentType())
}
body, err = r1.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
} else {
if len(ce) > 0 {
t.Fatalf("expecting empty Content-Encoding. Got %q", ce)
}
body = r1.Body()
}
if string(body) != s {
t.Fatalf("unexpected body %q. Expecting %q", body, s)
}
}
func testResponseGzip(t *testing.T, s string) {
var r Response
r.SetBodyString(s)
testResponseGzipExt(t, &r, s)
// make sure the uncompressible Content-Type isn't compressed
r.Reset()
r.Header.SetContentType("image/jpeg")
r.SetBodyString(s)
testResponseGzipExt(t, &r, s)
}
func testResponseGzipExt(t *testing.T, r *Response, s string) {
isCompressible := isCompressibleResponse(r, s)
var buf bytes.Buffer
var err error
bw := bufio.NewWriter(&buf)
if err = r.WriteGzip(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err = bw.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var r1 Response
br := bufio.NewReader(&buf)
if err = r1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := r1.Header.Peek(HeaderContentEncoding)
var body []byte
if isCompressible {
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding %q. Expecting %q. len(s)=%d, Content-Type: %q",
ce, "gzip", len(s), r.Header.ContentType())
}
body, err = r1.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
} else {
if len(ce) > 0 {
t.Fatalf("Expecting empty Content-Encoding. Got %q", ce)
}
body = r1.Body()
}
if string(body) != s {
t.Fatalf("unexpected body %q. Expecting %q", body, s)
}
}
func isCompressibleResponse(r *Response, s string) bool {
isCompressible := r.Header.isCompressibleContentType()
if isCompressible && len(s) < minCompressLen && !r.IsBodyStream() {
isCompressible = false
}
return isCompressible
}
func TestRequestMultipartForm(t *testing.T) {
t.Parallel()
var w bytes.Buffer
mw := multipart.NewWriter(&w)
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
if err := mw.WriteField(k, v); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
boundary := mw.Boundary()
if err := mw.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
formData := w.Bytes()
for i := 0; i < 5; i++ {
formData = testRequestMultipartForm(t, boundary, formData, 10)
}
// verify request unmarshalling / marshalling
s := "POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: 213\r\n\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_1\"\r\n\r\nvalue_1\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_2\"\r\n\r\nvalue_2\r\n--foobar--\r\n"
var req Request
br := bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
s = req.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
testRequestMultipartForm(t, "foobar", req.Body(), 3)
}
func testRequestMultipartForm(t *testing.T, boundary string, formData []byte, partsCount int) []byte {
s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s",
boundary, len(formData), formData)
var req Request
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := req.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
f, err := req.MultipartForm()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer req.RemoveMultipartFormFiles()
if len(f.File) > 0 {
t.Fatalf("unexpected files found in the multipart form: %d", len(f.File))
}
if len(f.Value) != partsCount {
t.Fatalf("unexpected number of values found: %d. Expecting %d", len(f.Value), partsCount)
}
for k, vv := range f.Value {
if len(vv) != 1 {
t.Fatalf("unexpected number of values found for key=%q: %d. Expecting 1", k, len(vv))
}
if !strings.HasPrefix(k, "key_") {
t.Fatalf("unexpected key prefix=%q. Expecting %q", k, "key_")
}
v := vv[0]
if !strings.HasPrefix(v, "value_") {
t.Fatalf("unexpected value prefix=%q. expecting %q", v, "value_")
}
if k[len("key_"):] != v[len("value_"):] {
t.Fatalf("key and value suffixes don't match: %q vs %q", k, v)
}
}
return req.Body()
}
func TestResponseReadLimitBody(t *testing.T) {
t.Parallel()
// response with content-length
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100)
testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9)
// chunked response
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 100)
testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 2)
// identity response
testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 6)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 106)
testResponseReadLimitBodyError(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 5)
}
func TestRequestReadLimitBody(t *testing.T) {
t.Parallel()
// request with content-length
testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 9)
testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 92)
testRequestReadLimitBodyError(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 5)
// chunked request
testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9)
testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 999)
testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8)
}
func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
var req Response
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
err := req.ReadLimitBody(br, maxBodySize)
if err == nil {
t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize)
}
if err != ErrBodyTooLarge {
t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize)
}
}
func testResponseReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) {
var req Response
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := req.ReadLimitBody(br, maxBodySize); err != nil {
t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize)
}
}
func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
var req Request
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
err := req.ReadLimitBody(br, maxBodySize)
if err == nil {
t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize)
}
if err != ErrBodyTooLarge {
t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize)
}
}
func testRequestReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) {
var req Request
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := req.ReadLimitBody(br, maxBodySize); err != nil {
t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize)
}
}
func TestRequestString(t *testing.T) {
t.Parallel()
var r Request
r.SetRequestURI("http://foobar.com/aaa")
s := r.String()
expectedS := "GET /aaa HTTP/1.1\r\nHost: foobar.com\r\n\r\n"
if s != expectedS {
t.Fatalf("unexpected request: %q. Expecting %q", s, expectedS)
}
}
func TestRequestBodyWriter(t *testing.T) {
var r Request
w := r.BodyWriter()
for i := 0; i < 10; i++ {
fmt.Fprintf(w, "%d", i)
}
if string(r.Body()) != "0123456789" {
t.Fatalf("unexpected body %q. Expecting %q", r.Body(), "0123456789")
}
}
func TestResponseBodyWriter(t *testing.T) {
t.Parallel()
var r Response
w := r.BodyWriter()
for i := 0; i < 10; i++ {
fmt.Fprintf(w, "%d", i)
}
if string(r.Body()) != "0123456789" {
t.Fatalf("unexpected body %q. Expecting %q", r.Body(), "0123456789")
}
}
func TestRequestWriteRequestURINoHost(t *testing.T) {
t.Parallel()
var req Request
req.Header.SetRequestURI("http://google.com/foo/bar?baz=aaa")
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := req.Write(bw); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexepcted error: %s", err)
}
var req1 Request
br := bufio.NewReader(&w)
if err := req1.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(req1.Header.Host()) != "google.com" {
t.Fatalf("unexpected host: %q. Expecting %q", req1.Header.Host(), "google.com")
}
if string(req.Header.RequestURI()) != "/foo/bar?baz=aaa" {
t.Fatalf("unexpected requestURI: %q. Expecting %q", req.Header.RequestURI(), "/foo/bar?baz=aaa")
}
// verify that Request.Write returns error on non-absolute RequestURI
req.Reset()
req.Header.SetRequestURI("/foo/bar")
w.Reset()
bw.Reset(&w)
if err := req.Write(bw); err == nil {
t.Fatalf("expecting error")
}
}
func TestSetRequestBodyStreamFixedSize(t *testing.T) {
t.Parallel()
testSetRequestBodyStream(t, "a", false)
testSetRequestBodyStream(t, string(createFixedBody(4097)), false)
testSetRequestBodyStream(t, string(createFixedBody(100500)), false)
}
func TestSetResponseBodyStreamFixedSize(t *testing.T) {
t.Parallel()
testSetResponseBodyStream(t, "a", false)
testSetResponseBodyStream(t, string(createFixedBody(4097)), false)
testSetResponseBodyStream(t, string(createFixedBody(100500)), false)
}
func TestSetRequestBodyStreamChunked(t *testing.T) {
t.Parallel()
testSetRequestBodyStream(t, "", true)
body := "foobar baz aaa bbb ccc"
testSetRequestBodyStream(t, body, true)
body = string(createFixedBody(10001))
testSetRequestBodyStream(t, body, true)
}
func TestSetResponseBodyStreamChunked(t *testing.T) {
t.Parallel()
testSetResponseBodyStream(t, "", true)
body := "foobar baz aaa bbb ccc"
testSetResponseBodyStream(t, body, true)
body = string(createFixedBody(10001))
testSetResponseBodyStream(t, body, true)
}
func testSetRequestBodyStream(t *testing.T, body string, chunked bool) {
var req Request
req.Header.SetHost("foobar.com")
req.Header.SetMethod(MethodPost)
bodySize := len(body)
if chunked {
bodySize = -1
}
if req.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
req.SetBodyStream(bytes.NewBufferString(body), bodySize)
if !req.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := req.Write(bw); err != nil {
t.Fatalf("unexpected error when writing request: %s. body=%q", err, body)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body)
}
var req1 Request
br := bufio.NewReader(&w)
if err := req1.Read(br); err != nil {
t.Fatalf("unexpected error when reading request: %s. body=%q", err, body)
}
if string(req1.Body()) != body {
t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body)
}
}
func testSetResponseBodyStream(t *testing.T, body string, chunked bool) {
var resp Response
bodySize := len(body)
if chunked {
bodySize = -1
}
if resp.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
resp.SetBodyStream(bytes.NewBufferString(body), bodySize)
if !resp.IsBodyStream() {
t.Fatalf("IsBodyStream must return true")
}
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := resp.Write(bw); err != nil {
t.Fatalf("unexpected error when writing response: %s. body=%q", err, body)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error when flushing response: %s. body=%q", err, body)
}
var resp1 Response
br := bufio.NewReader(&w)
if err := resp1.Read(br); err != nil {
t.Fatalf("unexpected error when reading response: %s. body=%q", err, body)
}
if string(resp1.Body()) != body {
t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body)
}
}
func TestRound2(t *testing.T) {
t.Parallel()
testRound2(t, 0, 0)
testRound2(t, 1, 1)
testRound2(t, 2, 2)
testRound2(t, 3, 4)
testRound2(t, 4, 4)
testRound2(t, 5, 8)
testRound2(t, 7, 8)
testRound2(t, 8, 8)
testRound2(t, 9, 16)
testRound2(t, 0x10001, 0x20000)
}
func testRound2(t *testing.T, n, expectedRound2 int) {
if round2(n) != expectedRound2 {
t.Fatalf("Unexpected round2(%d)=%d. Expected %d", n, round2(n), expectedRound2)
}
}
func TestRequestReadChunked(t *testing.T) {
t.Parallel()
var req Request
s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail"
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
err := req.Read(rb)
if err != nil {
t.Fatalf("Unexpected error when reading chunked request: %s", err)
}
expectedBody := "abc12345"
if string(req.Body()) != expectedBody {
t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody)
}
verifyRequestHeader(t, &req.Header, 8, "/foo", "google.com", "", "aa/bb")
verifyTrailer(t, rb, "trail")
}
// See: https://github.com/erikdubbelboer/fasthttp/issues/34
func TestRequestChunkedWhitespace(t *testing.T) {
t.Parallel()
var req Request
s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3 \r\nabc\r\n0\r\n\r\n"
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
err := req.Read(rb)
if err != nil {
t.Fatalf("Unexpected error when reading chunked request: %s", err)
}
expectedBody := "abc"
if string(req.Body()) != expectedBody {
t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody)
}
}
func TestResponseReadWithoutBody(t *testing.T) {
t.Parallel()
var resp Response
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\nfoobar", false,
304, 1235, "aa", "foobar")
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false,
204, -1, "aab", "123\r\nss")
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false,
123, 3434, "xxx", "aaaa")
testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true,
200, 123, "text/xml", "xxxx")
// '100 Continue' must be skipped.
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true,
329, 894, "qwe", "foobar")
}
func testResponseReadWithoutBody(t *testing.T, resp *Response, s string, skipBody bool,
expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string) {
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
resp.SkipBody = skipBody
err := resp.Read(rb)
if err != nil {
t.Fatalf("Unexpected error when reading response without body: %s. response=%q", err, s)
}
if len(resp.Body()) != 0 {
t.Fatalf("Unexpected response body %q. Expected %q. response=%q", resp.Body(), "", s)
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType)
verifyTrailer(t, rb, expectedTrailer)
// verify that ordinal response is read after null-body response
resp.SkipBody = false
testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa",
300, 5, "bar", "56789", "aaa")
}
func TestRequestSuccess(t *testing.T) {
t.Parallel()
// empty method, user-agent and body
testRequestSuccess(t, "", "/foo/bar", "google.com", "", "", MethodGet)
// non-empty user-agent
testRequestSuccess(t, MethodGet, "/foo/bar", "google.com", "MSIE", "", MethodGet)
// non-empty method
testRequestSuccess(t, MethodHead, "/aaa", "fobar", "", "", MethodHead)
// POST method with body
testRequestSuccess(t, MethodPost, "/bbb", "aaa.com", "Chrome aaa", "post body", MethodPost)
// PUT method with body
testRequestSuccess(t, MethodPut, "/aa/bb", "a.com", "ome aaa", "put body", MethodPut)
// only host is set
testRequestSuccess(t, "", "", "gooble.com", "", "", MethodGet)
// get with body
testRequestSuccess(t, MethodGet, "/foo/bar", "aaa.com", "", "foobar", MethodGet)
}
func TestResponseSuccess(t *testing.T) {
t.Parallel()
// 200 response
testResponseSuccess(t, 200, "test/plain", "server", "foobar",
200, "test/plain", "server")
// response with missing statusCode
testResponseSuccess(t, 0, "text/plain", "server", "foobar",
200, "text/plain", "server")
// response with missing server
testResponseSuccess(t, 500, "aaa", "", "aaadfsd",
500, "aaa", "")
// empty body
testResponseSuccess(t, 200, "bbb", "qwer", "",
200, "bbb", "qwer")
// missing content-type
testResponseSuccess(t, 200, "", "asdfsd", "asdf",
200, string(defaultContentType), "asdfsd")
}
func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, body string,
expectedStatusCode int, expectedContentType, expectedServerName string) {
var resp Response
resp.SetStatusCode(statusCode)
resp.Header.Set("Content-Type", contentType)
resp.Header.Set("Server", serverName)
resp.SetBody([]byte(body))
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
err := resp.Write(bw)
if err != nil {
t.Fatalf("Unexpected error when calling Response.Write(): %s", err)
}
if err = bw.Flush(); err != nil {
t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err)
}
var resp1 Response
br := bufio.NewReader(w)
if err = resp1.Read(br); err != nil {
t.Fatalf("Unexpected error when calling Response.Read(): %s", err)
}
if resp1.StatusCode() != expectedStatusCode {
t.Fatalf("Unexpected status code: %d. Expected %d", resp1.StatusCode(), expectedStatusCode)
}
if resp1.Header.ContentLength() != len(body) {
t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength(), len(body))
}
if string(resp1.Header.Peek(HeaderContentType)) != expectedContentType {
t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Peek(HeaderContentType), expectedContentType)
}
if string(resp1.Header.Peek(HeaderServer)) != expectedServerName {
t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Peek(HeaderServer), expectedServerName)
}
if !bytes.Equal(resp1.Body(), []byte(body)) {
t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body(), body)
}
}
func TestRequestWriteError(t *testing.T) {
t.Parallel()
// no host
testRequestWriteError(t, "", "/foo/bar", "", "", "")
}
func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, body string) {
var req Request
req.Header.SetMethod(method)
req.Header.SetRequestURI(requestURI)
req.Header.Set(HeaderHost, host)
req.Header.Set(HeaderUserAgent, userAgent)
req.SetBody([]byte(body))
w := &bytebufferpool.ByteBuffer{}
bw := bufio.NewWriter(w)
err := req.Write(bw)
if err == nil {
t.Fatalf("Expecting error when writing request=%#v", &req)
}
}
func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body, expectedMethod string) {
var req Request
req.Header.SetMethod(method)
req.Header.SetRequestURI(requestURI)
req.Header.Set(HeaderHost, host)
req.Header.Set(HeaderUserAgent, userAgent)
req.SetBody([]byte(body))
contentType := "foobar"
if method == MethodPost {
req.Header.Set(HeaderContentType, contentType)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
err := req.Write(bw)
if err != nil {
t.Fatalf("Unexpected error when calling Request.Write(): %s", err)
}
if err = bw.Flush(); err != nil {
t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err)
}
var req1 Request
br := bufio.NewReader(w)
if err = req1.Read(br); err != nil {
t.Fatalf("Unexpected error when calling Request.Read(): %s", err)
}
if string(req1.Header.Method()) != expectedMethod {
t.Fatalf("Unexpected method: %q. Expected %q", req1.Header.Method(), expectedMethod)
}
if len(requestURI) == 0 {
requestURI = "/"
}
if string(req1.Header.RequestURI()) != requestURI {
t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI(), requestURI)
}
if string(req1.Header.Peek(HeaderHost)) != host {
t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Peek(HeaderHost), host)
}
if string(req1.Header.Peek(HeaderUserAgent)) != userAgent {
t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Peek(HeaderUserAgent), userAgent)
}
if !bytes.Equal(req1.Body(), []byte(body)) {
t.Fatalf("Unexpected body: %q. Expected %q", req1.Body(), body)
}
if method == MethodPost && string(req1.Header.Peek(HeaderContentType)) != contentType {
t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Peek(HeaderContentType), contentType)
}
}
func TestResponseReadSuccess(t *testing.T) {
t.Parallel()
resp := &Response{}
// usual response
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789",
200, 10, "foo/bar", "0123456789", "")
// zero response
testResponseReadSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n",
500, 0, "foo/bar", "", "")
// response with trailer
testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa",
300, 5, "bar", "56789", "aaa")
// no conent-length ('identity' transfer-encoding)
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxc",
200, 4, "foobar", "zxxc", "")
// explicitly stated 'Transfer-Encoding: identity'
testResponseReadSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag",
234, 3, "xxx", "xag", "")
// big 'identity' response
body := string(createFixedBody(100500))
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body,
200, 100500, "aa", body, "")
// chunked response
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\n\r\nzzzzz",
200, 6, "text/html", "qwerty", "zzzzz")
// chunked response with non-chunked Transfer-Encoding.
testResponseReadSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\n\r\nwe",
230, 4, "text", "erty", "we")
// zero chunked response
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\nzzz",
200, 0, "text/html", "", "zzz")
}
func TestResponseReadError(t *testing.T) {
t.Parallel()
resp := &Response{}
// empty response
testResponseReadError(t, resp, "")
// invalid header
testResponseReadError(t, resp, "foobar")
// empty body
testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\n")
// short body
testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\nshort")
}
func testResponseReadError(t *testing.T, resp *Response, response string) {
r := bytes.NewBufferString(response)
rb := bufio.NewReader(r)
err := resp.Read(rb)
if err == nil {
t.Fatalf("Expecting error for response=%q", response)
}
testResponseReadSuccess(t, resp, "HTTP/1.1 303 Redisred sedfs sdf\r\nContent-Type: aaa\r\nContent-Length: 5\r\n\r\nHELLOaaa",
303, 5, "aaa", "HELLO", "aaa")
}
func testResponseReadSuccess(t *testing.T, resp *Response, response string, expectedStatusCode, expectedContentLength int,
expectedContenType, expectedBody, expectedTrailer string) {
r := bytes.NewBufferString(response)
rb := bufio.NewReader(r)
err := resp.Read(rb)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContenType)
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
verifyTrailer(t, rb, expectedTrailer)
}
func TestReadBodyFixedSize(t *testing.T) {
t.Parallel()
// zero-size body
testReadBodyFixedSize(t, 0)
// small-size body
testReadBodyFixedSize(t, 3)
// medium-size body
testReadBodyFixedSize(t, 1024)
// large-size body
testReadBodyFixedSize(t, 1024*1024)
// smaller body after big one
testReadBodyFixedSize(t, 34345)
}
func TestReadBodyChunked(t *testing.T) {
t.Parallel()
// zero-size body
testReadBodyChunked(t, 0)
// small-size body
testReadBodyChunked(t, 5)
// medium-size body
testReadBodyChunked(t, 43488)
// big body
testReadBodyChunked(t, 3*1024*1024)
// smaler body after big one
testReadBodyChunked(t, 12343)
}
func TestRequestURITLS(t *testing.T) {
t.Parallel()
uriNoScheme := "//foobar.com/baz/aa?bb=dd&dd#sdf"
requestURI := "http:" + uriNoScheme
requestURITLS := "https:" + uriNoScheme
var req Request
req.isTLS = true
req.SetRequestURI(requestURI)
uri := req.URI().String()
if uri != requestURITLS {
t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURITLS)
}
req.Reset()
req.SetRequestURI(requestURI)
uri = req.URI().String()
if uri != requestURI {
t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURI)
}
}
func TestRequestURI(t *testing.T) {
t.Parallel()
host := "foobar.com"
requestURI := "/aaa/bb+b%20d?ccc=ddd&qqq#1334dfds&=d"
expectedPathOriginal := "/aaa/bb+b%20d"
expectedPath := "/aaa/bb+b d"
expectedQueryString := "ccc=ddd&qqq"
expectedHash := "1334dfds&=d"
var req Request
req.Header.Set(HeaderHost, host)
req.Header.SetRequestURI(requestURI)
uri := req.URI()
if string(uri.Host()) != host {
t.Fatalf("Unexpected host %q. Expected %q", uri.Host(), host)
}
if string(uri.PathOriginal()) != expectedPathOriginal {
t.Fatalf("Unexpected source path %q. Expected %q", uri.PathOriginal(), expectedPathOriginal)
}
if string(uri.Path()) != expectedPath {
t.Fatalf("Unexpected path %q. Expected %q", uri.Path(), expectedPath)
}
if string(uri.QueryString()) != expectedQueryString {
t.Fatalf("Unexpected query string %q. Expected %q", uri.QueryString(), expectedQueryString)
}
if string(uri.Hash()) != expectedHash {
t.Fatalf("Unexpected hash %q. Expected %q", uri.Hash(), expectedHash)
}
}
func TestRequestPostArgsSuccess(t *testing.T) {
t.Parallel()
var req Request
testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 0\r\n\r\n", 0, "foo=", "=")
testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 18\r\n\r\nfoo&b%20r=b+z=&qwe", 3, "foo=", "b r=b z=", "qwe=")
}
func TestRequestPostArgsError(t *testing.T) {
t.Parallel()
var req Request
// non-post
testRequestPostArgsError(t, &req, "GET /aa HTTP/1.1\r\nHost: aaa\r\n\r\n")
// invalid content-type
testRequestPostArgsError(t, &req, "POST /aa HTTP/1.1\r\nHost: aaa\r\nContent-Type: text/html\r\nContent-Length: 5\r\n\r\nabcde")
}
func testRequestPostArgsError(t *testing.T, req *Request, s string) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
err := req.Read(br)
if err != nil {
t.Fatalf("Unexpected error when reading %q: %s", s, err)
}
ss := req.PostArgs().String()
if len(ss) != 0 {
t.Fatalf("unexpected post args: %q. Expecting empty post args", ss)
}
}
func testRequestPostArgsSuccess(t *testing.T, req *Request, s string, expectedArgsLen int, expectedArgs ...string) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
err := req.Read(br)
if err != nil {
t.Fatalf("Unexpected error when reading %q: %s", s, err)
}
args := req.PostArgs()
if args.Len() != expectedArgsLen {
t.Fatalf("Unexpected args len %d. Expected %d for %q", args.Len(), expectedArgsLen, s)
}
for _, x := range expectedArgs {
tmp := strings.SplitN(x, "=", 2)
k := tmp[0]
v := tmp[1]
vv := string(args.Peek(k))
if vv != v {
t.Fatalf("Unexpected value for key %q: %q. Expected %q for %q", k, vv, v, s)
}
}
}
func testReadBodyChunked(t *testing.T, bodySize int) {
body := createFixedBody(bodySize)
chunkedBody := createChunkedBody(body)
expectedTrailer := []byte("chunked shit")
chunkedBody = append(chunkedBody, expectedTrailer...)
r := bytes.NewBuffer(chunkedBody)
br := bufio.NewReader(r)
b, err := readBody(br, -1, 0, nil)
if err != nil {
t.Fatalf("Unexpected error for bodySize=%d: %s. body=%q, chunkedBody=%q", bodySize, err, body, chunkedBody)
}
if !bytes.Equal(b, body) {
t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q. chunkedBody=%q", bodySize, b, body, chunkedBody)
}
verifyTrailer(t, br, string(expectedTrailer))
}
func testReadBodyFixedSize(t *testing.T, bodySize int) {
body := createFixedBody(bodySize)
expectedTrailer := []byte("traler aaaa")
bodyWithTrailer := append(body, expectedTrailer...)
r := bytes.NewBuffer(bodyWithTrailer)
br := bufio.NewReader(r)
b, err := readBody(br, bodySize, 0, nil)
if err != nil {
t.Fatalf("Unexpected error in ReadResponseBody(%d): %s", bodySize, err)
}
if !bytes.Equal(b, body) {
t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q", bodySize, b, body)
}
verifyTrailer(t, br, string(expectedTrailer))
}
func createFixedBody(bodySize int) []byte {
var b []byte
for i := 0; i < bodySize; i++ {
b = append(b, byte(i%10)+'0')
}
return b
}
func createChunkedBody(body []byte) []byte {
var b []byte
chunkSize := 1
for len(body) > 0 {
if chunkSize > len(body) {
chunkSize = len(body)
}
b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...)
b = append(b, body[:chunkSize]...)
b = append(b, []byte("\r\n")...)
body = body[chunkSize:]
chunkSize++
}
return append(b, []byte("0\r\n\r\n")...)
}
func TestWriteMultipartForm(t *testing.T) {
t.Parallel()
var w bytes.Buffer
s := strings.Replace(`--foo
Content-Disposition: form-data; name="key"
value
--foo
Content-Disposition: form-data; name="file"; filename="test.json"
Content-Type: application/json
{"foo": "bar"}
--foo--
`, "\n", "\r\n", -1)
mr := multipart.NewReader(strings.NewReader(s), "foo")
form, err := mr.ReadForm(1024)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := WriteMultipartForm(&w, form, "foo"); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if w.String() != s {
t.Fatalf("unexpected output %q", w.Bytes())
}
}
func TestResponseRawBodySet(t *testing.T) {
t.Parallel()
var resp Response
expectedS := "test"
body := []byte(expectedS)
resp.SetBodyRaw(body)
testBodyWriteTo(t, &resp, expectedS, true)
}
func TestRequestRawBodySet(t *testing.T) {
t.Parallel()
var r Request
expectedS := "test"
body := []byte(expectedS)
r.SetBodyRaw(body)
testBodyWriteTo(t, &r, expectedS, true)
}
func TestResponseRawBodyReset(t *testing.T) {
t.Parallel()
var resp Response
body := []byte("test")
resp.SetBodyRaw(body)
resp.ResetBody()
testBodyWriteTo(t, &resp, "", true)
}
func TestRequestRawBodyReset(t *testing.T) {
t.Parallel()
var r Request
body := []byte("test")
r.SetBodyRaw(body)
r.ResetBody()
testBodyWriteTo(t, &r, "", true)
}
func TestResponseRawBodyCopyTo(t *testing.T) {
t.Parallel()
var resp Response
expectedS := "test"
body := []byte(expectedS)
resp.SetBodyRaw(body)
testResponseCopyTo(t, &resp)
}
func TestRequestRawBodyCopyTo(t *testing.T) {
t.Parallel()
var a Request
body := []byte("test")
a.SetBodyRaw(body)
var b Request
a.CopyTo(&b)
testBodyWriteTo(t, &a, "test", true)
testBodyWriteTo(t, &b, "test", true)
}
type testReader struct {
read chan (int)
cb chan (struct{})
onClose func() error
}
func (r *testReader) Read(b []byte) (int, error) {
read := <-r.read
if read == -1 {
return 0, io.EOF
}
r.cb <- struct{}{}
for i := 0; i < read; i++ {
b[i] = 'x'
}
return read, nil
}
func (r *testReader) Close() error {
if r.onClose != nil {
return r.onClose()
}
return nil
}
func TestResponseImmediateHeaderFlushRegressionFixedLength(t *testing.T) {
t.Parallel()
var r Response
expectedS := "aaabbbccc"
buf := bytes.NewBufferString(expectedS)
r.SetBodyStream(buf, len(expectedS))
r.ImmediateHeaderFlush = true
testBodyWriteTo(t, &r, expectedS, false)
}
func TestResponseImmediateHeaderFlushRegressionChunked(t *testing.T) {
t.Parallel()
var r Response
expectedS := "aaabbbccc"
buf := bytes.NewBufferString(expectedS)
r.SetBodyStream(buf, -1)
r.ImmediateHeaderFlush = true
testBodyWriteTo(t, &r, expectedS, false)
}
func TestResponseImmediateHeaderFlushFixedLength(t *testing.T) {
t.Parallel()
var r Response
r.ImmediateHeaderFlush = true
ch := make(chan int)
cb := make(chan struct{})
buf := &testReader{read: ch, cb: cb}
r.SetBodyStream(buf, 3)
b := []byte{}
w := bytes.NewBuffer(b)
bb := bufio.NewWriter(w)
bw := &r
waitForIt := make(chan struct{})
go func() {
if err := bw.Write(bb); err != nil {
t.Errorf("unexpected error: %s", err)
}
waitForIt <- struct{}{}
}()
ch <- 3
if !strings.Contains(w.String(), "Content-Length: 3") {
t.Fatalf("Expected headers to be flushed")
}
if strings.Contains(w.String(), "xxx") {
t.Fatalf("Did not expext body to be written yet")
}
<-cb
ch <- -1
<-waitForIt
}
func TestResponseImmediateHeaderFlushFixedLengthSkipBody(t *testing.T) {
t.Parallel()
var r Response
r.ImmediateHeaderFlush = true
r.SkipBody = true
ch := make(chan int)
cb := make(chan struct{})
buf := &testReader{read: ch, cb: cb}
r.SetBodyStream(buf, 0)
b := []byte{}
w := bytes.NewBuffer(b)
bb := bufio.NewWriter(w)
var headersOnClose string
buf.onClose = func() error {
headersOnClose = w.String()
return nil
}
bw := &r
if err := bw.Write(bb); err != nil {
t.Errorf("unexpected error: %s", err)
}
if !strings.Contains(headersOnClose, "Content-Length: 0") {
t.Fatalf("Expected headers to be eagerly flushed")
}
}
func TestResponseImmediateHeaderFlushChunked(t *testing.T) {
t.Parallel()
var r Response
r.ImmediateHeaderFlush = true
ch := make(chan int)
cb := make(chan struct{})
buf := &testReader{read: ch, cb: cb}
r.SetBodyStream(buf, -1)
b := []byte{}
w := bytes.NewBuffer(b)
bb := bufio.NewWriter(w)
bw := &r
waitForIt := make(chan struct{})
go func() {
if err := bw.Write(bb); err != nil {
t.Errorf("unexpected error: %s", err)
}
waitForIt <- struct{}{}
}()
ch <- 3
if !strings.Contains(w.String(), "Transfer-Encoding: chunked") {
t.Fatalf("Expected headers to be flushed")
}
if strings.Contains(w.String(), "xxx") {
t.Fatalf("Did not expext body to be written yet")
}
<-cb
ch <- -1
<-waitForIt
}
func TestResponseImmediateHeaderFlushChunkedNoBody(t *testing.T) {
t.Parallel()
var r Response
r.ImmediateHeaderFlush = true
r.SkipBody = true
ch := make(chan int)
cb := make(chan struct{})
buf := &testReader{read: ch, cb: cb}
r.SetBodyStream(buf, -1)
b := []byte{}
w := bytes.NewBuffer(b)
bb := bufio.NewWriter(w)
var headersOnClose string
buf.onClose = func() error {
headersOnClose = w.String()
return nil
}
bw := &r
if err := bw.Write(bb); err != nil {
t.Errorf("unexpected error: %s", err)
}
if !strings.Contains(headersOnClose, "Transfer-Encoding: chunked") {
t.Fatalf("Expected headers to be eagerly flushed")
}
}
type ErroneousBodyStream struct {
errOnRead bool
errOnClose bool
}
func (ebs *ErroneousBodyStream) Read(p []byte) (n int, err error) {
if ebs.errOnRead {
panic("reading erroneous body stream")
}
return 0, io.EOF
}
func (ebs *ErroneousBodyStream) Close() error {
if ebs.errOnClose {
panic("closing erroneous body stream")
}
return nil
}
func TestResponseBodyStreamErrorOnPanicDuringRead(t *testing.T) {
t.Parallel()
var resp Response
var w bytes.Buffer
bw := bufio.NewWriter(&w)
ebs := &ErroneousBodyStream{errOnRead: true, errOnClose: false}
resp.SetBodyStream(ebs, 42)
err := resp.Write(bw)
if err == nil {
t.Fatalf("expected error when writing response.")
}
e, ok := err.(*ErrBodyStreamWritePanic)
if !ok {
t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e)
}
if e.Error() != "panic while writing body stream: reading erroneous body stream" {
t.Fatalf("unexpected error value, got: %+v.", e.Error())
}
}
func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) {
t.Parallel()
var resp Response
var w bytes.Buffer
bw := bufio.NewWriter(&w)
ebs := &ErroneousBodyStream{errOnRead: false, errOnClose: true}
resp.SetBodyStream(ebs, 42)
err := resp.Write(bw)
if err == nil {
t.Fatalf("expected error when writing response.")
}
e, ok := err.(*ErrBodyStreamWritePanic)
if !ok {
t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e)
}
if e.Error() != "panic while writing body stream: closing erroneous body stream" {
t.Fatalf("unexpected error value, got: %+v.", e.Error())
}
}
fasthttp-1.31.0/lbclient.go 0000664 0000000 0000000 00000010202 14130360711 0015543 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"sync"
"sync/atomic"
"time"
)
// BalancingClient is the interface for clients, which may be passed
// to LBClient.Clients.
type BalancingClient interface {
DoDeadline(req *Request, resp *Response, deadline time.Time) error
PendingRequests() int
}
// LBClient balances requests among available LBClient.Clients.
//
// It has the following features:
//
// - Balances load among available clients using 'least loaded' + 'least total'
// hybrid technique.
// - Dynamically decreases load on unhealthy clients.
//
// It is forbidden copying LBClient instances. Create new instances instead.
//
// It is safe calling LBClient methods from concurrently running goroutines.
type LBClient struct {
noCopy noCopy //nolint:unused,structcheck
// Clients must contain non-zero clients list.
// Incoming requests are balanced among these clients.
Clients []BalancingClient
// HealthCheck is a callback called after each request.
//
// The request, response and the error returned by the client
// is passed to HealthCheck, so the callback may determine whether
// the client is healthy.
//
// Load on the current client is decreased if HealthCheck returns false.
//
// By default HealthCheck returns false if err != nil.
HealthCheck func(req *Request, resp *Response, err error) bool
// Timeout is the request timeout used when calling LBClient.Do.
//
// DefaultLBClientTimeout is used by default.
Timeout time.Duration
cs []*lbClient
once sync.Once
}
// DefaultLBClientTimeout is the default request timeout used by LBClient
// when calling LBClient.Do.
//
// The timeout may be overridden via LBClient.Timeout.
const DefaultLBClientTimeout = time.Second
// DoDeadline calls DoDeadline on the least loaded client
func (cc *LBClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return cc.get().DoDeadline(req, resp, deadline)
}
// DoTimeout calculates deadline and calls DoDeadline on the least loaded client
func (cc *LBClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
return cc.get().DoDeadline(req, resp, deadline)
}
// Do calls calculates deadline using LBClient.Timeout and calls DoDeadline
// on the least loaded client.
func (cc *LBClient) Do(req *Request, resp *Response) error {
timeout := cc.Timeout
if timeout <= 0 {
timeout = DefaultLBClientTimeout
}
return cc.DoTimeout(req, resp, timeout)
}
func (cc *LBClient) init() {
if len(cc.Clients) == 0 {
panic("BUG: LBClient.Clients cannot be empty")
}
for _, c := range cc.Clients {
cc.cs = append(cc.cs, &lbClient{
c: c,
healthCheck: cc.HealthCheck,
})
}
}
func (cc *LBClient) get() *lbClient {
cc.once.Do(cc.init)
cs := cc.cs
minC := cs[0]
minN := minC.PendingRequests()
minT := atomic.LoadUint64(&minC.total)
for _, c := range cs[1:] {
n := c.PendingRequests()
t := atomic.LoadUint64(&c.total)
if n < minN || (n == minN && t < minT) {
minC = c
minN = n
minT = t
}
}
return minC
}
type lbClient struct {
c BalancingClient
healthCheck func(req *Request, resp *Response, err error) bool
penalty uint32
// total amount of requests handled.
total uint64
}
func (c *lbClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
err := c.c.DoDeadline(req, resp, deadline)
if !c.isHealthy(req, resp, err) && c.incPenalty() {
// Penalize the client returning error, so the next requests
// are routed to another clients.
time.AfterFunc(penaltyDuration, c.decPenalty)
} else {
atomic.AddUint64(&c.total, 1)
}
return err
}
func (c *lbClient) PendingRequests() int {
n := c.c.PendingRequests()
m := atomic.LoadUint32(&c.penalty)
return n + int(m)
}
func (c *lbClient) isHealthy(req *Request, resp *Response, err error) bool {
if c.healthCheck == nil {
return err == nil
}
return c.healthCheck(req, resp, err)
}
func (c *lbClient) incPenalty() bool {
m := atomic.AddUint32(&c.penalty, 1)
if m > maxPenalty {
c.decPenalty()
return false
}
return true
}
func (c *lbClient) decPenalty() {
atomic.AddUint32(&c.penalty, ^uint32(0))
}
const (
maxPenalty = 300
penaltyDuration = 3 * time.Second
)
fasthttp-1.31.0/lbclient_example_test.go 0000664 0000000 0000000 00000001622 14130360711 0020323 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"fmt"
"log"
"github.com/valyala/fasthttp"
)
func ExampleLBClient() {
// Requests will be spread among these servers.
servers := []string{
"google.com:80",
"foobar.com:8080",
"127.0.0.1:123",
}
// Prepare clients for each server
var lbc fasthttp.LBClient
for _, addr := range servers {
c := &fasthttp.HostClient{
Addr: addr,
}
lbc.Clients = append(lbc.Clients, c)
}
// Send requests to load-balanced servers
var req fasthttp.Request
var resp fasthttp.Response
for i := 0; i < 10; i++ {
url := fmt.Sprintf("http://abcedfg/foo/bar/%d", i)
req.SetRequestURI(url)
if err := lbc.Do(&req, &resp); err != nil {
log.Fatalf("Error when sending request: %s", err)
}
if resp.StatusCode() != fasthttp.StatusOK {
log.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK)
}
useResponseBody(resp.Body())
}
}
fasthttp-1.31.0/methods.go 0000664 0000000 0000000 00000000736 14130360711 0015425 0 ustar 00root root 0000000 0000000 package fasthttp
// HTTP methods were copied from net/http.
const (
MethodGet = "GET" // RFC 7231, 4.3.1
MethodHead = "HEAD" // RFC 7231, 4.3.2
MethodPost = "POST" // RFC 7231, 4.3.3
MethodPut = "PUT" // RFC 7231, 4.3.4
MethodPatch = "PATCH" // RFC 5789
MethodDelete = "DELETE" // RFC 7231, 4.3.5
MethodConnect = "CONNECT" // RFC 7231, 4.3.6
MethodOptions = "OPTIONS" // RFC 7231, 4.3.7
MethodTrace = "TRACE" // RFC 7231, 4.3.8
)
fasthttp-1.31.0/nocopy.go 0000664 0000000 0000000 00000000612 14130360711 0015262 0 ustar 00root root 0000000 0000000 package fasthttp
// Embed this type into a struct, which mustn't be copied,
// so `go vet` gives a warning if this struct is copied.
//
// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for details.
// and also: https://stackoverflow.com/questions/52494458/nocopy-minimal-example
type noCopy struct{} //nolint:unused
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
fasthttp-1.31.0/peripconn.go 0000664 0000000 0000000 00000003327 14130360711 0015756 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"fmt"
"net"
"sync"
)
type perIPConnCounter struct {
pool sync.Pool
lock sync.Mutex
m map[uint32]int
}
func (cc *perIPConnCounter) Register(ip uint32) int {
cc.lock.Lock()
if cc.m == nil {
cc.m = make(map[uint32]int)
}
n := cc.m[ip] + 1
cc.m[ip] = n
cc.lock.Unlock()
return n
}
func (cc *perIPConnCounter) Unregister(ip uint32) {
cc.lock.Lock()
if cc.m == nil {
cc.lock.Unlock()
panic("BUG: perIPConnCounter.Register() wasn't called")
}
n := cc.m[ip] - 1
if n < 0 {
cc.lock.Unlock()
panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip))
}
cc.m[ip] = n
cc.lock.Unlock()
}
type perIPConn struct {
net.Conn
ip uint32
perIPConnCounter *perIPConnCounter
}
func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
v := counter.pool.Get()
if v == nil {
v = &perIPConn{
perIPConnCounter: counter,
}
}
c := v.(*perIPConn)
c.Conn = conn
c.ip = ip
return c
}
func releasePerIPConn(c *perIPConn) {
c.Conn = nil
c.perIPConnCounter.pool.Put(c)
}
func (c *perIPConn) Close() error {
err := c.Conn.Close()
c.perIPConnCounter.Unregister(c.ip)
releasePerIPConn(c)
return err
}
func getUint32IP(c net.Conn) uint32 {
return ip2uint32(getConnIP4(c))
}
func getConnIP4(c net.Conn) net.IP {
addr := c.RemoteAddr()
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return net.IPv4zero
}
return ipAddr.IP.To4()
}
func ip2uint32(ip net.IP) uint32 {
if len(ip) != 4 {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uint322ip(ip uint32) net.IP {
b := make([]byte, 4)
b[0] = byte(ip >> 24)
b[1] = byte(ip >> 16)
b[2] = byte(ip >> 8)
b[3] = byte(ip)
return b
}
fasthttp-1.31.0/peripconn_test.go 0000664 0000000 0000000 00000002150 14130360711 0017006 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"testing"
)
func TestIPxUint32(t *testing.T) {
t.Parallel()
testIPxUint32(t, 0)
testIPxUint32(t, 10)
testIPxUint32(t, 0x12892392)
}
func testIPxUint32(t *testing.T, n uint32) {
ip := uint322ip(n)
nn := ip2uint32(ip)
if n != nn {
t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n)
}
}
func TestPerIPConnCounter(t *testing.T) {
t.Parallel()
var cc perIPConnCounter
expectPanic(t, func() { cc.Unregister(123) })
for i := 1; i < 100; i++ {
if n := cc.Register(123); n != i {
t.Fatalf("Unexpected counter value=%d. Expected %d", n, i)
}
}
n := cc.Register(456)
if n != 1 {
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
}
for i := 1; i < 100; i++ {
cc.Unregister(123)
}
cc.Unregister(456)
expectPanic(t, func() { cc.Unregister(123) })
expectPanic(t, func() { cc.Unregister(456) })
n = cc.Register(123)
if n != 1 {
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
}
cc.Unregister(123)
}
func expectPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("Expecting panic")
}
}()
f()
}
fasthttp-1.31.0/pprofhandler/ 0000775 0000000 0000000 00000000000 14130360711 0016111 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/pprofhandler/pprof.go 0000664 0000000 0000000 00000002571 14130360711 0017573 0 ustar 00root root 0000000 0000000 package pprofhandler
import (
"net/http/pprof"
rtp "runtime/pprof"
"strings"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
)
var (
cmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
profile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
symbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
trace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
index = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
)
// PprofHandler serves server runtime profiling data in the format expected by the pprof visualization tool.
//
// See https://golang.org/pkg/net/http/pprof/ for details.
func PprofHandler(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Set("Content-Type", "text/html")
if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/cmdline") {
cmdline(ctx)
} else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/profile") {
profile(ctx)
} else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/symbol") {
symbol(ctx)
} else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/trace") {
trace(ctx)
} else {
for _, v := range rtp.Profiles() {
ppName := v.Name()
if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/"+ppName) {
namedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler(ppName).ServeHTTP)
namedHandler(ctx)
return
}
}
index(ctx)
}
}
fasthttp-1.31.0/prefork/ 0000775 0000000 0000000 00000000000 14130360711 0015075 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/prefork/README.md 0000664 0000000 0000000 00000004142 14130360711 0016355 0 ustar 00root root 0000000 0000000 # Prefork
Server prefork implementation.
Preforks master process between several child processes increases performance, because Go doesn't have to share and manage memory between cores.
**WARNING: using prefork prevents the use of any global state!. Things like in-memory caches won't work.**
- How it works:
```go
import (
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/prefork"
)
server := &fasthttp.Server{
// Your configuration
}
// Wraps the server with prefork
preforkServer := prefork.New(server)
if err := preforkServer.ListenAndServe(":8080"); err != nil {
panic(err)
}
```
## Benchmarks
Environment:
- Machine: MacBook Pro 13-inch, 2017
- OS: MacOS 10.15.3
- Go: go1.13.6 darwin/amd64
Handler code:
```go
func requestHandler(ctx *fasthttp.RequestCtx) {
// Simulates some hard work
time.Sleep(100 * time.Millisecond)
}
```
Test command:
```bash
$ wrk -H 'Host: localhost' -H 'Accept: text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7' -H 'Connection: keep-alive' --latency -d 15 -c 512 --timeout 8 -t 4 http://localhost:8080
```
Results:
- prefork
```bash
Running 15s test @ http://localhost:8080
4 threads and 512 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 4.75ms 4.27ms 126.24ms 97.45%
Req/Sec 26.46k 4.16k 71.18k 88.72%
Latency Distribution
50% 4.55ms
75% 4.82ms
90% 5.46ms
99% 15.49ms
1581916 requests in 15.09s, 140.30MB read
Socket errors: connect 0, read 318, write 0, timeout 0
Requests/sec: 104861.58
Transfer/sec: 9.30MB
```
- **non**-prefork
```bash
Running 15s test @ http://localhost:8080
4 threads and 512 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 6.42ms 11.83ms 177.19ms 96.42%
Req/Sec 24.96k 5.83k 56.83k 82.93%
Latency Distribution
50% 4.53ms
75% 4.93ms
90% 6.94ms
99% 74.54ms
1472441 requests in 15.09s, 130.59MB read
Socket errors: connect 0, read 265, write 0, timeout 0
Requests/sec: 97553.34
Transfer/sec: 8.65MB
```
fasthttp-1.31.0/prefork/prefork.go 0000664 0000000 0000000 00000014376 14130360711 0017107 0 ustar 00root root 0000000 0000000 package prefork
import (
"errors"
"flag"
"log"
"net"
"os"
"os/exec"
"runtime"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/reuseport"
)
const (
preforkChildFlag = "-prefork-child"
defaultNetwork = "tcp4"
)
var (
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
// ErrOverRecovery is returned when the times of starting over child prefork processes exceed
// the threshold.
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
// ErrOnlyReuseportOnWindows is returned when Reuseport is false.
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
)
// Logger is used for logging formatted messages.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...interface{})
}
// Prefork implements fasthttp server prefork
//
// Preforks master process (with all cores) between several child processes
// increases performance significantly, because Go doesn't have to share
// and manage memory between cores
//
// WARNING: using prefork prevents the use of any global state!
// Things like in-memory caches won't work.
type Prefork struct {
// The network must be "tcp", "tcp4" or "tcp6".
//
// By default is "tcp4"
Network string
// Flag to use a listener with reuseport, if not a file Listener will be used
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
//
// It's disabled by default
Reuseport bool
// Child prefork processes may exit with failure and will be started over until the times reach
// the value of RecoverThreshold, then it will return and terminate the server.
RecoverThreshold int
// By default standard logger from log package is used.
Logger Logger
ServeFunc func(ln net.Listener) error
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
ln net.Listener
files []*os.File
}
func init() { //nolint:gochecknoinits
// Definition flag to not break the program when the user adds their own flags
// and runs `flag.Parse()`
flag.Bool(preforkChildFlag[1:], false, "Is a child process")
}
// IsChild checks if the current thread/process is a child
func IsChild() bool {
for _, arg := range os.Args[1:] {
if arg == preforkChildFlag {
return true
}
}
return false
}
// New wraps the fasthttp server to run with preforked processes
func New(s *fasthttp.Server) *Prefork {
return &Prefork{
Network: defaultNetwork,
RecoverThreshold: runtime.GOMAXPROCS(0) / 2,
Logger: s.Logger,
ServeFunc: s.Serve,
ServeTLSFunc: s.ServeTLS,
ServeTLSEmbedFunc: s.ServeTLSEmbed,
}
}
func (p *Prefork) logger() Logger {
if p.Logger != nil {
return p.Logger
}
return defaultLogger
}
func (p *Prefork) listen(addr string) (net.Listener, error) {
runtime.GOMAXPROCS(1)
if p.Network == "" {
p.Network = defaultNetwork
}
if p.Reuseport {
return reuseport.Listen(p.Network, addr)
}
return net.FileListener(os.NewFile(3, ""))
}
func (p *Prefork) setTCPListenerFiles(addr string) error {
if p.Network == "" {
p.Network = defaultNetwork
}
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
if err != nil {
return err
}
tcplistener, err := net.ListenTCP(p.Network, tcpAddr)
if err != nil {
return err
}
p.ln = tcplistener
fl, err := tcplistener.File()
if err != nil {
return err
}
p.files = []*os.File{fl}
return nil
}
func (p *Prefork) doCommand() (*exec.Cmd, error) {
/* #nosec G204 */
cmd := exec.Command(os.Args[0], append(os.Args[1:], preforkChildFlag)...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = p.files
return cmd, cmd.Start()
}
func (p *Prefork) prefork(addr string) (err error) {
if !p.Reuseport {
if runtime.GOOS == "windows" {
return ErrOnlyReuseportOnWindows
}
if err = p.setTCPListenerFiles(addr); err != nil {
return
}
// defer for closing the net.Listener opened by setTCPListenerFiles.
defer func() {
e := p.ln.Close()
if err == nil {
err = e
}
}()
}
type procSig struct {
pid int
err error
}
goMaxProcs := runtime.GOMAXPROCS(0)
sigCh := make(chan procSig, goMaxProcs)
childProcs := make(map[int]*exec.Cmd)
defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
}
}()
for i := 0; i < goMaxProcs; i++ {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
return
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{cmd.Process.Pid, cmd.Wait()}
}()
}
var exitedProcs int
for sig := range sigCh {
delete(childProcs, sig.pid)
p.logger().Printf("one of the child prefork processes exited with "+
"error: %v", sig.err)
if exitedProcs++; exitedProcs > p.RecoverThreshold {
p.logger().Printf("child prefork processes exit too many times, "+
"which exceeds the value of RecoverThreshold(%d), "+
"exiting the master process.\n", exitedProcs)
err = ErrOverRecovery
break
}
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
break
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{cmd.Process.Pid, cmd.Wait()}
}()
}
return
}
// ListenAndServe serves HTTP requests from the given TCP addr
func (p *Prefork) ListenAndServe(addr string) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeFunc(ln)
}
return p.prefork(addr)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP addr
//
// certFile and keyFile are paths to TLS certificate and key files.
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeTLSFunc(ln, certFile, certKey)
}
return p.prefork(addr)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr
//
// certData and keyData must contain valid TLS certificate and key data.
func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeTLSEmbedFunc(ln, certData, keyData)
}
return p.prefork(addr)
}
fasthttp-1.31.0/prefork/prefork_test.go 0000664 0000000 0000000 00000010224 14130360711 0020132 0 ustar 00root root 0000000 0000000 package prefork
import (
"fmt"
"math/rand"
"net"
"os"
"reflect"
"runtime"
"testing"
"github.com/valyala/fasthttp"
)
func setUp() {
os.Args = append(os.Args, preforkChildFlag)
}
func tearDown() {
os.Args = os.Args[:len(os.Args)-1]
}
func getAddr() string {
return fmt.Sprintf("0.0.0.0:%d", rand.Intn(9000-3000)+3000)
}
func Test_IsChild(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
v := IsChild()
if v {
t.Errorf("IsChild() == %v, want %v", v, false)
}
setUp()
defer tearDown()
v = IsChild()
if !v {
t.Errorf("IsChild() == %v, want %v", v, true)
}
}
func Test_New(t *testing.T) {
t.Parallel()
s := &fasthttp.Server{}
p := New(s)
if p.Network != defaultNetwork {
t.Errorf("Prefork.Netork == %s, want %s", p.Network, defaultNetwork)
}
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
}
if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS)
}
if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed)
}
}
func Test_listen(t *testing.T) {
t.Parallel()
p := &Prefork{
Reuseport: true,
}
addr := getAddr()
ln, err := p.listen(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ln.Close()
lnAddr := ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork)
}
procs := runtime.GOMAXPROCS(0)
if procs != 1 {
t.Errorf("GOMAXPROCS == %d, want %d", procs, 1)
}
}
func Test_setTCPListenerFiles(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.SkipNow()
}
p := &Prefork{}
addr := getAddr()
err := p.setTCPListenerFiles(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if p.ln == nil {
t.Fatal("Prefork.ln is nil")
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork)
}
if len(p.files) != 1 {
t.Errorf("Prefork.files == %d, want %d", len(p.files), 1)
}
}
func Test_ListenAndServe(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeFunc = func(ln net.Listener) error {
return nil
}
addr := getAddr()
err := p.ListenAndServe(addr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
}
}
func Test_ListenAndServeTLS(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
return nil
}
addr := getAddr()
err := p.ListenAndServeTLS(addr, "./key", "./cert")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
}
}
func Test_ListenAndServeTLSEmbed(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
return nil
}
addr := getAddr()
err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
}
}
fasthttp-1.31.0/requestctx_setbodystreamwriter_example_test.go 0000664 0000000 0000000 00000001346 14130360711 0025143 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"bufio"
"fmt"
"log"
"time"
"github.com/valyala/fasthttp"
)
func ExampleRequestCtx_SetBodyStreamWriter() {
// Start fasthttp server for streaming responses.
if err := fasthttp.ListenAndServe(":8080", responseStreamHandler); err != nil {
log.Fatalf("unexpected error in server: %s", err)
}
}
func responseStreamHandler(ctx *fasthttp.RequestCtx) {
// Send the response in chunks and wait for a second between each chunk.
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
for i := 0; i < 10; i++ {
fmt.Fprintf(w, "this is a message number %d", i)
// Do not forget flushing streamed data to the client.
if err := w.Flush(); err != nil {
return
}
time.Sleep(time.Second)
}
})
}
fasthttp-1.31.0/reuseport/ 0000775 0000000 0000000 00000000000 14130360711 0015455 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/reuseport/LICENSE 0000664 0000000 0000000 00000002065 14130360711 0016465 0 ustar 00root root 0000000 0000000 The MIT License (MIT)
Copyright (c) 2014 Max Riveiro
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. fasthttp-1.31.0/reuseport/reuseport.go 0000664 0000000 0000000 00000002512 14130360711 0020034 0 ustar 00root root 0000000 0000000 //go:build !windows
// +build !windows
// Package reuseport provides TCP net.Listener with SO_REUSEPORT support.
//
// SO_REUSEPORT allows linear scaling server performance on multi-CPU servers.
// See https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ for more details :)
//
// The package is based on https://github.com/kavu/go_reuseport .
package reuseport
import (
"net"
"strings"
"github.com/valyala/tcplisten"
)
// Listen returns TCP listener with SO_REUSEPORT option set.
//
// The returned listener tries enabling the following TCP options, which usually
// have positive impact on performance:
//
// - TCP_DEFER_ACCEPT. This option expects that the server reads from accepted
// connections before writing to them.
//
// - TCP_FASTOPEN. See https://lwn.net/Articles/508865/ for details.
//
// Use https://github.com/valyala/tcplisten if you want customizing
// these options.
//
// Only tcp4 and tcp6 networks are supported.
//
// ErrNoReusePort error is returned if the system doesn't support SO_REUSEPORT.
func Listen(network, addr string) (net.Listener, error) {
ln, err := cfg.NewListener(network, addr)
if err != nil && strings.Contains(err.Error(), "SO_REUSEPORT") {
return nil, &ErrNoReusePort{err}
}
return ln, err
}
var cfg = &tcplisten.Config{
ReusePort: true,
DeferAccept: true,
FastOpen: true,
}
fasthttp-1.31.0/reuseport/reuseport_error.go 0000664 0000000 0000000 00000000456 14130360711 0021252 0 ustar 00root root 0000000 0000000 package reuseport
import (
"fmt"
)
// ErrNoReusePort is returned if the OS doesn't support SO_REUSEPORT.
type ErrNoReusePort struct {
err error
}
// Error implements error interface.
func (e *ErrNoReusePort) Error() string {
return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %s", e.err)
}
fasthttp-1.31.0/reuseport/reuseport_example_test.go 0000664 0000000 0000000 00000000734 14130360711 0022612 0 ustar 00root root 0000000 0000000 package reuseport_test
import (
"fmt"
"log"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/reuseport"
)
func ExampleListen() {
ln, err := reuseport.Listen("tcp4", "localhost:12345")
if err != nil {
log.Fatalf("error in reuseport listener: %s", err)
}
if err = fasthttp.Serve(ln, requestHandler); err != nil {
log.Fatalf("error in fasthttp Server: %s", err)
}
}
func requestHandler(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hello, world!")
}
fasthttp-1.31.0/reuseport/reuseport_test.go 0000664 0000000 0000000 00000005343 14130360711 0021100 0 ustar 00root root 0000000 0000000 package reuseport
import (
"fmt"
"io/ioutil"
"net"
"testing"
"time"
)
func TestTCP4(t *testing.T) {
t.Parallel()
testNewListener(t, "tcp4", "localhost:10081", 20, 1000)
}
func TestTCP6(t *testing.T) {
t.Parallel()
// Run this test only if tcp6 interface exists.
if hasLocalIPv6(t) {
testNewListener(t, "tcp6", "[::1]:10082", 20, 1000)
}
}
func hasLocalIPv6(t *testing.T) bool {
addrs, err := net.InterfaceAddrs()
if err != nil {
t.Fatalf("cannot obtain local interfaces: %s", err)
}
for _, a := range addrs {
if a.String() == "::1/128" {
return true
}
}
return false
}
func testNewListener(t *testing.T, network, addr string, serversCount, requestsCount int) {
var lns []net.Listener
doneCh := make(chan struct{}, serversCount)
for i := 0; i < serversCount; i++ {
ln, err := Listen(network, addr)
if err != nil {
t.Fatalf("cannot create listener %d: %s", i, err)
}
go func() {
serveEcho(t, ln)
doneCh <- struct{}{}
}()
lns = append(lns, ln)
}
for i := 0; i < requestsCount; i++ {
c, err := net.Dial(network, addr)
if err != nil {
t.Fatalf("%d. unexpected error when dialing: %s", i, err)
}
req := fmt.Sprintf("request number %d", i)
if _, err = c.Write([]byte(req)); err != nil {
t.Fatalf("%d. unexpected error when writing request: %s", i, err)
}
if err = c.(*net.TCPConn).CloseWrite(); err != nil {
t.Fatalf("%d. unexpected error when closing write end of the connection: %s", i, err)
}
var resp []byte
ch := make(chan struct{})
go func() {
if resp, err = ioutil.ReadAll(c); err != nil {
t.Errorf("%d. unexpected error when reading response: %s", i, err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(250 * time.Millisecond):
t.Fatalf("%d. timeout when waiting for response", i)
}
if string(resp) != req {
t.Fatalf("%d. unexpected response %q. Expecting %q", i, resp, req)
}
if err = c.Close(); err != nil {
t.Fatalf("%d. unexpected error when closing connection: %s", i, err)
}
}
for _, ln := range lns {
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error when closing listener: %s", err)
}
}
for i := 0; i < serversCount; i++ {
select {
case <-doneCh:
case <-time.After(200 * time.Millisecond):
t.Fatalf("timeout when waiting for servers to be closed")
}
}
}
func serveEcho(t *testing.T, ln net.Listener) {
for {
c, err := ln.Accept()
if err != nil {
break
}
req, err := ioutil.ReadAll(c)
if err != nil {
t.Fatalf("unexpected error when reading request: %s", err)
}
if _, err = c.Write(req); err != nil {
t.Fatalf("unexpected error when writing response: %s", err)
}
if err = c.Close(); err != nil {
t.Fatalf("unexpected error when closing connection: %s", err)
}
}
}
fasthttp-1.31.0/reuseport/reuseport_windows.go 0000664 0000000 0000000 00000001206 14130360711 0021605 0 ustar 00root root 0000000 0000000 package reuseport
import (
"context"
"net"
"syscall"
"golang.org/x/sys/windows"
)
var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) (err error) {
return c.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
})
},
}
// Listen returns TCP listener with SO_REUSEADDR option set, SO_REUSEPORT is not supported on Windows, so it uses
// SO_REUSEADDR as an alternative to achieve the same effect.
func Listen(network, addr string) (net.Listener, error) {
return listenConfig.Listen(context.Background(), network, addr)
}
fasthttp-1.31.0/server.go 0000664 0000000 0000000 00000240606 14130360711 0015272 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
var errNoCertOrKeyProvided = errors.New("cert or key has not provided")
var (
// ErrAlreadyServing is returned when calling Serve on a Server
// that is already serving connections.
ErrAlreadyServing = errors.New("Server is already serving connections")
)
// ServeConn serves HTTP requests from the given connection
// using the given handler.
//
// ServeConn returns nil if all requests from the c are successfully served.
// It returns non-nil error otherwise.
//
// Connection c must immediately propagate all the data passed to Write()
// to the client. Otherwise requests' processing may hang.
//
// ServeConn closes c before returning.
func ServeConn(c net.Conn, handler RequestHandler) error {
v := serverPool.Get()
if v == nil {
v = &Server{}
}
s := v.(*Server)
s.Handler = handler
err := s.ServeConn(c)
s.Handler = nil
serverPool.Put(v)
return err
}
var serverPool sync.Pool
// Serve serves incoming connections from the given listener
// using the given handler.
//
// Serve blocks until the given listener returns permanent error.
func Serve(ln net.Listener, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.Serve(ln)
}
// ServeTLS serves HTTPS requests from the given net.Listener
// using the given handler.
//
// certFile and keyFile are paths to TLS certificate and key files.
func ServeTLS(ln net.Listener, certFile, keyFile string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ServeTLS(ln, certFile, keyFile)
}
// ServeTLSEmbed serves HTTPS requests from the given net.Listener
// using the given handler.
//
// certData and keyData must contain valid TLS certificate and key data.
func ServeTLSEmbed(ln net.Listener, certData, keyData []byte, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ServeTLSEmbed(ln, certData, keyData)
}
// ListenAndServe serves HTTP requests from the given TCP addr
// using the given handler.
func ListenAndServe(addr string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServe(addr)
}
// ListenAndServeUNIX serves HTTP requests from the given UNIX addr
// using the given handler.
//
// The function deletes existing file at addr before starting serving.
//
// The server sets the given file mode for the UNIX addr.
func ListenAndServeUNIX(addr string, mode os.FileMode, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeUNIX(addr, mode)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP addr
// using the given handler.
//
// certFile and keyFile are paths to TLS certificate and key files.
func ListenAndServeTLS(addr, certFile, keyFile string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeTLS(addr, certFile, keyFile)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr
// using the given handler.
//
// certData and keyData must contain valid TLS certificate and key data.
func ListenAndServeTLSEmbed(addr string, certData, keyData []byte, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeTLSEmbed(addr, certData, keyData)
}
// RequestHandler must process incoming requests.
//
// RequestHandler must call ctx.TimeoutError() before returning
// if it keeps references to ctx and/or its' members after the return.
// Consider wrapping RequestHandler into TimeoutHandler if response time
// must be limited.
type RequestHandler func(ctx *RequestCtx)
// ServeHandler must process tls.Config.NextProto negotiated requests.
type ServeHandler func(c net.Conn) error
// Server implements HTTP server.
//
// Default Server settings should satisfy the majority of Server users.
// Adjust Server settings only if you really understand the consequences.
//
// It is forbidden copying Server instances. Create new Server instances
// instead.
//
// It is safe to call Server methods from concurrently running goroutines.
type Server struct {
noCopy noCopy //nolint:unused,structcheck
// Handler for processing incoming requests.
//
// Take into account that no `panic` recovery is done by `fasthttp` (thus any `panic` will take down the entire server).
// Instead the user should use `recover` to handle these situations.
Handler RequestHandler
// ErrorHandler for returning a response in case of an error while receiving or parsing the request.
//
// The following is a non-exhaustive list of errors that can be expected as argument:
// * io.EOF
// * io.ErrUnexpectedEOF
// * ErrGetOnly
// * ErrSmallBuffer
// * ErrBodyTooLarge
// * ErrBrokenChunks
ErrorHandler func(ctx *RequestCtx, err error)
// HeaderReceived is called after receiving the header
//
// non zero RequestConfig field values will overwrite the default configs
HeaderReceived func(header *RequestHeader) RequestConfig
// ContinueHandler is called after receiving the Expect 100 Continue Header
//
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
// Using ContinueHandler a server can make decisioning on whether or not
// to read a potentially large request body based on the headers
//
// The default is to automatically read request bodies of Expect 100 Continue requests
// like they are normal requests
ContinueHandler func(header *RequestHeader) bool
// Server name for sending in response headers.
//
// Default server name is used if left blank.
Name string
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
//
// Concurrency only works if you either call Serve once, or only ServeConn multiple times.
// It works with ListenAndServe as well.
Concurrency int
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// ReadTimeout is the amount of time allowed to read
// the full request including body. The connection's read
// deadline is reset when the connection opens, or for
// keep-alive connections after the first byte has been read.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response. It is reset after the request handler
// has returned.
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alive is enabled. If IdleTimeout
// is zero, the value of ReadTimeout is used.
IdleTimeout time.Duration
// Maximum number of concurrent client connections allowed per IP.
//
// By default unlimited number of concurrent connections
// may be established to the server from a single IP address.
MaxConnsPerIP int
// Maximum number of requests served per connection.
//
// The server closes connection after the last request.
// 'Connection: close' header is added to the last response.
//
// By default unlimited number of requests may be served per connection.
MaxRequestsPerConn int
// MaxKeepaliveDuration is a no-op and only left here for backwards compatibility.
// Deprecated: Use IdleTimeout instead.
MaxKeepaliveDuration time.Duration
// Period between tcp keep-alive messages.
//
// TCP keep-alive period is determined by operation system by default.
TCPKeepalivePeriod time.Duration
// Maximum request body size.
//
// The server rejects requests with bodies exceeding this limit.
//
// Request body size is limited by DefaultMaxRequestBodySize by default.
MaxRequestBodySize int
// Whether to disable keep-alive connections.
//
// The server will close all the incoming connections after sending
// the first response to client if this option is set to true.
//
// By default keep-alive connections are enabled.
DisableKeepalive bool
// Whether to enable tcp keep-alive connections.
//
// Whether the operating system should send tcp keep-alive messages on the tcp connection.
//
// By default tcp keep-alive connections are disabled.
TCPKeepalive bool
// Aggressively reduces memory usage at the cost of higher CPU usage
// if set to true.
//
// Try enabling this option only if the server consumes too much memory
// serving mostly idle keep-alive connections. This may reduce memory
// usage by more than 50%.
//
// Aggressive memory usage reduction is disabled by default.
ReduceMemoryUsage bool
// Rejects all non-GET requests if set to true.
//
// This option is useful as anti-DoS protection for servers
// accepting only GET requests. The request size is limited
// by ReadBufferSize if GetOnly is set.
//
// Server accepts all the requests by default.
GetOnly bool
// Will not pre parse Multipart Form data if set to true.
//
// This option is useful for servers that desire to treat
// multipart form data as a binary blob, or choose when to parse the data.
//
// Server pre parses multipart form data by default.
DisablePreParseMultipartForm bool
// Logs all errors, including the most frequent
// 'connection reset by peer', 'broken pipe' and 'connection timeout'
// errors. Such errors are common in production serving real-world
// clients.
//
// By default the most frequent errors such as
// 'connection reset by peer', 'broken pipe' and 'connection timeout'
// are suppressed in order to limit output log traffic.
LogAllErrors bool
// Will not log potentially sensitive content in error logs
//
// This option is useful for servers that handle sensitive data
// in the request/response.
//
// Server logs all full errors by default.
SecureErrorLogMessage bool
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// incoming requests to other servers expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// SleepWhenConcurrencyLimitsExceeded is a duration to be slept of if
// the concurrency limit in exceeded (default [when is 0]: don't sleep
// and accept new connections immediately).
SleepWhenConcurrencyLimitsExceeded time.Duration
// NoDefaultServerHeader, when set to true, causes the default Server header
// to be excluded from the Response.
//
// The default Server header value is the value of the Name field or an
// internal default value in its absence. With this option set to true,
// the only time a Server header will be sent is if a non-zero length
// value is explicitly provided during a request.
NoDefaultServerHeader bool
// NoDefaultDate, when set to true, causes the default Date
// header to be excluded from the Response.
//
// The default Date header value is the current date value. When
// set to true, the Date will not be present.
NoDefaultDate bool
// NoDefaultContentType, when set to true, causes the default Content-Type
// header to be excluded from the Response.
//
// The default Content-Type header value is the internal default value. When
// set to true, the Content-Type will not be present.
NoDefaultContentType bool
// KeepHijackedConns is an opt-in disable of connection
// close by fasthttp after connections' HijackHandler returns.
// This allows to save goroutines, e.g. when fasthttp used to upgrade
// http connections to WS and connection goes to another handler,
// which will close it when needed.
KeepHijackedConns bool
// CloseOnShutdown when true adds a `Connection: close` header when when the server is shutting down.
CloseOnShutdown bool
// StreamRequestBody enables request body streaming,
// and calls the handler sooner when given body is
// larger then the current limit.
StreamRequestBody bool
// ConnState specifies an optional callback function that is
// called when a client connection changes state. See the
// ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState)
// Logger, which is used by RequestCtx.Logger().
//
// By default standard logger from log package is used.
Logger Logger
tlsConfig *tls.Config
nextProtos map[string]ServeHandler
concurrency uint32
concurrencyCh chan struct{}
perIPConnCounter perIPConnCounter
serverName atomic.Value
ctxPool sync.Pool
readerPool sync.Pool
writerPool sync.Pool
hijackConnPool sync.Pool
// We need to know our listeners so we can close them in Shutdown().
ln []net.Listener
mu sync.Mutex
open int32
stop int32
done chan struct{}
}
// TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout
// error with the given msg to the client if h didn't return during
// the given duration.
//
// The returned handler may return StatusTooManyRequests error with the given
// msg to the client if there are more than Server.Concurrency concurrent
// handlers h are running at the moment.
func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler {
return TimeoutWithCodeHandler(h, timeout, msg, StatusRequestTimeout)
}
// TimeoutWithCodeHandler creates RequestHandler, which returns an error with
// the given msg and status code to the client if h didn't return during
// the given duration.
//
// The returned handler may return StatusTooManyRequests error with the given
// msg to the client if there are more than Server.Concurrency concurrent
// handlers h are running at the moment.
func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string, statusCode int) RequestHandler {
if timeout <= 0 {
return h
}
return func(ctx *RequestCtx) {
concurrencyCh := ctx.s.concurrencyCh
select {
case concurrencyCh <- struct{}{}:
default:
ctx.Error(msg, StatusTooManyRequests)
return
}
ch := ctx.timeoutCh
if ch == nil {
ch = make(chan struct{}, 1)
ctx.timeoutCh = ch
}
go func() {
h(ctx)
ch <- struct{}{}
<-concurrencyCh
}()
ctx.timeoutTimer = initTimer(ctx.timeoutTimer, timeout)
select {
case <-ch:
case <-ctx.timeoutTimer.C:
ctx.TimeoutErrorWithCode(msg, statusCode)
}
stopTimer(ctx.timeoutTimer)
}
}
//RequestConfig configure the per request deadline and body limits
type RequestConfig struct {
// ReadTimeout is the maximum duration for reading the entire
// request body.
// a zero value means that default values will be honored
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response.
// a zero value means that default values will be honored
WriteTimeout time.Duration
// Maximum request body size.
// a zero value means that default values will be honored
MaxRequestBodySize int
}
// CompressHandler returns RequestHandler that transparently compresses
// response body generated by h if the request contains 'gzip' or 'deflate'
// 'Accept-Encoding' header.
func CompressHandler(h RequestHandler) RequestHandler {
return CompressHandlerLevel(h, CompressDefaultCompression)
}
// CompressHandlerLevel returns RequestHandler that transparently compresses
// response body generated by h if the request contains a 'gzip' or 'deflate'
// 'Accept-Encoding' header.
//
// Level is the desired compression level:
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func CompressHandlerLevel(h RequestHandler, level int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) {
ctx.Response.gzipBody(level) //nolint:errcheck
} else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) {
ctx.Response.deflateBody(level) //nolint:errcheck
}
}
}
// CompressHandlerBrotliLevel returns RequestHandler that transparently compresses
// response body generated by h if the request contains a 'br', 'gzip' or 'deflate'
// 'Accept-Encoding' header.
//
// brotliLevel is the desired compression level for brotli.
//
// * CompressBrotliNoCompression
// * CompressBrotliBestSpeed
// * CompressBrotliBestCompression
// * CompressBrotliDefaultCompression
//
// otherLevel is the desired compression level for gzip and deflate.
//
// * CompressNoCompression
// * CompressBestSpeed
// * CompressBestCompression
// * CompressDefaultCompression
// * CompressHuffmanOnly
func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
if ctx.Request.Header.HasAcceptEncodingBytes(strBr) {
ctx.Response.brotliBody(brotliLevel) //nolint:errcheck
} else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) {
ctx.Response.gzipBody(otherLevel) //nolint:errcheck
} else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) {
ctx.Response.deflateBody(otherLevel) //nolint:errcheck
}
}
}
// RequestCtx contains incoming request and manages outgoing response.
//
// It is forbidden copying RequestCtx instances.
//
// RequestHandler should avoid holding references to incoming RequestCtx and/or
// its' members after the return.
// If holding RequestCtx references after the return is unavoidable
// (for instance, ctx is passed to a separate goroutine and ctx lifetime cannot
// be controlled), then the RequestHandler MUST call ctx.TimeoutError()
// before return.
//
// It is unsafe modifying/reading RequestCtx instance from concurrently
// running goroutines. The only exception is TimeoutError*, which may be called
// while other goroutines accessing RequestCtx.
type RequestCtx struct {
noCopy noCopy //nolint:unused,structcheck
// Incoming request.
//
// Copying Request by value is forbidden. Use pointer to Request instead.
Request Request
// Outgoing response.
//
// Copying Response by value is forbidden. Use pointer to Response instead.
Response Response
userValues userData
connID uint64
connRequestNum uint64
connTime time.Time
remoteAddr net.Addr
time time.Time
logger ctxLogger
s *Server
c net.Conn
fbr firstByteReader
timeoutResponse *Response
timeoutCh chan struct{}
timeoutTimer *time.Timer
hijackHandler HijackHandler
hijackNoResponse bool
}
// HijackHandler must process the hijacked connection c.
//
// If KeepHijackedConns is disabled, which is by default,
// the connection c is automatically closed after returning from HijackHandler.
//
// The connection c must not be used after returning from the handler, if KeepHijackedConns is disabled.
//
// When KeepHijackedConns enabled, fasthttp will not Close() the connection,
// you must do it when you need it. You must not use c in any way after calling Close().
type HijackHandler func(c net.Conn)
// Hijack registers the given handler for connection hijacking.
//
// The handler is called after returning from RequestHandler
// and sending http response. The current connection is passed
// to the handler. The connection is automatically closed after
// returning from the handler.
//
// The server skips calling the handler in the following cases:
//
// * 'Connection: close' header exists in either request or response.
// * Unexpected error during response writing to the connection.
//
// The server stops processing requests from hijacked connections.
//
// Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc.
// aren't applied to hijacked connections.
//
// The handler must not retain references to ctx members.
//
// Arbitrary 'Connection: Upgrade' protocols may be implemented
// with HijackHandler. For instance,
//
// * WebSocket ( https://en.wikipedia.org/wiki/WebSocket )
// * HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 )
//
func (ctx *RequestCtx) Hijack(handler HijackHandler) {
ctx.hijackHandler = handler
}
// HijackSetNoResponse changes the behavior of hijacking a request.
// If HijackSetNoResponse is called with false fasthttp will send a response
// to the client before calling the HijackHandler (default). If HijackSetNoResponse
// is called with true no response is send back before calling the
// HijackHandler supplied in the Hijack function.
func (ctx *RequestCtx) HijackSetNoResponse(noResponse bool) {
ctx.hijackNoResponse = noResponse
}
// Hijacked returns true after Hijack is called.
func (ctx *RequestCtx) Hijacked() bool {
return ctx.hijackHandler != nil
}
// SetUserValue stores the given value (arbitrary object)
// under the given key in ctx.
//
// The value stored in ctx may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
ctx.userValues.Set(key, value)
}
// SetUserValueBytes stores the given value (arbitrary object)
// under the given key in ctx.
//
// The value stored in ctx may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values stored in ctx are deleted after returning from RequestHandler.
func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
ctx.userValues.SetBytes(key, value)
}
// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key string) interface{} {
return ctx.userValues.Get(key)
}
// UserValueBytes returns the value stored via SetUserValue*
// under the given key.
func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}
// VisitUserValues calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
}
}
// ResetUserValues allows to reset user values from Request Context
func (ctx *RequestCtx) ResetUserValues() {
ctx.userValues.Reset()
}
// RemoveUserValue removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValue(key string) {
ctx.userValues.Remove(key)
}
// RemoveUserValueBytes removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValueBytes(key []byte) {
ctx.userValues.RemoveBytes(key)
}
type connTLSer interface {
Handshake() error
ConnectionState() tls.ConnectionState
}
// IsTLS returns true if the underlying connection is tls.Conn.
//
// tls.Conn is an encrypted connection (aka SSL, HTTPS).
func (ctx *RequestCtx) IsTLS() bool {
// cast to (connTLSer) instead of (*tls.Conn), since it catches
// cases with overridden tls.Conn such as:
//
// type customConn struct {
// *tls.Conn
//
// // other custom fields here
// }
// perIPConn wraps the net.Conn in the Conn field
if pic, ok := ctx.c.(*perIPConn); ok {
_, ok := pic.Conn.(connTLSer)
return ok
}
_, ok := ctx.c.(connTLSer)
return ok
}
// TLSConnectionState returns TLS connection state.
//
// The function returns nil if the underlying connection isn't tls.Conn.
//
// The returned state may be used for verifying TLS version, client certificates,
// etc.
func (ctx *RequestCtx) TLSConnectionState() *tls.ConnectionState {
tlsConn, ok := ctx.c.(connTLSer)
if !ok {
return nil
}
state := tlsConn.ConnectionState()
return &state
}
// Conn returns a reference to the underlying net.Conn.
//
// WARNING: Only use this method if you know what you are doing!
//
// Reading from or writing to the returned connection will end badly!
func (ctx *RequestCtx) Conn() net.Conn {
return ctx.c
}
type firstByteReader struct {
c net.Conn
ch byte
byteRead bool
}
func (r *firstByteReader) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
nn := 0
if !r.byteRead {
b[0] = r.ch
b = b[1:]
r.byteRead = true
nn = 1
}
n, err := r.c.Read(b)
return n + nn, err
}
// Logger is used for logging formatted messages.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...interface{})
}
var ctxLoggerLock sync.Mutex
type ctxLogger struct {
ctx *RequestCtx
logger Logger
}
func (cl *ctxLogger) Printf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
ctxLoggerLock.Lock()
cl.logger.Printf("%.3f %s - %s", time.Since(cl.ctx.ConnTime()).Seconds(), cl.ctx.String(), msg)
ctxLoggerLock.Unlock()
}
var zeroTCPAddr = &net.TCPAddr{
IP: net.IPv4zero,
}
// String returns unique string representation of the ctx.
//
// The returned value may be useful for logging.
func (ctx *RequestCtx) String() string {
return fmt.Sprintf("#%016X - %s<->%s - %s %s", ctx.ID(), ctx.LocalAddr(), ctx.RemoteAddr(), ctx.Request.Header.Method(), ctx.URI().FullURI())
}
// ID returns unique ID of the request.
func (ctx *RequestCtx) ID() uint64 {
return (ctx.connID << 32) | ctx.connRequestNum
}
// ConnID returns unique connection ID.
//
// This ID may be used to match distinct requests to the same incoming
// connection.
func (ctx *RequestCtx) ConnID() uint64 {
return ctx.connID
}
// Time returns RequestHandler call time.
func (ctx *RequestCtx) Time() time.Time {
return ctx.time
}
// ConnTime returns the time the server started serving the connection
// the current request came from.
func (ctx *RequestCtx) ConnTime() time.Time {
return ctx.connTime
}
// ConnRequestNum returns request sequence number
// for the current connection.
//
// Sequence starts with 1.
func (ctx *RequestCtx) ConnRequestNum() uint64 {
return ctx.connRequestNum
}
// SetConnectionClose sets 'Connection: close' response header and closes
// connection after the RequestHandler returns.
func (ctx *RequestCtx) SetConnectionClose() {
ctx.Response.SetConnectionClose()
}
// SetStatusCode sets response status code.
func (ctx *RequestCtx) SetStatusCode(statusCode int) {
ctx.Response.SetStatusCode(statusCode)
}
// SetContentType sets response Content-Type.
func (ctx *RequestCtx) SetContentType(contentType string) {
ctx.Response.Header.SetContentType(contentType)
}
// SetContentTypeBytes sets response Content-Type.
//
// It is safe modifying contentType buffer after function return.
func (ctx *RequestCtx) SetContentTypeBytes(contentType []byte) {
ctx.Response.Header.SetContentTypeBytes(contentType)
}
// RequestURI returns RequestURI.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) RequestURI() []byte {
return ctx.Request.Header.RequestURI()
}
// URI returns requested uri.
//
// This uri is valid until your request handler returns.
func (ctx *RequestCtx) URI() *URI {
return ctx.Request.URI()
}
// Referer returns request referer.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Referer() []byte {
return ctx.Request.Header.Referer()
}
// UserAgent returns User-Agent header value from the request.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) UserAgent() []byte {
return ctx.Request.Header.UserAgent()
}
// Path returns requested path.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Path() []byte {
return ctx.URI().Path()
}
// Host returns requested host.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Host() []byte {
return ctx.URI().Host()
}
// QueryArgs returns query arguments from RequestURI.
//
// It doesn't return POST'ed arguments - use PostArgs() for this.
//
// See also PostArgs, FormValue and FormFile.
//
// These args are valid until your request handler returns.
func (ctx *RequestCtx) QueryArgs() *Args {
return ctx.URI().QueryArgs()
}
// PostArgs returns POST arguments.
//
// It doesn't return query arguments from RequestURI - use QueryArgs for this.
//
// See also QueryArgs, FormValue and FormFile.
//
// These args are valid until your request handler returns.
func (ctx *RequestCtx) PostArgs() *Args {
return ctx.Request.PostArgs()
}
// MultipartForm returns requests's multipart form.
//
// Returns ErrNoMultipartForm if request's content-type
// isn't 'multipart/form-data'.
//
// All uploaded temporary files are automatically deleted after
// returning from RequestHandler. Either move or copy uploaded files
// into new place if you want retaining them.
//
// Use SaveMultipartFile function for permanently saving uploaded file.
//
// The returned form is valid until your request handler returns.
//
// See also FormFile and FormValue.
func (ctx *RequestCtx) MultipartForm() (*multipart.Form, error) {
return ctx.Request.MultipartForm()
}
// FormFile returns uploaded file associated with the given multipart form key.
//
// The file is automatically deleted after returning from RequestHandler,
// so either move or copy uploaded file into new place if you want retaining it.
//
// Use SaveMultipartFile function for permanently saving uploaded file.
//
// The returned file header is valid until your request handler returns.
func (ctx *RequestCtx) FormFile(key string) (*multipart.FileHeader, error) {
mf, err := ctx.MultipartForm()
if err != nil {
return nil, err
}
if mf.File == nil {
return nil, err
}
fhh := mf.File[key]
if fhh == nil {
return nil, ErrMissingFile
}
return fhh[0], nil
}
// ErrMissingFile may be returned from FormFile when the is no uploaded file
// associated with the given multipart form key.
var ErrMissingFile = errors.New("there is no uploaded file associated with the given key")
// SaveMultipartFile saves multipart file fh under the given filename path.
func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) {
var (
f multipart.File
ff *os.File
)
f, err = fh.Open()
if err != nil {
return
}
var ok bool
if ff, ok = f.(*os.File); ok {
// Windows can't rename files that are opened.
if err = f.Close(); err != nil {
return
}
// If renaming fails we try the normal copying method.
// Renaming could fail if the files are on different devices.
if os.Rename(ff.Name(), path) == nil {
return nil
}
// Reopen f for the code below.
if f, err = fh.Open(); err != nil {
return
}
}
defer func() {
e := f.Close()
if err == nil {
err = e
}
}()
if ff, err = os.Create(path); err != nil {
return
}
defer func() {
e := ff.Close()
if err == nil {
err = e
}
}()
_, err = copyZeroAlloc(ff, f)
return
}
// FormValue returns form value associated with the given key.
//
// The value is searched in the following places:
//
// * Query string.
// * POST or PUT body.
//
// There are more fine-grained methods for obtaining form values:
//
// * QueryArgs for obtaining values from query string.
// * PostArgs for obtaining values from POST or PUT body.
// * MultipartForm for obtaining values from multipart form.
// * FormFile for obtaining uploaded files.
//
// The returned value is valid until your request handler returns.
func (ctx *RequestCtx) FormValue(key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
return nil
}
// IsGet returns true if request method is GET.
func (ctx *RequestCtx) IsGet() bool {
return ctx.Request.Header.IsGet()
}
// IsPost returns true if request method is POST.
func (ctx *RequestCtx) IsPost() bool {
return ctx.Request.Header.IsPost()
}
// IsPut returns true if request method is PUT.
func (ctx *RequestCtx) IsPut() bool {
return ctx.Request.Header.IsPut()
}
// IsDelete returns true if request method is DELETE.
func (ctx *RequestCtx) IsDelete() bool {
return ctx.Request.Header.IsDelete()
}
// IsConnect returns true if request method is CONNECT.
func (ctx *RequestCtx) IsConnect() bool {
return ctx.Request.Header.IsConnect()
}
// IsOptions returns true if request method is OPTIONS.
func (ctx *RequestCtx) IsOptions() bool {
return ctx.Request.Header.IsOptions()
}
// IsTrace returns true if request method is TRACE.
func (ctx *RequestCtx) IsTrace() bool {
return ctx.Request.Header.IsTrace()
}
// IsPatch returns true if request method is PATCH.
func (ctx *RequestCtx) IsPatch() bool {
return ctx.Request.Header.IsPatch()
}
// Method return request method.
//
// Returned value is valid until your request handler returns.
func (ctx *RequestCtx) Method() []byte {
return ctx.Request.Header.Method()
}
// IsHead returns true if request method is HEAD.
func (ctx *RequestCtx) IsHead() bool {
return ctx.Request.Header.IsHead()
}
// RemoteAddr returns client address for the given request.
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteAddr() net.Addr {
if ctx.remoteAddr != nil {
return ctx.remoteAddr
}
if ctx.c == nil {
return zeroTCPAddr
}
addr := ctx.c.RemoteAddr()
if addr == nil {
return zeroTCPAddr
}
return addr
}
// SetRemoteAddr sets remote address to the given value.
//
// Set nil value to resore default behaviour for using
// connection remote address.
func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) {
ctx.remoteAddr = remoteAddr
}
// LocalAddr returns server address for the given request.
//
// Always returns non-nil result.
func (ctx *RequestCtx) LocalAddr() net.Addr {
if ctx.c == nil {
return zeroTCPAddr
}
addr := ctx.c.LocalAddr()
if addr == nil {
return zeroTCPAddr
}
return addr
}
// RemoteIP returns the client ip the request came from.
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteIP() net.IP {
return addrToIP(ctx.RemoteAddr())
}
// LocalIP returns the server ip the request came to.
//
// Always returns non-nil result.
func (ctx *RequestCtx) LocalIP() net.IP {
return addrToIP(ctx.LocalAddr())
}
func addrToIP(addr net.Addr) net.IP {
x, ok := addr.(*net.TCPAddr)
if !ok {
return net.IPv4zero
}
return x.IP
}
// Error sets response status code to the given value and sets response body
// to the given message.
//
// Warning: this will reset the response headers and body already set!
func (ctx *RequestCtx) Error(msg string, statusCode int) {
ctx.Response.Reset()
ctx.SetStatusCode(statusCode)
ctx.SetContentTypeBytes(defaultContentType)
ctx.SetBodyString(msg)
}
// Success sets response Content-Type and body to the given values.
func (ctx *RequestCtx) Success(contentType string, body []byte) {
ctx.SetContentType(contentType)
ctx.SetBody(body)
}
// SuccessString sets response Content-Type and body to the given values.
func (ctx *RequestCtx) SuccessString(contentType, body string) {
ctx.SetContentType(contentType)
ctx.SetBodyString(body)
}
// Redirect sets 'Location: uri' response header and sets the given statusCode.
//
// statusCode must have one of the following values:
//
// * StatusMovedPermanently (301)
// * StatusFound (302)
// * StatusSeeOther (303)
// * StatusTemporaryRedirect (307)
// * StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
// strLocation = []byte("Location") // Put this with your top level var () declarations.
// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
//
func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
u := AcquireURI()
ctx.URI().CopyTo(u)
u.Update(uri)
ctx.redirect(u.FullURI(), statusCode)
ReleaseURI(u)
}
// RedirectBytes sets 'Location: uri' response header and sets
// the given statusCode.
//
// statusCode must have one of the following values:
//
// * StatusMovedPermanently (301)
// * StatusFound (302)
// * StatusSeeOther (303)
// * StatusTemporaryRedirect (307)
// * StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
// strLocation = []byte("Location") // Put this with your top level var () declarations.
// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
//
func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) {
s := b2s(uri)
ctx.Redirect(s, statusCode)
}
func (ctx *RequestCtx) redirect(uri []byte, statusCode int) {
ctx.Response.Header.SetCanonical(strLocation, uri)
statusCode = getRedirectStatusCode(statusCode)
ctx.Response.SetStatusCode(statusCode)
}
func getRedirectStatusCode(statusCode int) int {
if statusCode == StatusMovedPermanently || statusCode == StatusFound ||
statusCode == StatusSeeOther || statusCode == StatusTemporaryRedirect ||
statusCode == StatusPermanentRedirect {
return statusCode
}
return StatusFound
}
// SetBody sets response body to the given value.
//
// It is safe re-using body argument after the function returns.
func (ctx *RequestCtx) SetBody(body []byte) {
ctx.Response.SetBody(body)
}
// SetBodyString sets response body to the given value.
func (ctx *RequestCtx) SetBodyString(body string) {
ctx.Response.SetBodyString(body)
}
// ResetBody resets response body contents.
func (ctx *RequestCtx) ResetBody() {
ctx.Response.ResetBody()
}
// SendFile sends local file contents from the given path as response body.
//
// This is a shortcut to ServeFile(ctx, path).
//
// SendFile logs all the errors via ctx.Logger.
//
// See also ServeFile, FSHandler and FS.
func (ctx *RequestCtx) SendFile(path string) {
ServeFile(ctx, path)
}
// SendFileBytes sends local file contents from the given path as response body.
//
// This is a shortcut to ServeFileBytes(ctx, path).
//
// SendFileBytes logs all the errors via ctx.Logger.
//
// See also ServeFileBytes, FSHandler and FS.
func (ctx *RequestCtx) SendFileBytes(path []byte) {
ServeFileBytes(ctx, path)
}
// IfModifiedSince returns true if lastModified exceeds 'If-Modified-Since'
// value from the request header.
//
// The function returns true also 'If-Modified-Since' request header is missing.
func (ctx *RequestCtx) IfModifiedSince(lastModified time.Time) bool {
ifModStr := ctx.Request.Header.peek(strIfModifiedSince)
if len(ifModStr) == 0 {
return true
}
ifMod, err := ParseHTTPDate(ifModStr)
if err != nil {
return true
}
lastModified = lastModified.Truncate(time.Second)
return ifMod.Before(lastModified)
}
// NotModified resets response and sets '304 Not Modified' response status code.
func (ctx *RequestCtx) NotModified() {
ctx.Response.Reset()
ctx.SetStatusCode(StatusNotModified)
}
// NotFound resets response and sets '404 Not Found' response status code.
func (ctx *RequestCtx) NotFound() {
ctx.Response.Reset()
ctx.SetStatusCode(StatusNotFound)
ctx.SetBodyString("404 Page not found")
}
// Write writes p into response body.
func (ctx *RequestCtx) Write(p []byte) (int, error) {
ctx.Response.AppendBody(p)
return len(p), nil
}
// WriteString appends s to response body.
func (ctx *RequestCtx) WriteString(s string) (int, error) {
ctx.Response.AppendBodyString(s)
return len(s), nil
}
// PostBody returns POST request body.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) PostBody() []byte {
return ctx.Request.Body()
}
// SetBodyStream sets response body stream and, optionally body size.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// If bodySize is >= 0, then bodySize bytes must be provided by bodyStream
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// See also SetBodyStreamWriter.
func (ctx *RequestCtx) SetBodyStream(bodyStream io.Reader, bodySize int) {
ctx.Response.SetBodyStream(bodyStream, bodySize)
}
// SetBodyStreamWriter registers the given stream writer for populating
// response body.
//
// Access to RequestCtx and/or its' members is forbidden from sw.
//
// This function may be used in the following cases:
//
// * if response body is too big (more than 10MB).
// * if response body is streamed from slow external sources.
// * if response body must be streamed to the client in chunks.
// (aka `http server push`).
func (ctx *RequestCtx) SetBodyStreamWriter(sw StreamWriter) {
ctx.Response.SetBodyStreamWriter(sw)
}
// IsBodyStream returns true if response body is set via SetBodyStream*.
func (ctx *RequestCtx) IsBodyStream() bool {
return ctx.Response.IsBodyStream()
}
// Logger returns logger, which may be used for logging arbitrary
// request-specific messages inside RequestHandler.
//
// Each message logged via returned logger contains request-specific information
// such as request id, request duration, local address, remote address,
// request method and request url.
//
// It is safe re-using returned logger for logging multiple messages
// for the current request.
//
// The returned logger is valid until your request handler returns.
func (ctx *RequestCtx) Logger() Logger {
if ctx.logger.ctx == nil {
ctx.logger.ctx = ctx
}
if ctx.logger.logger == nil {
ctx.logger.logger = ctx.s.logger()
}
return &ctx.logger
}
// TimeoutError sets response status code to StatusRequestTimeout and sets
// body to the given msg.
//
// All response modifications after TimeoutError call are ignored.
//
// TimeoutError MUST be called before returning from RequestHandler if there are
// references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutError(msg string) {
ctx.TimeoutErrorWithCode(msg, StatusRequestTimeout)
}
// TimeoutErrorWithCode sets response body to msg and response status
// code to statusCode.
//
// All response modifications after TimeoutErrorWithCode call are ignored.
//
// TimeoutErrorWithCode MUST be called before returning from RequestHandler
// if there are references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutErrorWithCode(msg string, statusCode int) {
var resp Response
resp.SetStatusCode(statusCode)
resp.SetBodyString(msg)
ctx.TimeoutErrorWithResponse(&resp)
}
// TimeoutErrorWithResponse marks the ctx as timed out and sends the given
// response to the client.
//
// All ctx modifications after TimeoutErrorWithResponse call are ignored.
//
// TimeoutErrorWithResponse MUST be called before returning from RequestHandler
// if there are references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutErrorWithResponse(resp *Response) {
respCopy := &Response{}
resp.CopyTo(respCopy)
ctx.timeoutResponse = respCopy
}
// NextProto adds nph to be processed when key is negotiated when TLS
// connection is established.
//
// This function can only be called before the server is started.
func (s *Server) NextProto(key string, nph ServeHandler) {
if s.nextProtos == nil {
s.nextProtos = make(map[string]ServeHandler)
}
s.configTLS()
s.tlsConfig.NextProtos = append(s.tlsConfig.NextProtos, key)
s.nextProtos[key] = nph
}
func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
if tlsConn, ok := c.(connTLSer); ok {
if s.ReadTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
}
}
if s.WriteTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
}
}
err = tlsConn.Handshake()
if err == nil {
proto = tlsConn.ConnectionState().NegotiatedProtocol
}
}
return
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe, ListenAndServeTLS and
// ListenAndServeTLSEmbed so dead TCP connections (e.g. closing laptop mid-download)
// eventually go away.
type tcpKeepaliveListener struct {
*net.TCPListener
keepalive bool
keepalivePeriod time.Duration
}
func (ln tcpKeepaliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(ln.keepalive); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
if ln.keepalivePeriod > 0 {
if err := tc.SetKeepAlivePeriod(ln.keepalivePeriod); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
}
return tc, nil
}
// ListenAndServe serves HTTP requests from the given TCP4 addr.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
// such as IPv6.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServe(addr string) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.Serve(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
})
}
return s.Serve(ln)
}
// ListenAndServeUNIX serves HTTP requests from the given UNIX addr.
//
// The function deletes existing file at addr before starting serving.
//
// The server sets the given file mode for the UNIX addr.
func (s *Server) ListenAndServeUNIX(addr string, mode os.FileMode) error {
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("unexpected error when trying to remove unix socket file %q: %s", addr, err)
}
ln, err := net.Listen("unix", addr)
if err != nil {
return err
}
if err = os.Chmod(addr, mode); err != nil {
return fmt.Errorf("cannot chmod %#o for %q: %s", mode, addr, err)
}
return s.Serve(ln)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP4 addr.
//
// certFile and keyFile are paths to TLS certificate and key files.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
// such as IPv6.
//
// If the certFile or keyFile has not been provided to the server structure,
// the function will use the previously added TLS configuration.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.ServeTLS(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
}, certFile, keyFile)
}
return s.ServeTLS(ln, certFile, keyFile)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP4 addr.
//
// certData and keyData must contain valid TLS certificate and key data.
//
// Pass custom listener to Serve if you need listening on arbitrary media
// such as IPv6.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.ServeTLSEmbed(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
}, certData, keyData)
}
return s.ServeTLSEmbed(ln, certData, keyData)
}
// ServeTLS serves HTTPS requests from the given listener.
//
// certFile and keyFile are paths to TLS certificate and key files.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error {
s.mu.Lock()
err := s.AppendCert(certFile, keyFile)
if err != nil && err != errNoCertOrKeyProvided {
s.mu.Unlock()
return err
}
if s.tlsConfig == nil {
s.mu.Unlock()
return errNoCertOrKeyProvided
}
// BuildNameToCertificate has been deprecated since 1.14.
// But since we also support older versions we'll keep this here.
s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck
s.mu.Unlock()
return s.Serve(
tls.NewListener(ln, s.tlsConfig),
)
}
// ServeTLSEmbed serves HTTPS requests from the given listener.
//
// certData and keyData must contain valid TLS certificate and key data.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error {
s.mu.Lock()
err := s.AppendCertEmbed(certData, keyData)
if err != nil && err != errNoCertOrKeyProvided {
s.mu.Unlock()
return err
}
if s.tlsConfig == nil {
s.mu.Unlock()
return errNoCertOrKeyProvided
}
// BuildNameToCertificate has been deprecated since 1.14.
// But since we also support older versions we'll keep this here.
s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck
s.mu.Unlock()
return s.Serve(
tls.NewListener(ln, s.tlsConfig),
)
}
// AppendCert appends certificate and keyfile to TLS Configuration.
//
// This function allows programmer to handle multiple domains
// in one server structure. See examples/multidomain
func (s *Server) AppendCert(certFile, keyFile string) error {
if len(certFile) == 0 && len(keyFile) == 0 {
return errNoCertOrKeyProvided
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
}
s.configTLS()
s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert)
return nil
}
// AppendCertEmbed does the same as AppendCert but using in-memory data.
func (s *Server) AppendCertEmbed(certData, keyData []byte) error {
if len(certData) == 0 && len(keyData) == 0 {
return errNoCertOrKeyProvided
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s",
len(certData), len(keyData), err)
}
s.configTLS()
s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert)
return nil
}
func (s *Server) configTLS() {
if s.tlsConfig == nil {
s.tlsConfig = &tls.Config{
PreferServerCipherSuites: true,
}
}
}
// DefaultConcurrency is the maximum number of concurrent connections
// the Server may serve by default (i.e. if Server.Concurrency isn't set).
const DefaultConcurrency = 256 * 1024
// Serve serves incoming connections from the given listener.
//
// Serve blocks until the given listener returns permanent error.
func (s *Server) Serve(ln net.Listener) error {
var lastOverflowErrorTime time.Time
var lastPerIPErrorTime time.Time
var c net.Conn
var err error
maxWorkersCount := s.getConcurrency()
s.mu.Lock()
{
s.ln = append(s.ln, ln)
if s.done == nil {
s.done = make(chan struct{})
}
if s.concurrencyCh == nil {
s.concurrencyCh = make(chan struct{}, maxWorkersCount)
}
}
s.mu.Unlock()
wp := &workerPool{
WorkerFunc: s.serveConn,
MaxWorkersCount: maxWorkersCount,
LogAllErrors: s.LogAllErrors,
Logger: s.logger(),
connState: s.setState,
}
wp.Start()
// Count our waiting to accept a connection as an open connection.
// This way we can't get into any weird state where just after accepting
// a connection Shutdown is called which reads open as 0 because it isn't
// incremented yet.
atomic.AddInt32(&s.open, 1)
defer atomic.AddInt32(&s.open, -1)
for {
if c, err = acceptConn(s, ln, &lastPerIPErrorTime); err != nil {
wp.Stop()
if err == io.EOF {
return nil
}
return err
}
s.setState(c, StateNew)
atomic.AddInt32(&s.open, 1)
if !wp.Serve(c) {
atomic.AddInt32(&s.open, -1)
s.writeFastError(c, StatusServiceUnavailable,
"The connection cannot be served because Server.Concurrency limit exceeded")
c.Close()
s.setState(c, StateClosed)
if time.Since(lastOverflowErrorTime) > time.Minute {
s.logger().Printf("The incoming connection cannot be served, because %d concurrent connections are served. "+
"Try increasing Server.Concurrency", maxWorkersCount)
lastOverflowErrorTime = time.Now()
}
// The current server reached concurrency limit,
// so give other concurrently running servers a chance
// accepting incoming connections on the same address.
//
// There is a hope other servers didn't reach their
// concurrency limits yet :)
//
// See also: https://github.com/valyala/fasthttp/pull/485#discussion_r239994990
if s.SleepWhenConcurrencyLimitsExceeded > 0 {
time.Sleep(s.SleepWhenConcurrencyLimitsExceeded)
}
}
c = nil
}
}
// Shutdown gracefully shuts down the server without interrupting any active connections.
// Shutdown works by first closing all open listeners and then waiting indefinitely for all connections to return to idle and then shut down.
//
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil.
// Make sure the program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0.
func (s *Server) Shutdown() error {
s.mu.Lock()
defer s.mu.Unlock()
atomic.StoreInt32(&s.stop, 1)
defer atomic.StoreInt32(&s.stop, 0)
if s.ln == nil {
return nil
}
for _, ln := range s.ln {
if err := ln.Close(); err != nil {
return err
}
}
if s.done != nil {
close(s.done)
}
// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
// Now we just have to wait until all workers are done.
for {
if open := atomic.LoadInt32(&s.open); open == 0 {
break
}
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
time.Sleep(time.Millisecond * 100)
}
s.done = nil
s.ln = nil
return nil
}
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
for {
c, err := ln.Accept()
if err != nil {
if c != nil {
panic("BUG: net.Listener returned non-nil conn and non-nil error")
}
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
s.logger().Printf("Temporary error when accepting new connections: %s", netErr)
time.Sleep(time.Second)
continue
}
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
s.logger().Printf("Permanent error when accepting new connections: %s", err)
return nil, err
}
return nil, io.EOF
}
if c == nil {
panic("BUG: net.Listener returned (nil, nil)")
}
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
if time.Since(*lastPerIPErrorTime) > time.Minute {
s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d",
getConnIP4(c), s.MaxConnsPerIP)
*lastPerIPErrorTime = time.Now()
}
continue
}
c = pic
}
return c, nil
}
}
func wrapPerIPConn(s *Server, c net.Conn) net.Conn {
ip := getUint32IP(c)
if ip == 0 {
return c
}
n := s.perIPConnCounter.Register(ip)
if n > s.MaxConnsPerIP {
s.perIPConnCounter.Unregister(ip)
s.writeFastError(c, StatusTooManyRequests, "The number of connections from your ip exceeds MaxConnsPerIP")
c.Close()
return nil
}
return acquirePerIPConn(c, ip, &s.perIPConnCounter)
}
var defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
func (s *Server) logger() Logger {
if s.Logger != nil {
return s.Logger
}
return defaultLogger
}
var (
// ErrPerIPConnLimit may be returned from ServeConn if the number of connections
// per ip exceeds Server.MaxConnsPerIP.
ErrPerIPConnLimit = errors.New("too many connections per ip")
// ErrConcurrencyLimit may be returned from ServeConn if the number
// of concurrently served connections exceeds Server.Concurrency.
ErrConcurrencyLimit = errors.New("cannot serve the connection because Server.Concurrency concurrent connections are served")
)
// ServeConn serves HTTP requests from the given connection.
//
// ServeConn returns nil if all requests from the c are successfully served.
// It returns non-nil error otherwise.
//
// Connection c must immediately propagate all the data passed to Write()
// to the client. Otherwise requests' processing may hang.
//
// ServeConn closes c before returning.
func (s *Server) ServeConn(c net.Conn) error {
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
return ErrPerIPConnLimit
}
c = pic
}
n := atomic.AddUint32(&s.concurrency, 1)
if n > uint32(s.getConcurrency()) {
atomic.AddUint32(&s.concurrency, ^uint32(0))
s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded")
c.Close()
return ErrConcurrencyLimit
}
atomic.AddInt32(&s.open, 1)
err := s.serveConn(c)
atomic.AddUint32(&s.concurrency, ^uint32(0))
if err != errHijacked {
err1 := c.Close()
s.setState(c, StateClosed)
if err == nil {
err = err1
}
} else {
err = nil
s.setState(c, StateHijacked)
}
return err
}
var errHijacked = errors.New("connection has been hijacked")
// GetCurrentConcurrency returns a number of currently served
// connections.
//
// This function is intended be used by monitoring systems
func (s *Server) GetCurrentConcurrency() uint32 {
return atomic.LoadUint32(&s.concurrency)
}
// GetOpenConnectionsCount returns a number of opened connections.
//
// This function is intended be used by monitoring systems
func (s *Server) GetOpenConnectionsCount() int32 {
if atomic.LoadInt32(&s.stop) == 0 {
// Decrement by one to avoid reporting the extra open value that gets
// counted while the server is listening.
return atomic.LoadInt32(&s.open) - 1
}
// This is not perfect, because s.stop could have changed to zero
// before we load the value of s.open. However, in the common case
// this avoids underreporting open connections by 1 during server shutdown.
return atomic.LoadInt32(&s.open)
}
func (s *Server) getConcurrency() int {
n := s.Concurrency
if n <= 0 {
n = DefaultConcurrency
}
return n
}
var globalConnID uint64
func nextConnID() uint64 {
return atomic.AddUint64(&globalConnID, 1)
}
// DefaultMaxRequestBodySize is the maximum request body size the server
// reads by default.
//
// See Server.MaxRequestBodySize for details.
const DefaultMaxRequestBodySize = 4 * 1024 * 1024
func (s *Server) idleTimeout() time.Duration {
if s.IdleTimeout != 0 {
return s.IdleTimeout
}
return s.ReadTimeout
}
func (s *Server) serveConnCleanup() {
atomic.AddInt32(&s.open, -1)
atomic.AddUint32(&s.concurrency, ^uint32(0))
}
func (s *Server) serveConn(c net.Conn) (err error) {
defer s.serveConnCleanup()
atomic.AddUint32(&s.concurrency, 1)
var proto string
if proto, err = s.getNextProto(c); err != nil {
return
}
if handler, ok := s.nextProtos[proto]; ok {
// Remove read or write deadlines that might have previously been set.
// The next handler is responsible for setting its own deadlines.
if s.ReadTimeout > 0 || s.WriteTimeout > 0 {
if err := c.SetDeadline(zeroTime); err != nil {
panic(fmt.Sprintf("BUG: error in SetDeadline(zeroTime): %s", err))
}
}
return handler(c)
}
var serverName []byte
if !s.NoDefaultServerHeader {
serverName = s.getServerName()
}
connRequestNum := uint64(0)
connID := nextConnID()
connTime := time.Now()
maxRequestBodySize := s.MaxRequestBodySize
if maxRequestBodySize <= 0 {
maxRequestBodySize = DefaultMaxRequestBodySize
}
writeTimeout := s.WriteTimeout
previousWriteTimeout := time.Duration(0)
ctx := s.acquireCtx(c)
ctx.connTime = connTime
isTLS := ctx.IsTLS()
var (
br *bufio.Reader
bw *bufio.Writer
timeoutResponse *Response
hijackHandler HijackHandler
hijackNoResponse bool
connectionClose bool
isHTTP11 bool
reqReset bool
continueReadingRequest bool = true
)
for {
connRequestNum++
// If this is a keep-alive connection set the idle timeout.
if connRequestNum > 1 {
if d := s.idleTimeout(); d > 0 {
if err := c.SetReadDeadline(time.Now().Add(d)); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", d, err))
}
}
}
if !s.ReduceMemoryUsage || br != nil {
if br == nil {
br = acquireReader(ctx)
}
// If this is a keep-alive connection we want to try and read the first bytes
// within the idle time.
if connRequestNum > 1 {
var b []byte
b, err = br.Peek(1)
if len(b) == 0 {
// If reading from a keep-alive connection returns nothing it means
// the connection was closed (either timeout or from the other side).
if err != io.EOF {
err = ErrNothingRead{err}
}
}
}
} else {
// If this is a keep-alive connection acquireByteReader will try to peek
// a couple of bytes already so the idle timeout will already be used.
br, err = acquireByteReader(&ctx)
}
ctx.Request.isTLS = isTLS
ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType
ctx.Response.Header.noDefaultDate = s.NoDefaultDate
// Secure header error logs configuration
ctx.Request.Header.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Response.Header.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Request.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Response.secureErrorLogMessage = s.SecureErrorLogMessage
if err == nil {
if s.ReadTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
}
} else if s.IdleTimeout > 0 && connRequestNum > 1 {
// If this was an idle connection and the server has an IdleTimeout but
// no ReadTimeout then we should remove the ReadTimeout.
if err := c.SetReadDeadline(zeroTime); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(zeroTime): %s", err))
}
}
if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}
// Reading Headers.
//
// If we have pipline response in the outgoing buffer,
// we only want to try and read the next headers once.
// If we have to wait for the next request we flush the
// outgoing buffer first so it doesn't have to wait.
if bw != nil && bw.Buffered() > 0 {
err = ctx.Request.Header.readLoop(br, false)
if err == errNeedMore {
err = bw.Flush()
if err != nil {
break
}
err = ctx.Request.Header.Read(br)
}
} else {
err = ctx.Request.Header.Read(br)
}
if err == nil {
if onHdrRecv := s.HeaderReceived; onHdrRecv != nil {
reqConf := onHdrRecv(&ctx.Request.Header)
if reqConf.ReadTimeout > 0 {
deadline := time.Now().Add(reqConf.ReadTimeout)
if err := c.SetReadDeadline(deadline); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err))
}
}
if reqConf.MaxRequestBodySize > 0 {
maxRequestBodySize = reqConf.MaxRequestBodySize
}
if reqConf.WriteTimeout > 0 {
writeTimeout = reqConf.WriteTimeout
}
}
//read body
if s.StreamRequestBody {
err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
}
}
if err == nil {
// If we read any bytes off the wire, we're active.
s.setState(c, StateActive)
}
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
}
if err != nil {
if err == io.EOF {
err = nil
} else if nr, ok := err.(ErrNothingRead); ok {
if connRequestNum > 1 {
// This is not the first request and we haven't read a single byte
// of a new request yet. This means it's just a keep-alive connection
// closing down either because the remote closed it or because
// or a read timeout on our side. Either way just close the connection
// and don't return any error response.
err = nil
} else {
err = nr.error
}
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
}
break
}
// 'Expect: 100-continue' request handling.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
if ctx.Request.MayContinue() {
// Allow the ability to deny reading the incoming request body
if s.ContinueHandler != nil {
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
if br != nil {
br.Reset(ctx.c)
}
ctx.SetStatusCode(StatusExpectationFailed)
}
}
if continueReadingRequest {
if bw == nil {
bw = acquireWriter(ctx)
}
// Send 'HTTP/1.1 100 Continue' response.
_, err = bw.Write(strResponseContinue)
if err != nil {
break
}
err = bw.Flush()
if err != nil {
break
}
if s.ReduceMemoryUsage {
releaseWriter(s, bw)
bw = nil
}
// Read request body.
if br == nil {
br = acquireReader(ctx)
}
if s.StreamRequestBody {
err = ctx.Request.ContinueReadBodyStream(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
}
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
break
}
}
}
connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
isHTTP11 = ctx.Request.Header.IsHTTP11()
if serverName != nil {
ctx.Response.Header.SetServerBytes(serverName)
}
ctx.connID = connID
ctx.connRequestNum = connRequestNum
ctx.time = time.Now()
// If a client denies a request the handler should not be called
if continueReadingRequest {
s.Handler(ctx)
}
timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
// Acquire a new ctx because the old one will still be in use by the timeout out handler.
ctx = s.acquireCtx(c)
timeoutResponse.CopyTo(&ctx.Response)
}
if !ctx.IsGet() && ctx.IsHead() {
ctx.Response.SkipBody = true
}
reqReset = true
ctx.Request.Reset()
hijackHandler = ctx.hijackHandler
ctx.hijackHandler = nil
hijackNoResponse = ctx.hijackNoResponse && hijackHandler != nil
ctx.hijackNoResponse = false
if s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn) {
ctx.SetConnectionClose()
}
if writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err))
}
previousWriteTimeout = writeTimeout
} else if previousWriteTimeout > 0 {
// We don't want a write timeout but we previously set one, remove it.
if err := c.SetWriteDeadline(zeroTime); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(zeroTime): %s", err))
}
previousWriteTimeout = 0
}
connectionClose = connectionClose || ctx.Response.ConnectionClose()
connectionClose = connectionClose || ctx.Response.ConnectionClose() || (s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1)
if connectionClose {
ctx.Response.Header.SetCanonical(strConnection, strClose)
} else if !isHTTP11 {
// Set 'Connection: keep-alive' response header for non-HTTP/1.1 request.
// There is no need in setting this header for http/1.1, since in http/1.1
// connections are keep-alive by default.
ctx.Response.Header.SetCanonical(strConnection, strKeepAlive)
}
if serverName != nil && len(ctx.Response.Header.Server()) == 0 {
ctx.Response.Header.SetServerBytes(serverName)
}
if !hijackNoResponse {
if bw == nil {
bw = acquireWriter(ctx)
}
if err = writeResponse(ctx, bw); err != nil {
break
}
// Only flush the writer if we don't have another request in the pipeline.
// This is a big of an ugly optimization for https://www.techempower.com/benchmarks/
// This benchmark will send 16 pipelined requests. It is faster to pack as many responses
// in a TCP packet and send it back at once than waiting for a flush every request.
// In real world circumstances this behaviour could be argued as being wrong.
if br == nil || br.Buffered() == 0 || connectionClose {
err = bw.Flush()
if err != nil {
break
}
}
if connectionClose {
break
}
if s.ReduceMemoryUsage && hijackHandler == nil {
releaseWriter(s, bw)
bw = nil
}
}
if hijackHandler != nil {
var hjr io.Reader = c
if br != nil {
hjr = br
br = nil
// br may point to ctx.fbr, so do not return ctx into pool below.
ctx = nil
}
if bw != nil {
err = bw.Flush()
if err != nil {
break
}
releaseWriter(s, bw)
bw = nil
}
err = c.SetDeadline(zeroTime)
if err != nil {
break
}
go hijackConnHandler(hjr, c, s, hijackHandler)
err = errHijacked
break
}
if ctx.Request.bodyStream != nil {
if rs, ok := ctx.Request.bodyStream.(*requestStream); ok {
releaseRequestStream(rs)
}
}
s.setState(c, StateIdle)
ctx.userValues.Reset()
if atomic.LoadInt32(&s.stop) == 1 {
err = nil
break
}
}
if br != nil {
releaseReader(s, br)
}
if bw != nil {
releaseWriter(s, bw)
}
if ctx != nil {
// in unexpected cases the for loop will break
// before request reset call. in such cases, call it before
// release to fix #548
if !reqReset {
ctx.Request.Reset()
}
s.releaseCtx(ctx)
}
return
}
func (s *Server) setState(nc net.Conn, state ConnState) {
if hook := s.ConnState; hook != nil {
hook(nc, state)
}
}
func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) {
hjc := s.acquireHijackConn(r, c)
h(hjc)
if br, ok := r.(*bufio.Reader); ok {
releaseReader(s, br)
}
if !s.KeepHijackedConns {
c.Close()
s.releaseHijackConn(hjc)
}
}
func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn {
v := s.hijackConnPool.Get()
if v == nil {
hjc := &hijackConn{
Conn: c,
r: r,
s: s,
}
return hjc
}
hjc := v.(*hijackConn)
hjc.Conn = c
hjc.r = r
return hjc
}
func (s *Server) releaseHijackConn(hjc *hijackConn) {
hjc.Conn = nil
hjc.r = nil
s.hijackConnPool.Put(hjc)
}
type hijackConn struct {
net.Conn
r io.Reader
s *Server
}
func (c *hijackConn) UnsafeConn() net.Conn {
return c.Conn
}
func (c *hijackConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}
func (c *hijackConn) Close() error {
if !c.s.KeepHijackedConns {
// when we do not keep hijacked connections,
// it is closed in hijackConnHandler.
return nil
}
conn := c.Conn
c.s.releaseHijackConn(c)
return conn.Close()
}
// LastTimeoutErrorResponse returns the last timeout response set
// via TimeoutError* call.
//
// This function is intended for custom server implementations.
func (ctx *RequestCtx) LastTimeoutErrorResponse() *Response {
return ctx.timeoutResponse
}
func writeResponse(ctx *RequestCtx, w *bufio.Writer) error {
if ctx.timeoutResponse != nil {
panic("BUG: cannot write timed out response")
}
err := ctx.Response.Write(w)
ctx.Response.Reset()
return err
}
const (
defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
)
func acquireByteReader(ctxP **RequestCtx) (*bufio.Reader, error) {
ctx := *ctxP
s := ctx.s
c := ctx.c
s.releaseCtx(ctx)
// Make GC happy, so it could garbage collect ctx
// while we waiting for the next request.
ctx = nil
*ctxP = nil
var b [1]byte
n, err := c.Read(b[:])
ctx = s.acquireCtx(c)
*ctxP = ctx
if err != nil {
// Treat all errors as EOF on unsuccessful read
// of the first request byte.
return nil, io.EOF
}
if n != 1 {
panic("BUG: Reader must return at least one byte")
}
ctx.fbr.c = c
ctx.fbr.ch = b[0]
ctx.fbr.byteRead = false
r := acquireReader(ctx)
r.Reset(&ctx.fbr)
return r, nil
}
func acquireReader(ctx *RequestCtx) *bufio.Reader {
v := ctx.s.readerPool.Get()
if v == nil {
n := ctx.s.ReadBufferSize
if n <= 0 {
n = defaultReadBufferSize
}
return bufio.NewReaderSize(ctx.c, n)
}
r := v.(*bufio.Reader)
r.Reset(ctx.c)
return r
}
func releaseReader(s *Server, r *bufio.Reader) {
s.readerPool.Put(r)
}
func acquireWriter(ctx *RequestCtx) *bufio.Writer {
v := ctx.s.writerPool.Get()
if v == nil {
n := ctx.s.WriteBufferSize
if n <= 0 {
n = defaultWriteBufferSize
}
return bufio.NewWriterSize(ctx.c, n)
}
w := v.(*bufio.Writer)
w.Reset(ctx.c)
return w
}
func releaseWriter(s *Server, w *bufio.Writer) {
s.writerPool.Put(w)
}
func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
v := s.ctxPool.Get()
if v == nil {
ctx = &RequestCtx{
s: s,
}
keepBodyBuffer := !s.ReduceMemoryUsage
ctx.Request.keepBodyBuffer = keepBodyBuffer
ctx.Response.keepBodyBuffer = keepBodyBuffer
} else {
ctx = v.(*RequestCtx)
}
ctx.c = c
return
}
// Init2 prepares ctx for passing to RequestHandler.
//
// conn is used only for determining local and remote addresses.
//
// This function is intended for custom Server implementations.
// See https://github.com/valyala/httpteleport for details.
func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) {
ctx.c = conn
ctx.remoteAddr = nil
ctx.logger.logger = logger
ctx.connID = nextConnID()
ctx.s = fakeServer
ctx.connRequestNum = 0
ctx.connTime = time.Now()
keepBodyBuffer := !reduceMemoryUsage
ctx.Request.keepBodyBuffer = keepBodyBuffer
ctx.Response.keepBodyBuffer = keepBodyBuffer
}
// Init prepares ctx for passing to RequestHandler.
//
// remoteAddr and logger are optional. They are used by RequestCtx.Logger().
//
// This function is intended for custom Server implementations.
func (ctx *RequestCtx) Init(req *Request, remoteAddr net.Addr, logger Logger) {
if remoteAddr == nil {
remoteAddr = zeroTCPAddr
}
c := &fakeAddrer{
laddr: zeroTCPAddr,
raddr: remoteAddr,
}
if logger == nil {
logger = defaultLogger
}
ctx.Init2(c, logger, true)
req.CopyTo(&ctx.Request)
}
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
//
// This method always returns 0, false and is only present to make
// RequestCtx implement the context interface.
func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) {
return
}
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
func (ctx *RequestCtx) Done() <-chan struct{} {
return ctx.s.done
}
// Err returns a non-nil error value after Done is closed,
// successive calls to Err return the same error.
// If Done is not yet closed, Err returns nil.
// If Done is closed, Err returns a non-nil error explaining why:
// Canceled if the context was canceled (via server Shutdown)
// or DeadlineExceeded if the context's deadline passed.
func (ctx *RequestCtx) Err() error {
select {
case <-ctx.s.done:
return context.Canceled
default:
return nil
}
}
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
if keyString, ok := key.(string); ok {
return ctx.UserValue(keyString)
}
return nil
}
var fakeServer = &Server{
// Initialize concurrencyCh for TimeoutHandler
concurrencyCh: make(chan struct{}, DefaultConcurrency),
}
type fakeAddrer struct {
net.Conn
laddr net.Addr
raddr net.Addr
}
func (fa *fakeAddrer) RemoteAddr() net.Addr {
return fa.raddr
}
func (fa *fakeAddrer) LocalAddr() net.Addr {
return fa.laddr
}
func (fa *fakeAddrer) Read(p []byte) (int, error) {
panic("BUG: unexpected Read call")
}
func (fa *fakeAddrer) Write(p []byte) (int, error) {
panic("BUG: unexpected Write call")
}
func (fa *fakeAddrer) Close() error {
panic("BUG: unexpected Close call")
}
func (s *Server) releaseCtx(ctx *RequestCtx) {
if ctx.timeoutResponse != nil {
panic("BUG: cannot release timed out RequestCtx")
}
ctx.c = nil
ctx.remoteAddr = nil
ctx.fbr.c = nil
ctx.userValues.Reset()
s.ctxPool.Put(ctx)
}
func (s *Server) getServerName() []byte {
v := s.serverName.Load()
var serverName []byte
if v == nil {
serverName = []byte(s.Name)
if len(serverName) == 0 {
serverName = defaultServerName
}
s.serverName.Store(serverName)
} else {
serverName = v.([]byte)
}
return serverName
}
func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) {
w.Write(statusLine(statusCode)) //nolint:errcheck
server := ""
if !s.NoDefaultServerHeader {
server = fmt.Sprintf("Server: %s\r\n", s.getServerName())
}
date := ""
if !s.NoDefaultDate {
serverDateOnce.Do(updateServerDate)
date = fmt.Sprintf("Date: %s\r\n", serverDate.Load())
}
fmt.Fprintf(w, "Connection: close\r\n"+
server+
date+
"Content-Type: text/plain\r\n"+
"Content-Length: %d\r\n"+
"\r\n"+
"%s",
len(msg), msg)
}
func defaultErrorHandler(ctx *RequestCtx, err error) {
if _, ok := err.(*ErrSmallBuffer); ok {
ctx.Error("Too big request header", StatusRequestHeaderFieldsTooLarge)
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
ctx.Error("Request timeout", StatusRequestTimeout)
} else {
ctx.Error("Error when parsing request", StatusBadRequest)
}
}
func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName []byte, err error) *bufio.Writer {
errorHandler := defaultErrorHandler
if s.ErrorHandler != nil {
errorHandler = s.ErrorHandler
}
errorHandler(ctx, err)
if serverName != nil {
ctx.Response.Header.SetServerBytes(serverName)
}
ctx.SetConnectionClose()
if bw == nil {
bw = acquireWriter(ctx)
}
writeResponse(ctx, bw) //nolint:errcheck
bw.Flush()
return bw
}
// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
const (
// StateNew represents a new connection that is expected to
// send a request immediately. Connections begin at this
// state and then transition to either StateActive or
// StateClosed.
StateNew ConnState = iota
// StateActive represents a connection that has read 1 or more
// bytes of a request. The Server.ConnState hook for
// StateActive fires before the request has entered a handler
// and doesn't fire again until the request has been
// handled. After the request is handled, the state
// transitions to StateClosed, StateHijacked, or StateIdle.
// For HTTP/2, StateActive fires on the transition from zero
// to one active request, and only transitions away once all
// active requests are complete. That means that ConnState
// cannot be used to do per-request work; ConnState only notes
// the overall state of the connection.
StateActive
// StateIdle represents a connection that has finished
// handling a request and is in the keep-alive state, waiting
// for a new request. Connections transition from StateIdle
// to either StateActive or StateClosed.
StateIdle
// StateHijacked represents a hijacked connection.
// This is a terminal state. It does not transition to StateClosed.
StateHijacked
// StateClosed represents a closed connection.
// This is a terminal state. Hijacked connections do not
// transition to StateClosed.
StateClosed
)
var stateName = map[ConnState]string{
StateNew: "new",
StateActive: "active",
StateIdle: "idle",
StateHijacked: "hijacked",
StateClosed: "closed",
}
func (c ConnState) String() string {
return stateName[c]
}
fasthttp-1.31.0/server_example_test.go 0000664 0000000 0000000 00000012644 14130360711 0020043 0 ustar 00root root 0000000 0000000 package fasthttp_test
import (
"fmt"
"log"
"math/rand"
"net"
"time"
"github.com/valyala/fasthttp"
)
func ExampleListenAndServe() {
// The server will listen for incoming requests on this address.
listenAddr := "127.0.0.1:80"
// This function will be called by the server for each incoming request.
//
// RequestCtx provides a lot of functionality related to http request
// processing. See RequestCtx docs for details.
requestHandler := func(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path())
}
// Start the server with default settings.
// Create Server instance for adjusting server settings.
//
// ListenAndServe returns only on error, so usually it blocks forever.
if err := fasthttp.ListenAndServe(listenAddr, requestHandler); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
func ExampleServe() {
// Create network listener for accepting incoming requests.
//
// Note that you are not limited by TCP listener - arbitrary
// net.Listener may be used by the server.
// For example, unix socket listener or TLS listener.
ln, err := net.Listen("tcp4", "127.0.0.1:8080")
if err != nil {
log.Fatalf("error in net.Listen: %s", err)
}
// This function will be called by the server for each incoming request.
//
// RequestCtx provides a lot of functionality related to http request
// processing. See RequestCtx docs for details.
requestHandler := func(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path())
}
// Start the server with default settings.
// Create Server instance for adjusting server settings.
//
// Serve returns on ln.Close() or error, so usually it blocks forever.
if err := fasthttp.Serve(ln, requestHandler); err != nil {
log.Fatalf("error in Serve: %s", err)
}
}
func ExampleServer() {
// This function will be called by the server for each incoming request.
//
// RequestCtx provides a lot of functionality related to http request
// processing. See RequestCtx docs for details.
requestHandler := func(ctx *fasthttp.RequestCtx) {
fmt.Fprintf(ctx, "Hello, world! Requested path is %q", ctx.Path())
}
// Create custom server.
s := &fasthttp.Server{
Handler: requestHandler,
// Every response will contain 'Server: My super server' header.
Name: "My super server",
// Other Server settings may be set here.
}
// Start the server listening for incoming requests on the given address.
//
// ListenAndServe returns only on error, so usually it blocks forever.
if err := s.ListenAndServe("127.0.0.1:80"); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
func ExampleRequestCtx_Hijack() {
// hijackHandler is called on hijacked connection.
hijackHandler := func(c net.Conn) {
fmt.Fprintf(c, "This message is sent over a hijacked connection to the client %s\n", c.RemoteAddr())
fmt.Fprintf(c, "Send me something and I'll echo it to you\n")
var buf [1]byte
for {
if _, err := c.Read(buf[:]); err != nil {
log.Printf("error when reading from hijacked connection: %s", err)
return
}
fmt.Fprintf(c, "You sent me %q. Waiting for new data\n", buf[:])
}
}
// requestHandler is called for each incoming request.
requestHandler := func(ctx *fasthttp.RequestCtx) {
path := ctx.Path()
switch {
case string(path) == "/hijack":
// Note that the connection is hijacked only after
// returning from requestHandler and sending http response.
ctx.Hijack(hijackHandler)
// The connection will be hijacked after sending this response.
fmt.Fprintf(ctx, "Hijacked the connection!")
case string(path) == "/":
fmt.Fprintf(ctx, "Root directory requested")
default:
fmt.Fprintf(ctx, "Requested path is %q", path)
}
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
func ExampleRequestCtx_TimeoutError() {
requestHandler := func(ctx *fasthttp.RequestCtx) {
// Emulate long-running task, which touches ctx.
doneCh := make(chan struct{})
go func() {
workDuration := time.Millisecond * time.Duration(rand.Intn(2000))
time.Sleep(workDuration)
fmt.Fprintf(ctx, "ctx has been accessed by long-running task\n")
fmt.Fprintf(ctx, "The reuqestHandler may be finished by this time.\n")
close(doneCh)
}()
select {
case <-doneCh:
fmt.Fprintf(ctx, "The task has been finished in less than a second")
case <-time.After(time.Second):
// Since the long-running task is still running and may access ctx,
// we must call TimeoutError before returning from requestHandler.
//
// Otherwise the program will suffer from data races.
ctx.TimeoutError("Timeout!")
}
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
func ExampleRequestCtx_Logger() {
requestHandler := func(ctx *fasthttp.RequestCtx) {
if string(ctx.Path()) == "/top-secret" {
ctx.Logger().Printf("Alarm! Alien intrusion detected!")
ctx.Error("Access denied!", fasthttp.StatusForbidden)
return
}
// Logger may be cached in local variables.
logger := ctx.Logger()
logger.Printf("Good request from User-Agent %q", ctx.Request.Header.UserAgent())
fmt.Fprintf(ctx, "Good request to %q", ctx.Path())
logger.Printf("Multiple log messages may be written during a single request")
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
log.Fatalf("error in ListenAndServe: %s", err)
}
}
fasthttp-1.31.0/server_test.go 0000664 0000000 0000000 00000267427 14130360711 0016343 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net"
"os"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
// Make sure RequestCtx implements context.Context
var _ context.Context = &RequestCtx{}
func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
"GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
func TestServerCRNLAfterPost(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
ReadTimeout: time.Millisecond * 100,
}
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it
)); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err == nil {
t.Fatal("expected error") // We didn't send a request so we should get an error here.
}
}
func TestServerPipelineFlush(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatal(err)
}
// Write a partial request.
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil {
t.Fatal(err)
}
go func() {
// Wait for 200ms to finish the request
time.Sleep(time.Millisecond * 200)
if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil {
t.Error(err)
}
}()
start := time.Now()
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
// Since the second request takes 200ms to finish we expect the first one to be flushed earlier.
d := time.Since(start)
if d > time.Millisecond*100 {
t.Fatalf("had to wait for %v", d)
}
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
func TestServerInvalidHeader(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil {
t.Error("expected Foo header")
}
},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
}
c, err = ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
t.Fatal(err)
}
br = bufio.NewReader(c)
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
}
if err := c.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestServerConnState(t *testing.T) {
t.Parallel()
states := make([]string, 0)
s := &Server{
Handler: func(ctx *RequestCtx) {},
ConnState: func(conn net.Conn, state ConnState) {
states = append(states, state.String())
},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c)
// Send 2 requests on the same connection.
for i := 0; i < 2; i++ {
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
var resp Response
if err := resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
if err := c.Close(); err != nil {
t.Errorf("unexpected error: %s", err)
}
// Give the server a little bit of time to transition the connection to the close state.
time.Sleep(time.Millisecond * 100)
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// 2 requests so we go to active and idle twice.
expected := []string{"new", "active", "idle", "active", "idle", "closed"}
if !reflect.DeepEqual(expected, states) {
t.Fatalf("wrong state, expected %s, got %s", expected, states)
}
}
func TestSaveMultipartFile(t *testing.T) {
t.Parallel()
filea := "This is a test file."
fileb := strings.Repeat("test", 64)
mr := multipart.NewReader(strings.NewReader(""+
"--foo\r\n"+
"Content-Disposition: form-data; name=\"filea\"; filename=\"filea.txt\"\r\n"+
"Content-Type: text/plain\r\n"+
"\r\n"+
filea+"\r\n"+
"--foo\r\n"+
"Content-Disposition: form-data; name=\"fileb\"; filename=\"fileb.txt\"\r\n"+
"Content-Type: text/plain\r\n"+
"\r\n"+
fileb+"\r\n"+
"--foo--\r\n",
), "foo")
f, err := mr.ReadForm(64)
if err != nil {
t.Fatal(err)
}
if err := SaveMultipartFile(f.File["filea"][0], "filea.txt"); err != nil {
t.Fatal(err)
}
defer os.Remove("filea.txt")
if c, err := ioutil.ReadFile("filea.txt"); err != nil {
t.Fatal(err)
} else if string(c) != filea {
t.Fatalf("filea changed expected %q got %q", filea, c)
}
// Make sure fileb was saved to a file.
if ff, err := f.File["fileb"][0].Open(); err != nil {
t.Fatal("expected FileHeader.Open to work")
} else if _, ok := ff.(*os.File); !ok {
t.Fatal("expected fileb to be an os.File")
} else {
ff.Close()
}
if err := SaveMultipartFile(f.File["fileb"][0], "fileb.txt"); err != nil {
t.Fatal(err)
}
defer os.Remove("fileb.txt")
if c, err := ioutil.ReadFile("fileb.txt"); err != nil {
t.Fatal(err)
} else if string(c) != fileb {
t.Fatalf("fileb changed expected %q got %q", fileb, c)
}
}
func TestServerName(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
getReponse := func() []byte {
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
resp, err := ioutil.ReadAll(&rw.w)
if err != nil {
t.Fatalf("Unexpected error from ReadAll: %s", err)
}
return resp
}
resp := getReponse()
if !bytes.Contains(resp, []byte("\r\nServer: "+string(defaultServerName)+"\r\n")) {
t.Fatalf("Unexpected response %q expected Server: "+string(defaultServerName), resp)
}
// We can't just overwrite s.Name as fasthttp caches the name in an atomic.Value
s = &Server{
Handler: func(ctx *RequestCtx) {
},
Name: "foobar",
}
resp = getReponse()
if !bytes.Contains(resp, []byte("\r\nServer: foobar\r\n")) {
t.Fatalf("Unexpected response %q expected Server: foobar", resp)
}
s = &Server{
Handler: func(ctx *RequestCtx) {
},
NoDefaultServerHeader: true,
NoDefaultContentType: true,
NoDefaultDate: true,
}
resp = getReponse()
if bytes.Contains(resp, []byte("\r\nServer: ")) {
t.Fatalf("Unexpected response %q expected no Server header", resp)
}
if bytes.Contains(resp, []byte("\r\nContent-Type: ")) {
t.Fatalf("Unexpected response %q expected no Content-Type header", resp)
}
if bytes.Contains(resp, []byte("\r\nDate: ")) {
t.Fatalf("Unexpected response %q expected no Date header", resp)
}
}
func TestRequestCtxString(t *testing.T) {
t.Parallel()
var ctx RequestCtx
s := ctx.String()
expectedS := "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:///"
if s != expectedS {
t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS)
}
ctx.Request.SetRequestURI("https://foobar.com/aaa?bb=c")
s = ctx.String()
expectedS = "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET https://foobar.com/aaa?bb=c"
if s != expectedS {
t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS)
}
}
func TestServerErrSmallBuffer(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("shouldn't be never called") //nolint:errcheck
},
ReadBufferSize: 20,
}
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: aabb.com\r\nVERY-long-Header: sdfdfsd dsf dsaf dsf df fsd\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
var serverErr error
select {
case serverErr = <-ch:
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout")
}
if serverErr == nil {
t.Fatal("expected error")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusRequestHeaderFieldsTooLarge {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusRequestHeaderFieldsTooLarge)
}
if !resp.ConnectionClose() {
t.Fatal("missing 'Connection: close' response header")
}
expectedErr := errSmallBuffer.Error()
if !strings.Contains(serverErr.Error(), expectedErr) {
t.Fatalf("unexpected log output: %v. Expecting %q", serverErr, expectedErr)
}
}
func TestRequestCtxIsTLS(t *testing.T) {
t.Parallel()
var ctx RequestCtx
// tls.Conn
ctx.c = &tls.Conn{}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
// non-tls.Conn
ctx.c = &readWriter{}
if ctx.IsTLS() {
t.Fatal("IsTLS must return false")
}
// overridden tls.Conn
ctx.c = &struct {
*tls.Conn
fooBar bool
}{}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
ctx.c = &perIPConn{Conn: &tls.Conn{}}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
}
func TestRequestCtxRedirectHTTPSSchemeless(t *testing.T) {
t.Parallel()
var ctx RequestCtx
s := "GET /foo/bar?baz HTTP/1.1\nHost: aaa.com\n\n"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := ctx.Request.Read(br); err != nil {
t.Fatalf("cannot read request: %s", err)
}
ctx.Request.isTLS = true
ctx.Redirect("//foobar.com/aa/bbb", StatusFound)
location := ctx.Response.Header.Peek(HeaderLocation)
expectedLocation := "https://foobar.com/aa/bbb"
if string(location) != expectedLocation {
t.Fatalf("Unexpected location: %q. Expecting %q", location, expectedLocation)
}
}
func TestRequestCtxRedirect(t *testing.T) {
t.Parallel()
testRequestCtxRedirect(t, "http://qqq/", "", "http://qqq/")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "", "http://qqq/foo/bar?baz=111")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "#aaa", "http://qqq/foo/bar?baz=111#aaa")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f", "http://qqq/foo/bar?abc=de&f")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f#sf", "http://qqq/foo/bar?abc=de&f#sf")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html", "http://qqq/foo/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?a=1", "http://qqq/foo/x.html?a=1")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html#aaa=bbb&cc=ddd", "http://qqq/foo/x.html#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz")
testRequestCtxRedirect(t, "https://foo.com/bar?aaa", "//google.com/aaa?bb", "https://google.com/aaa?bb")
}
func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) {
var ctx RequestCtx
var req Request
req.SetRequestURI(origURL)
ctx.Init(&req, nil, nil)
ctx.Redirect(redirectURL, StatusFound)
loc := ctx.Response.Header.Peek(HeaderLocation)
if string(loc) != expectedURL {
t.Fatalf("unexpected redirect url %q. Expecting %q. origURL=%q, redirectURL=%q", loc, expectedURL, origURL, redirectURL)
}
}
func TestServerResponseServerHeader(t *testing.T) {
t.Parallel()
serverName := "foobar serv"
s := &Server{
Handler: func(ctx *RequestCtx) {
name := ctx.Response.Header.Server()
if string(name) != serverName {
fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName)
} else {
ctx.WriteString("OK") //nolint:errcheck
}
// make sure the server name is sent to the client after ctx.Response.Reset()
ctx.NotFound()
},
Name: serverName,
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusNotFound {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotFound)
}
if string(resp.Body()) != "404 Page not found" {
t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "404 Page not found")
}
if string(resp.Header.Server()) != serverName {
t.Errorf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName)
}
if err = c.Close(); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerResponseBodyStream(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
readyCh := make(chan struct{})
h := func(ctx *RequestCtx) {
ctx.SetConnectionClose()
if ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return false")
}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "first")
if err := w.Flush(); err != nil {
return
}
<-readyCh
fmt.Fprintf(w, "second")
// there is no need to flush w here, since it will
// be flushed automatically after returning from StreamWriter.
})
if !ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return true")
}
}
serverCh := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if respH.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK)
}
buf := make([]byte, 1024)
n, err := br.Read(buf)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
b := buf[:n]
if string(b) != "5\r\nfirst\r\n" {
t.Errorf("unexpected result %q. Expecting %q", b, "5\r\nfirst\r\n")
}
close(readyCh)
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" {
t.Errorf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerDisableKeepalive(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
DisableKeepalive: true,
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if !resp.ConnectionClose() {
t.Error("expecting 'Connection: close' response header")
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "OK")
}
// make sure the connection is closed
data, err := ioutil.ReadAll(br)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if len(data) > 0 {
t.Errorf("unexpected data read from the connection: %q. Expecting empty data", data)
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerMaxConnsPerIPLimit(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
MaxConnsPerIP: 1,
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
fakeLN := &fakeIPListener{
Listener: ln,
}
if err := s.Serve(fakeLN); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c1, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
c2, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusTooManyRequests {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
resp.StatusCode(), StatusTooManyRequests)
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error when writing to the first connection: %s", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
resp.StatusCode(), StatusOK)
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
type fakeIPListener struct {
net.Listener
}
func (ln *fakeIPListener) Accept() (net.Conn, error) {
conn, err := ln.Listener.Accept()
if err != nil {
return nil, err
}
return &fakeIPConn{
Conn: conn,
}, nil
}
type fakeIPConn struct {
net.Conn
}
func (conn *fakeIPConn) RemoteAddr() net.Addr {
addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789")
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error: %s", err))
}
return addr
}
func TestServerConcurrencyLimit(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
Concurrency: 1,
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c1, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
c2, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusServiceUnavailable {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
resp.StatusCode(), StatusServiceUnavailable)
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error when writing to the first connection: %s", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
resp.StatusCode(), StatusOK)
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerWriteFastError(t *testing.T) {
t.Parallel()
s := &Server{
Name: "foobar",
}
var buf bytes.Buffer
expectedBody := "access denied"
s.writeFastError(&buf, StatusForbidden, expectedBody)
br := bufio.NewReader(&buf)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusForbidden {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden)
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody)
}
server := string(resp.Header.Server())
if server != s.Name {
t.Fatalf("unexpected server: %q. Expecting %q", server, s.Name)
}
contentType := string(resp.Header.ContentType())
if contentType != "text/plain" {
t.Fatalf("unexpected content-type: %q. Expecting %q", contentType, "text/plain")
}
if !resp.Header.ConnectionClose() {
t.Fatal("expecting 'Connection: close' response header")
}
}
func TestServerTLS(t *testing.T) {
t.Parallel()
text := []byte("Make fasthttp great again")
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Write(text) //nolint:errcheck
},
}
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
err = s.AppendCertEmbed(certData, keyData)
if err != nil {
t.Fatal(err)
}
go func() {
err = s.ServeTLS(ln, "", "")
if err != nil {
t.Error(err)
}
}()
c := &Client{
ReadTimeout: time.Second * 2,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
req, res := AcquireRequest(), AcquireResponse()
req.SetRequestURI("https://some.url")
err = c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(text, res.Body()) {
t.Fatal("error transmitting information")
}
}
func TestServerTLSReadTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
ReadTimeout: time.Millisecond * 500,
Logger: &testLogger{}, // Ignore log output.
Handler: func(ctx *RequestCtx) {
},
}
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
err = s.AppendCertEmbed(certData, keyData)
if err != nil {
t.Fatal(err)
}
go func() {
err = s.ServeTLS(ln, "", "")
if err != nil {
t.Error(err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Error(err)
}
r := make(chan error)
go func() {
b := make([]byte, 1)
_, err := c.Read(b)
c.Close()
r <- err
}()
select {
case err = <-r:
case <-time.After(time.Second):
}
if err == nil {
t.Error("server didn't close connection after timeout")
}
}
func TestServerServeTLSEmbed(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
// start the server
ch := make(chan struct{})
go func() {
err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) {
if !ctx.IsTLS() {
ctx.Error("expecting tls", StatusBadRequest)
return
}
scheme := ctx.URI().Scheme()
if string(scheme) != "https" {
ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", scheme, "https"), StatusBadRequest)
return
}
ctx.WriteString("success") //nolint:errcheck
})
if err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
// establish connection to the server
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
// send request
if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// read response
respCh := make(chan struct{})
go func() {
br := bufio.NewReader(tlsConn)
var resp Response
if err := resp.Read(br); err != nil {
t.Error("unexpected error")
}
body := resp.Body()
if string(body) != "success" {
t.Errorf("unexpected response body %q. Expecting %q", body, "success")
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// close the server
if err = ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerMultipartFormDataRequest(t *testing.T) {
t.Parallel()
for _, test := range []struct {
StreamRequestBody bool
DisablePreParseMultipartForm bool
}{
{false, false},
{false, true},
{true, false},
{true, true},
} {
reqS := `POST /upload HTTP/1.1
Host: qwerty.com
Content-Length: 521
Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="f1"
value1
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="fileaaa"; filename="TODO"
Content-Type: application/octet-stream
- SessionClient with referer and cookies support.
- Client with requests' pipelining support.
- ProxyHandler similar to FSHandler.
- WebSockets. See https://tools.ietf.org/html/rfc6455 .
- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 .
------WebKitFormBoundaryJwfATyF8tmxSJnLg--
GET / HTTP/1.1
Host: asbd
Connection: close
`
ln := fasthttputil.NewInmemoryListener()
s := &Server{
StreamRequestBody: test.StreamRequestBody,
DisablePreParseMultipartForm: test.DisablePreParseMultipartForm,
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/upload":
f, err := ctx.MultipartForm()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if len(f.Value) != 1 {
t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
}
if len(f.File) != 1 {
t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
}
fv := ctx.FormValue("f1")
if string(fv) != "value1" {
t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
}
ctx.Redirect("/", StatusSeeOther)
default:
ctx.WriteString("non-upload") //nolint:errcheck
}
},
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusSeeOther {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
}
loc := resp.Header.Peek(HeaderLocation)
if string(loc) != "http://qwerty.com/" {
t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading the second response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "non-upload" {
t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
}
func TestServerGetWithContent(t *testing.T) {
t.Parallel()
h := func(ctx *RequestCtx) {
ctx.Success("foo/bar", []byte("success"))
}
s := &Server{
Handler: h,
}
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: mm.com\r\nContent-Length: 5\r\n\r\nabcde")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
resp := rw.w.String()
if !strings.HasSuffix(resp, "success") {
t.Fatalf("unexpected response %s.", resp)
}
}
func TestServerDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
headerName := "CASE-senSITive-HEAder-NAME"
headerNameLower := strings.ToLower(headerName)
headerValue := "foobar baz"
s := &Server{
Handler: func(ctx *RequestCtx) {
hv := ctx.Request.Header.Peek(headerName)
if string(hv) != headerValue {
t.Errorf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = ctx.Request.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Errorf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
ctx.Response.Header.Set(headerName, headerValue)
ctx.WriteString("ok") //nolint:errcheck
ctx.SetContentType("aaa")
},
DisableHeaderNamesNormalizing: true,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue))
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.Header.DisableNormalizing()
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
hv := resp.Header.Peek(headerName)
if string(hv) != headerValue {
t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = resp.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
}
func TestServerReduceMemoryUsageSerial(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {},
ReduceMemoryUsage: true,
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
testServerRequests(t, ln)
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestServerReduceMemoryUsageConcurrent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {},
ReduceMemoryUsage: true,
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
gCh := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
testServerRequests(t, ln)
gCh <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-gCh:
case <-time.After(time.Second):
t.Fatalf("timeout on goroutine %d", i)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) {
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
var resp Response
for i := 0; i < 10; i++ {
if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil {
t.Fatalf("unexpected error on iteration %d: %s", i, err)
}
respCh := make(chan struct{})
go func() {
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error when reading response on iteration %d: %s", i, err)
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatalf("timeout on iteration %d", i)
}
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %s", err)
}
}
func TestServerHTTP10ConnectionKeepAlive(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
err := Serve(ln, func(ctx *RequestCtx) {
if string(ctx.Path()) == "/close" {
ctx.SetConnectionClose()
}
})
if err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %s", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %s", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.ConnectionClose() {
t.Fatal("response mustn't have 'Connection: close' header")
}
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatal("response must have 'Connection: close' header")
}
tailCh := make(chan struct{})
go func() {
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Errorf("error when reading tail: %s", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
}
close(tailCh)
}()
select {
case <-tailCh:
case <-time.After(time.Second):
t.Fatal("timeout when reading tail")
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %s", err)
}
if err = ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestServerHTTP10ConnectionClose(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
err := Serve(ln, func(ctx *RequestCtx) {
// The server must close the connection irregardless
// of request and response state set inside request
// handler, since the HTTP/1.0 request
// had no 'Connection: keep-alive' header.
ctx.Request.Header.ResetConnectionClose()
ctx.Request.Header.Set(HeaderConnection, "keep-alive")
ctx.Response.Header.ResetConnectionClose()
ctx.Response.Header.Set(HeaderConnection, "keep-alive")
})
if err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %s", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatal("HTTP1.0 response must have 'Connection: close' header")
}
tailCh := make(chan struct{})
go func() {
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Errorf("error when reading tail: %s", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
}
close(tailCh)
}()
select {
case <-tailCh:
case <-time.After(time.Second):
t.Fatal("timeout when reading tail")
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %s", err)
}
if err = ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestRequestCtxFormValue(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?baz=123&aaa=bbb")
req.SetBodyString("qqq=port&mmm=sddd")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
v := ctx.FormValue("baz")
if string(v) != "123" {
t.Fatalf("unexpected value %q. Expecting %q", v, "123")
}
v = ctx.FormValue("mmm")
if string(v) != "sddd" {
t.Fatalf("unexpected value %q. Expecting %q", v, "sddd")
}
v = ctx.FormValue("aaaasdfsdf")
if len(v) > 0 {
t.Fatalf("unexpected value for unknown key %q", v)
}
}
func TestRequestCtxUserValue(t *testing.T) {
t.Parallel()
var ctx RequestCtx
for i := 0; i < 5; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValue(k, i)
}
for i := 5; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValueBytes([]byte(k), i)
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
v := ctx.UserValue(k)
n, ok := v.(int)
if !ok || n != i {
t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i)
}
}
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
v := ctx.UserValueBytes(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
})
if len(ctx.userValues) != vlen {
t.Fatalf("the length of user values returned from VisitUserValues is not equal to the length of the userValues, expecting: %d but got: %d", len(ctx.userValues), vlen)
}
ctx.ResetUserValues()
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
v := ctx.UserValue(k)
if v != nil {
t.Fatalf("unexpected value obtained for key %q: %v. Expecting nil", k, v)
}
}
}
func TestServerHeadRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
fmt.Fprintf(ctx, "Request method is %q", ctx.Method())
ctx.SetContentType("aaa/bbb")
},
}
rw := &readWriter{}
rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.SkipBody = true
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if resp.Header.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK)
}
if len(resp.Body()) > 0 {
t.Fatalf("Unexpected non-zero body %q", resp.Body())
}
if resp.Header.ContentLength() != 24 {
t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24)
}
if string(resp.Header.ContentType()) != "aaa/bbb" {
t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestServerExpect100Continue(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestServerContinueHandler(t *testing.T) {
t.Parallel()
acceptContentLength := 5
s := &Server{
ContinueHandler: func(headers *RequestHeader) bool {
if !headers.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", headers.Method())
}
ct := headers.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
// Pass on any request that isn't the accepted content length
return headers.contentLength == acceptContentLength
},
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.contentLength != acceptContentLength {
t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength)
}
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}
sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
// The same server should not fail when handling the three different types of requests
// Regular requests
// Expect 100 continue accepted
// Exepect 100 continue denied
rw := &readWriter{}
for i := 0; i < 25; i++ {
// Regular requests without Expect 100 continue header
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Regular Expect 100 continue reqeuests that are accepted
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Requests being denied
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
sendRequest(rw, StatusExpectationFailed, "")
}
}
func TestCompressHandler(t *testing.T) {
t.Parallel()
expectedBody := string(createFixedBody(2e4))
h := CompressHandler(func(ctx *RequestCtx) {
ctx.Write([]byte(expectedBody)) //nolint:errcheck
})
var ctx RequestCtx
var resp Response
// verify uncompressed response
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify gzip-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// an attempt to compress already compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
hh := CompressHandler(h)
hh(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err = resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify deflate-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderAcceptEncoding, "foobar, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek(HeaderContentEncoding)
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate")
}
body, err = resp.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxWriteString(t *testing.T) {
t.Parallel()
var ctx RequestCtx
n, err := ctx.WriteString("foo")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != 3 {
t.Fatalf("unexpected n %d. Expecting 3", n)
}
n, err = ctx.WriteString("привет")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != 12 {
t.Fatalf("unexpected n=%d. Expecting 12", n)
}
s := ctx.Response.Body()
if string(s) != "fooпривет" {
t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет")
}
}
func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
t.Parallel()
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n")
requestsServed := 0
ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
requestsServed++
ctx.SuccessString("aaa/bbb", "foobar")
})
if err != nil {
t.Errorf("unexpected error in ServeConn: %s", err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
// verify the first response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "keep-alive" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "keep-alive")
}
if resp.Header.ConnectionClose() {
t.Fatal("unexpected Connection: close")
}
// verify the second response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "close" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "close")
}
if !resp.Header.ConnectionClose() {
t.Fatal("expecting Connection: close")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after responses %q", data)
}
if requestsServed != 2 {
t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed)
}
}
func TestRequestCtxSetBodyStreamWriter(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
if ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return false")
}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %s", err)
}
fmt.Fprintf(w, "body writer line 2\n")
})
if !ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return true")
}
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Error when reading response: %s", err)
}
body := string(resp.Body())
expectedBody := "body writer line 1\nbody writer line 2\n"
if body != expectedBody {
t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxIfModifiedSince(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
lastModified := time.Now().Add(-time.Hour)
if !ctx.IfModifiedSince(lastModified) {
t.Fatal("IfModifiedSince must return true for non-existing If-Modified-Since header")
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
if ctx.IfModifiedSince(lastModified) {
t.Fatal("If-Modified-Since current time must return false")
}
past := lastModified.Add(-time.Hour)
if ctx.IfModifiedSince(past) {
t.Fatal("If-Modified-Since past time must return false")
}
future := lastModified.Add(time.Hour)
if !ctx.IfModifiedSince(future) {
t.Fatal("If-Modified-Since future time must return true")
}
}
func TestRequestCtxSendFileNotModified(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
ctx.SendFile(filePath)
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusNotModified {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified)
}
if len(resp.Body()) > 0 {
t.Fatalf("unexpected non-zero response body: %q", resp.Body())
}
}
func TestRequestCtxSendFileModified(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
lastModified = lastModified.Add(-time.Hour)
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
ctx.SendFile(filePath)
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %s", err)
}
body, err := ioutil.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %s", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func TestRequestCtxSendFile(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
ctx.SendFile(filePath)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := ctx.Response.Write(bw); err != nil {
t.Fatalf("error when writing response: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("error when flushing response: %s", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %s", err)
}
body, err := ioutil.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %s", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func TestRequestCtxHijack(t *testing.T) {
t.Parallel()
hijackStartCh := make(chan struct{})
hijackStopCh := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.Hijacked() {
t.Error("connection mustn't be hijacked")
}
ctx.Hijack(func(c net.Conn) {
<-hijackStartCh
b := make([]byte, 1)
// ping-pong echo via hijacked conn
for {
n, err := c.Read(b)
if n != 1 {
if err == io.EOF {
close(hijackStopCh)
return
}
if err != nil {
t.Errorf("unexpected error: %s", err)
}
t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
}
if _, err = c.Write(b); err != nil {
t.Errorf("unexpected error when writing data: %s", err)
}
}
})
if !ctx.Hijacked() {
t.Error("connection must be hijacked")
}
ctx.Success("foo/bar", []byte("hijack it!"))
},
}
hijackedString := "foobar baz hijacked!!!"
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString(hijackedString)
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
close(hijackStartCh)
select {
case <-hijackStopCh:
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if string(data) != hijackedString {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString)
}
}
func TestRequestCtxHijackNoResponse(t *testing.T) {
t.Parallel()
hijackDone := make(chan error)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {
_, err := c.Write([]byte("test"))
hijackDone <- err
})
ctx.HijackSetNoResponse(true)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
select {
case err := <-hijackDone:
if err != nil {
t.Fatalf("Unexpected error from hijack: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
if got := rw.w.String(); got != "test" {
t.Errorf(`expected "test", got %q`, got)
}
}
func TestRequestCtxNoHijackNoResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
io.WriteString(ctx, "test") //nolint:errcheck
ctx.HijackSetNoResponse(true)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
bf := bufio.NewReader(
strings.NewReader(rw.w.String()),
)
resp := AcquireResponse()
resp.Read(bf) //nolint:errcheck
if got := string(resp.Body()); got != "test" {
t.Errorf(`expected "test", got %q`, got)
}
}
func TestRequestCtxInit(t *testing.T) {
// This test can't run parallel as it modifies globalConnID.
var ctx RequestCtx
var logger testLogger
globalConnID = 0x123456
ctx.Init(&ctx.Request, zeroTCPAddr, &logger)
ip := ctx.RemoteIP()
if !ip.IsUnspecified() {
t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip)
}
ctx.Logger().Printf("foo bar %d", 10)
expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n"
if logger.out != expectedLog {
t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog)
}
}
func TestTimeoutHandlerSuccess(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
h := func(ctx *RequestCtx) {
if string(ctx.Path()) == "/" {
ctx.Success("aaa/bbb", []byte("real response"))
}
}
s := &Server{
Handler: TimeoutHandler(h, 10*time.Second, "timeout!!!"),
}
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
close(serverCh)
}()
concurrency := 20
clientCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
clientCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTimeoutHandlerTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
readyCh := make(chan struct{})
doneCh := make(chan struct{})
h := func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
<-readyCh
doneCh <- struct{}{}
}
s := &Server{
Handler: TimeoutHandler(h, 20*time.Millisecond, "timeout!!!"),
}
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
close(serverCh)
}()
concurrency := 20
clientCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
clientCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
close(readyCh)
for i := 0; i < concurrency; i++ {
select {
case <-doneCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTimeoutHandlerTimeoutReuse(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
h := func(ctx *RequestCtx) {
if string(ctx.Path()) == "/timeout" {
time.Sleep(time.Second)
}
ctx.SetBodyString("ok")
}
s := &Server{
Handler: TimeoutHandler(h, 500*time.Millisecond, "timeout!!!"),
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
if _, err = conn.Write([]byte("GET /timeout HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
verifyResponse(t, br, StatusOK, string(defaultContentType), "ok")
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestServerGetOnly(t *testing.T) {
t.Parallel()
h := func(ctx *RequestCtx) {
if !ctx.IsGet() {
t.Errorf("non-get request: %q", ctx.Method())
}
ctx.Success("foo/bar", []byte("success"))
}
s := &Server{
Handler: h,
GetOnly: true,
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err == nil {
t.Fatal("expecting error")
}
if err != ErrGetOnly {
t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, ErrGetOnly)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusBadRequest)
}
if !resp.ConnectionClose() {
t.Fatal("missing 'Connection: close' response header")
}
}
func TestServerTimeoutErrorWithResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
var resp Response
resp.SetStatusCode(123)
resp.SetBodyString("foobar. Should be ignored")
ctx.TimeoutErrorWithResponse(&resp)
resp.SetStatusCode(456)
resp.ResetBody()
fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path())
resp.Header.SetContentType("foo/bar")
ctx.TimeoutErrorWithResponse(&resp)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 456, "foo/bar", "path=/foo")
verifyResponse(t, br, 456, "foo/bar", "path=/bar")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutErrorWithCode(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutErrorWithCode("should be ignored", 234)
ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutError(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutError("should be ignored")
ctx.TimeoutError("stolen ctx")
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerMaxRequestsPerConn(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {},
MaxRequestsPerConn: 1,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatal("Response must have 'connection: close' header")
}
verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType))
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerConnectionClose(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetConnectionClose()
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatal("expecting Connection: close header")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerRequestNumAndTime(t *testing.T) {
t.Parallel()
n := uint64(0)
var connT time.Time
s := &Server{
Handler: func(ctx *RequestCtx) {
n++
if ctx.ConnRequestNum() != n {
t.Errorf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n)
}
if connT.IsZero() {
connT = ctx.ConnTime()
}
if ctx.ConnTime() != connT {
t.Errorf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT)
}
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
if n != 3 {
t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
func TestServerEmptyResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
func TestServerLogger(t *testing.T) {
// This test can't run parallel as it modifies globalConnID.
cl := &testLogger{}
s := &Server{
Handler: func(ctx *RequestCtx) {
logger := ctx.Logger()
h := &ctx.Request.Header
logger.Printf("begin")
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s",
h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr())))
logger.Printf("end")
},
Logger: cl,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
globalConnID = 0
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765")
verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765")
expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin
#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end
`
if cl.out != expectedLogOut {
t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut)
}
}
func TestServerRemoteAddr(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
}
func TestServerCustomRemoteAddr(t *testing.T) {
t.Parallel()
customRemoteAddrHandler := func(h RequestHandler) RequestHandler {
return func(ctx *RequestCtx) {
ctx.SetRemoteAddr(&net.TCPAddr{
IP: []byte{1, 2, 3, 5},
Port: 0,
})
h(ctx)
}
}
s := &Server{
Handler: customRemoteAddrHandler(func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
}),
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5")
}
type readWriterRemoteAddr struct {
net.Conn
rw io.ReadWriteCloser
addr net.Addr
}
func (rw *readWriterRemoteAddr) Close() error {
return rw.rw.Close()
}
func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) {
return rw.rw.Read(b)
}
func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) {
return rw.rw.Write(b)
}
func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr {
return rw.addr
}
func (rw *readWriterRemoteAddr) LocalAddr() net.Addr {
return rw.addr
}
func TestServerConnError(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Error("foobar", 423)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %s", err)
}
if resp.Header.StatusCode() != 423 {
t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423)
}
if resp.Header.ContentLength() != 6 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6)
}
if !bytes.Equal(resp.Header.Peek(HeaderContentType), defaultContentType) {
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek(HeaderContentType), defaultContentType)
}
if !bytes.Equal(resp.Body(), []byte("foobar")) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}
}
func TestServeConnSingleRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}
func TestServeConnMultiRequests(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com")
}
func TestShutdown(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(time.Millisecond * 500)
ctx.Success("aaa/bbb", []byte("real response"))
},
}
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
_, err := ln.Dial()
if err == nil {
t.Error("server is still listening")
}
serveCh <- struct{}{}
}()
clientCh := make(chan struct{})
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
verifyResponseHeaderConnection(t, &resp.Header, "")
clientCh <- struct{}{}
}()
time.Sleep(time.Millisecond * 100)
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
t.Errorf("unexepcted error: %s", err)
}
shutdownCh <- struct{}{}
}()
done := 0
for {
select {
case <-time.After(time.Second * 2):
t.Fatal("shutdown took too long")
case <-serveCh:
done++
case <-clientCh:
done++
case <-shutdownCh:
done++
}
if done == 3 {
return
}
}
}
func TestCloseOnShutdown(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(time.Millisecond * 500)
ctx.Success("aaa/bbb", []byte("real response"))
},
CloseOnShutdown: true,
}
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
_, err := ln.Dial()
if err == nil {
t.Error("server is still listening")
}
serveCh <- struct{}{}
}()
clientCh := make(chan struct{})
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
verifyResponseHeaderConnection(t, &resp.Header, "close")
clientCh <- struct{}{}
}()
time.Sleep(time.Millisecond * 100)
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
t.Errorf("unexepcted error: %s", err)
}
shutdownCh <- struct{}{}
}()
done := 0
for {
select {
case <-time.After(time.Second):
t.Fatal("shutdown took too long")
case <-serveCh:
done++
case <-clientCh:
done++
case <-shutdownCh:
done++
}
if done == 3 {
return
}
}
}
func TestShutdownReuse(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
ReadTimeout: time.Millisecond * 100,
Logger: &testLogger{}, // Ignore log output.
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
t.Fatalf("unexepcted error: %s", err)
}
ln = fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err = ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
t.Fatalf("unexepcted error: %s", err)
}
}
func TestShutdownDone(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
<-ctx.Done()
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
// We can only reach this point and get a valid response
// if reading from ctx.Done() returned.
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestShutdownErr(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
// This will panic, but I was not able to intercept with recover()
c, cancel := context.WithCancel(ctx)
defer cancel()
<-c.Done()
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
// We can only reach this point and get a valid response
// if reading from ctx.Done() returned.
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestMultipleServe(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}
ln1 := fasthttputil.NewInmemoryListener()
ln2 := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln1); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
go func() {
if err := s.Serve(ln2); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
conn, err := ln1.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
conn, err = ln2.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestMaxBodySizePerRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
MaxRequestBodySize: 5 << 10,
}
},
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
MaxRequestBodySize: 1 << 20,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1)))
if err := s.ServeConn(rw); err != ErrBodyTooLarge {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
}
func TestStreamRequestBody(t *testing.T) {
t.Parallel()
part1 := strings.Repeat("1", 1<<15)
part2 := strings.Repeat("2", 1<<16)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
StreamRequestBody: true,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil {
t.Fatal(err)
}
if _, err := cc.Write([]byte(part1)); err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}
if _, err := cc.Write([]byte(part2)); err != nil {
t.Fatal(err)
}
if err := sc.Close(); err != nil {
t.Fatal(err)
}
select {
case err := <-ch:
if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}
func TestStreamRequestBodyExceedMaxSize(t *testing.T) {
t.Parallel()
part1 := strings.Repeat("1", 1<<18)
part2 := strings.Repeat("2", 1<<20-1<<18)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
DisableKeepalive: true,
StreamRequestBody: true,
MaxRequestBodySize: 1,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
t.Error(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}
if _, err := cc.Write([]byte(part2)); err != nil {
t.Error(err)
}
select {
case err := <-ch:
if err != nil {
t.Error(err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}
func TestStreamBodyReqestContentLength(t *testing.T) {
t.Parallel()
content := strings.Repeat("1", 1<<15) // 32K
contentLength := len(content)
s := &Server{
Handler: func(ctx *RequestCtx) {
realContentLength := ctx.Request.Header.ContentLength()
if realContentLength != contentLength {
t.Fatal("incorrect content length")
}
},
MaxRequestBodySize: 1 * 1024 * 1024, // 1M
StreamRequestBody: true,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, content))); err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
if err := sc.Close(); err != nil {
t.Fatal(err)
}
select {
case err := <-ch:
if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func checkReader(t *testing.T, r io.Reader, expected string) {
b := make([]byte, len(expected))
if _, err := io.ReadFull(r, b); err != nil {
t.Fatalf("Unexpected error from reader: %s", err)
}
if string(b) != expected {
t.Fatal("incorrect request body")
}
}
func TestMaxReadTimeoutPerRequest(t *testing.T) {
t.Parallel()
headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
Handler: func(ctx *RequestCtx) {
t.Error("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
ReadTimeout: time.Millisecond,
}
},
ReadBufferSize: len(headers),
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
go func() {
//write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
}
//write body
for i := 0; i < 5*1024; i++ {
time.Sleep(time.Millisecond)
cc.Write([]byte{'a'}) //nolint:errcheck
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func TestMaxWriteTimeoutPerRequest(t *testing.T) {
t.Parallel()
headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n")
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
var buf [192]byte
for {
w.Write(buf[:]) //nolint:errcheck
}
})
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
WriteTimeout: time.Millisecond,
}
},
ReadBufferSize: 192,
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
var resp Response
go func() {
//write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
}
br := bufio.NewReaderSize(cc, 192)
err = resp.Header.Read(br)
if err != nil {
t.Error(err)
}
var chunk [192]byte
for {
time.Sleep(time.Millisecond)
br.Read(chunk[:]) //nolint:errcheck
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func TestIncompleteBodyReturnsUnexpectedEOF(t *testing.T) {
t.Parallel()
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\n\r\n123")
s := &Server{
Handler: func(ctx *RequestCtx) {},
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
if err := <-ch; err == nil || err.Error() != "unexpected EOF" {
t.Fatal(err)
}
}
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType)
return &resp
}
type readWriter struct {
net.Conn
r bytes.Buffer
w bytes.Buffer
}
func (rw *readWriter) Close() error {
return nil
}
func (rw *readWriter) Read(b []byte) (int, error) {
return rw.r.Read(b)
}
func (rw *readWriter) Write(b []byte) (int, error) {
return rw.w.Write(b)
}
func (rw *readWriter) RemoteAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) LocalAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) SetDeadline(t time.Time) error {
return nil
}
func (rw *readWriter) SetReadDeadline(t time.Time) error {
return nil
}
func (rw *readWriter) SetWriteDeadline(t time.Time) error {
return nil
}
type testLogger struct {
lock sync.Mutex
out string
}
func (cl *testLogger) Printf(format string, args ...interface{}) {
cl.lock.Lock()
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
cl.lock.Unlock()
}
fasthttp-1.31.0/server_timing_test.go 0000664 0000000 0000000 00000026405 14130360711 0017677 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
var defaultClientsCount = runtime.NumCPU()
func BenchmarkRequestCtxRedirect(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var ctx RequestCtx
for pb.Next() {
ctx.Request.SetRequestURI("http://aaa.com/fff/ss.html?sdf")
ctx.Redirect("/foo/bar?baz=111", StatusFound)
}
})
}
func BenchmarkServerGet1ReqPerConn(b *testing.B) {
benchmarkServerGet(b, defaultClientsCount, 1)
}
func BenchmarkServerGet2ReqPerConn(b *testing.B) {
benchmarkServerGet(b, defaultClientsCount, 2)
}
func BenchmarkServerGet10ReqPerConn(b *testing.B) {
benchmarkServerGet(b, defaultClientsCount, 10)
}
func BenchmarkServerGet10KReqPerConn(b *testing.B) {
benchmarkServerGet(b, defaultClientsCount, 10000)
}
func BenchmarkNetHTTPServerGet1ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerGet(b, defaultClientsCount, 1)
}
func BenchmarkNetHTTPServerGet2ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerGet(b, defaultClientsCount, 2)
}
func BenchmarkNetHTTPServerGet10ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerGet(b, defaultClientsCount, 10)
}
func BenchmarkNetHTTPServerGet10KReqPerConn(b *testing.B) {
benchmarkNetHTTPServerGet(b, defaultClientsCount, 10000)
}
func BenchmarkServerPost1ReqPerConn(b *testing.B) {
benchmarkServerPost(b, defaultClientsCount, 1)
}
func BenchmarkServerPost2ReqPerConn(b *testing.B) {
benchmarkServerPost(b, defaultClientsCount, 2)
}
func BenchmarkServerPost10ReqPerConn(b *testing.B) {
benchmarkServerPost(b, defaultClientsCount, 10)
}
func BenchmarkServerPost10KReqPerConn(b *testing.B) {
benchmarkServerPost(b, defaultClientsCount, 10000)
}
func BenchmarkNetHTTPServerPost1ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerPost(b, defaultClientsCount, 1)
}
func BenchmarkNetHTTPServerPost2ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerPost(b, defaultClientsCount, 2)
}
func BenchmarkNetHTTPServerPost10ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerPost(b, defaultClientsCount, 10)
}
func BenchmarkNetHTTPServerPost10KReqPerConn(b *testing.B) {
benchmarkNetHTTPServerPost(b, defaultClientsCount, 10000)
}
func BenchmarkServerGet1ReqPerConn10KClients(b *testing.B) {
benchmarkServerGet(b, 10000, 1)
}
func BenchmarkServerGet2ReqPerConn10KClients(b *testing.B) {
benchmarkServerGet(b, 10000, 2)
}
func BenchmarkServerGet10ReqPerConn10KClients(b *testing.B) {
benchmarkServerGet(b, 10000, 10)
}
func BenchmarkServerGet100ReqPerConn10KClients(b *testing.B) {
benchmarkServerGet(b, 10000, 100)
}
func BenchmarkNetHTTPServerGet1ReqPerConn10KClients(b *testing.B) {
benchmarkNetHTTPServerGet(b, 10000, 1)
}
func BenchmarkNetHTTPServerGet2ReqPerConn10KClients(b *testing.B) {
benchmarkNetHTTPServerGet(b, 10000, 2)
}
func BenchmarkNetHTTPServerGet10ReqPerConn10KClients(b *testing.B) {
benchmarkNetHTTPServerGet(b, 10000, 10)
}
func BenchmarkNetHTTPServerGet100ReqPerConn10KClients(b *testing.B) {
benchmarkNetHTTPServerGet(b, 10000, 100)
}
func BenchmarkServerHijack(b *testing.B) {
clientsCount := 1000
requestsPerConn := 10000
ch := make(chan struct{}, b.N)
responseBody := []byte("123")
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {
// emulate server loop :)
err := ServeConn(c, func(ctx *RequestCtx) {
ctx.Success("foobar", responseBody)
registerServedRequest(b, ch)
})
if err != nil {
b.Fatalf("error when serving connection")
}
})
ctx.Success("foobar", responseBody)
registerServedRequest(b, ch)
},
Concurrency: 16 * clientsCount,
}
req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n"
benchmarkServer(b, s, clientsCount, requestsPerConn, req)
verifyRequestsServed(b, ch)
}
func BenchmarkServerMaxConnsPerIP(b *testing.B) {
clientsCount := 1000
requestsPerConn := 10
ch := make(chan struct{}, b.N)
responseBody := []byte("123")
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("foobar", responseBody)
registerServedRequest(b, ch)
},
MaxConnsPerIP: clientsCount * 2,
Concurrency: 16 * clientsCount,
}
req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n"
benchmarkServer(b, s, clientsCount, requestsPerConn, req)
verifyRequestsServed(b, ch)
}
func BenchmarkServerTimeoutError(b *testing.B) {
clientsCount := 10
requestsPerConn := 1
ch := make(chan struct{}, b.N)
n := uint32(0)
responseBody := []byte("123")
s := &Server{
Handler: func(ctx *RequestCtx) {
if atomic.AddUint32(&n, 1)&7 == 0 {
ctx.TimeoutError("xxx")
go func() {
ctx.Success("foobar", responseBody)
}()
} else {
ctx.Success("foobar", responseBody)
}
registerServedRequest(b, ch)
},
Concurrency: 16 * clientsCount,
}
req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n"
benchmarkServer(b, s, clientsCount, requestsPerConn, req)
verifyRequestsServed(b, ch)
}
type fakeServerConn struct {
net.TCPConn
ln *fakeListener
requestsCount int
pos int
closed uint32
}
func (c *fakeServerConn) Read(b []byte) (int, error) {
nn := 0
reqLen := len(c.ln.request)
for len(b) > 0 {
if c.requestsCount == 0 {
if nn == 0 {
return 0, io.EOF
}
return nn, nil
}
pos := c.pos % reqLen
n := copy(b, c.ln.request[pos:])
b = b[n:]
nn += n
c.pos += n
if n+pos == reqLen {
c.requestsCount--
}
}
return nn, nil
}
func (c *fakeServerConn) Write(b []byte) (int, error) {
return len(b), nil
}
var fakeAddr = net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 12345,
}
func (c *fakeServerConn) RemoteAddr() net.Addr {
return &fakeAddr
}
func (c *fakeServerConn) Close() error {
if atomic.AddUint32(&c.closed, 1) == 1 {
c.ln.ch <- c
}
return nil
}
func (c *fakeServerConn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *fakeServerConn) SetWriteDeadline(t time.Time) error {
return nil
}
type fakeListener struct {
lock sync.Mutex
requestsCount int
requestsPerConn int
request []byte
ch chan *fakeServerConn
done chan struct{}
closed bool
}
func (ln *fakeListener) Accept() (net.Conn, error) {
ln.lock.Lock()
if ln.requestsCount == 0 {
ln.lock.Unlock()
for len(ln.ch) < cap(ln.ch) {
time.Sleep(10 * time.Millisecond)
}
ln.lock.Lock()
if !ln.closed {
close(ln.done)
ln.closed = true
}
ln.lock.Unlock()
return nil, io.EOF
}
requestsCount := ln.requestsPerConn
if requestsCount > ln.requestsCount {
requestsCount = ln.requestsCount
}
ln.requestsCount -= requestsCount
ln.lock.Unlock()
c := <-ln.ch
c.requestsCount = requestsCount
c.closed = 0
c.pos = 0
return c, nil
}
func (ln *fakeListener) Close() error {
return nil
}
func (ln *fakeListener) Addr() net.Addr {
return &fakeAddr
}
func newFakeListener(requestsCount, clientsCount, requestsPerConn int, request string) *fakeListener {
ln := &fakeListener{
requestsCount: requestsCount,
requestsPerConn: requestsPerConn,
request: []byte(request),
ch: make(chan *fakeServerConn, clientsCount),
done: make(chan struct{}),
}
for i := 0; i < clientsCount; i++ {
ln.ch <- &fakeServerConn{
ln: ln,
}
}
return ln
}
var (
fakeResponse = []byte("Hello, world!")
getRequest = "GET /foobar?baz HTTP/1.1\r\nHost: google.com\r\nUser-Agent: aaa/bbb/ccc/ddd/eee Firefox Chrome MSIE Opera\r\n" +
"Referer: http://xxx.com/aaa?bbb=ccc\r\nCookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n"
postRequest = fmt.Sprintf("POST /foobar?baz HTTP/1.1\r\nHost: google.com\r\nContent-Type: foo/bar\r\nContent-Length: %d\r\n"+
"User-Agent: Opera Chrome MSIE Firefox and other/1.2.34\r\nReferer: http://google.com/aaaa/bbb/ccc\r\n"+
"Cookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n%s",
len(fakeResponse), fakeResponse)
)
func benchmarkServerGet(b *testing.B, clientsCount, requestsPerConn int) {
ch := make(chan struct{}, b.N)
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsGet() {
b.Fatalf("Unexpected request method: %s", ctx.Method())
}
ctx.Success("text/plain", fakeResponse)
if requestsPerConn == 1 {
ctx.SetConnectionClose()
}
registerServedRequest(b, ch)
},
Concurrency: 16 * clientsCount,
}
benchmarkServer(b, s, clientsCount, requestsPerConn, getRequest)
verifyRequestsServed(b, ch)
}
func benchmarkNetHTTPServerGet(b *testing.B, clientsCount, requestsPerConn int) {
ch := make(chan struct{}, b.N)
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Method != MethodGet {
b.Fatalf("Unexpected request method: %s", req.Method)
}
h := w.Header()
h.Set("Content-Type", "text/plain")
if requestsPerConn == 1 {
h.Set(HeaderConnection, "close")
}
w.Write(fakeResponse) //nolint:errcheck
registerServedRequest(b, ch)
}),
}
benchmarkServer(b, s, clientsCount, requestsPerConn, getRequest)
verifyRequestsServed(b, ch)
}
func benchmarkServerPost(b *testing.B, clientsCount, requestsPerConn int) {
ch := make(chan struct{}, b.N)
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
b.Fatalf("Unexpected request method: %s", ctx.Method())
}
body := ctx.Request.Body()
if !bytes.Equal(body, fakeResponse) {
b.Fatalf("Unexpected body %q. Expected %q", body, fakeResponse)
}
ctx.Success("text/plain", body)
if requestsPerConn == 1 {
ctx.SetConnectionClose()
}
registerServedRequest(b, ch)
},
Concurrency: 16 * clientsCount,
}
benchmarkServer(b, s, clientsCount, requestsPerConn, postRequest)
verifyRequestsServed(b, ch)
}
func benchmarkNetHTTPServerPost(b *testing.B, clientsCount, requestsPerConn int) {
ch := make(chan struct{}, b.N)
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Method != MethodPost {
b.Fatalf("Unexpected request method: %s", req.Method)
}
body, err := ioutil.ReadAll(req.Body)
if err != nil {
b.Fatalf("Unexpected error: %s", err)
}
req.Body.Close()
if !bytes.Equal(body, fakeResponse) {
b.Fatalf("Unexpected body %q. Expected %q", body, fakeResponse)
}
h := w.Header()
h.Set("Content-Type", "text/plain")
if requestsPerConn == 1 {
h.Set(HeaderConnection, "close")
}
w.Write(body) //nolint:errcheck
registerServedRequest(b, ch)
}),
}
benchmarkServer(b, s, clientsCount, requestsPerConn, postRequest)
verifyRequestsServed(b, ch)
}
func registerServedRequest(b *testing.B, ch chan<- struct{}) {
select {
case ch <- struct{}{}:
default:
b.Fatalf("More than %d requests served", cap(ch))
}
}
func verifyRequestsServed(b *testing.B, ch <-chan struct{}) {
requestsServed := 0
for len(ch) > 0 {
<-ch
requestsServed++
}
requestsSent := b.N
for requestsServed < requestsSent {
select {
case <-ch:
requestsServed++
case <-time.After(100 * time.Millisecond):
b.Fatalf("Unexpected number of requests served %d. Expected %d", requestsServed, requestsSent)
}
}
}
type realServer interface {
Serve(ln net.Listener) error
}
func benchmarkServer(b *testing.B, s realServer, clientsCount, requestsPerConn int, request string) {
ln := newFakeListener(b.N, clientsCount, requestsPerConn, request)
ch := make(chan struct{})
go func() {
s.Serve(ln) //nolint:errcheck
ch <- struct{}{}
}()
<-ln.done
select {
case <-ch:
case <-time.After(10 * time.Second):
b.Fatalf("Server.Serve() didn't stop")
}
}
fasthttp-1.31.0/stackless/ 0000775 0000000 0000000 00000000000 14130360711 0015421 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/stackless/doc.go 0000664 0000000 0000000 00000000217 14130360711 0016515 0 ustar 00root root 0000000 0000000 // Package stackless provides functionality that may save stack space
// for high number of concurrently running goroutines.
package stackless
fasthttp-1.31.0/stackless/func.go 0000664 0000000 0000000 00000003106 14130360711 0016703 0 ustar 00root root 0000000 0000000 package stackless
import (
"runtime"
"sync"
)
// NewFunc returns stackless wrapper for the function f.
//
// Unlike f, the returned stackless wrapper doesn't use stack space
// on the goroutine that calls it.
// The wrapper may save a lot of stack space if the following conditions
// are met:
//
// - f doesn't contain blocking calls on network, I/O or channels;
// - f uses a lot of stack space;
// - the wrapper is called from high number of concurrent goroutines.
//
// The stackless wrapper returns false if the call cannot be processed
// at the moment due to high load.
func NewFunc(f func(ctx interface{})) func(ctx interface{}) bool {
if f == nil {
panic("BUG: f cannot be nil")
}
funcWorkCh := make(chan *funcWork, runtime.GOMAXPROCS(-1)*2048)
onceInit := func() {
n := runtime.GOMAXPROCS(-1)
for i := 0; i < n; i++ {
go funcWorker(funcWorkCh, f)
}
}
var once sync.Once
return func(ctx interface{}) bool {
once.Do(onceInit)
fw := getFuncWork()
fw.ctx = ctx
select {
case funcWorkCh <- fw:
default:
putFuncWork(fw)
return false
}
<-fw.done
putFuncWork(fw)
return true
}
}
func funcWorker(funcWorkCh <-chan *funcWork, f func(ctx interface{})) {
for fw := range funcWorkCh {
f(fw.ctx)
fw.done <- struct{}{}
}
}
func getFuncWork() *funcWork {
v := funcWorkPool.Get()
if v == nil {
v = &funcWork{
done: make(chan struct{}, 1),
}
}
return v.(*funcWork)
}
func putFuncWork(fw *funcWork) {
fw.ctx = nil
funcWorkPool.Put(fw)
}
var funcWorkPool sync.Pool
type funcWork struct {
ctx interface{}
done chan struct{}
}
fasthttp-1.31.0/stackless/func_test.go 0000664 0000000 0000000 00000003130 14130360711 0017737 0 ustar 00root root 0000000 0000000 package stackless
import (
"fmt"
"sync/atomic"
"testing"
"time"
)
func TestNewFuncSimple(t *testing.T) {
t.Parallel()
var n uint64
f := NewFunc(func(ctx interface{}) {
atomic.AddUint64(&n, uint64(ctx.(int)))
})
iterations := 4 * 1024
for i := 0; i < iterations; i++ {
if !f(2) {
t.Fatalf("f mustn't return false")
}
}
if n != uint64(2*iterations) {
t.Fatalf("Unexpected n: %d. Expecting %d", n, 2*iterations)
}
}
func TestNewFuncMulti(t *testing.T) {
t.Parallel()
var n1, n2 uint64
f1 := NewFunc(func(ctx interface{}) {
atomic.AddUint64(&n1, uint64(ctx.(int)))
})
f2 := NewFunc(func(ctx interface{}) {
atomic.AddUint64(&n2, uint64(ctx.(int)))
})
iterations := 4 * 1024
f1Done := make(chan error, 1)
go func() {
var err error
for i := 0; i < iterations; i++ {
if !f1(3) {
err = fmt.Errorf("f1 mustn't return false")
break
}
}
f1Done <- err
}()
f2Done := make(chan error, 1)
go func() {
var err error
for i := 0; i < iterations; i++ {
if !f2(5) {
err = fmt.Errorf("f2 mustn't return false")
break
}
}
f2Done <- err
}()
select {
case err := <-f1Done:
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
select {
case err := <-f2Done:
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if n1 != uint64(3*iterations) {
t.Fatalf("unexpected n1: %d. Expecting %d", n1, 3*iterations)
}
if n2 != uint64(5*iterations) {
t.Fatalf("unexpected n2: %d. Expecting %d", n2, 5*iterations)
}
}
fasthttp-1.31.0/stackless/func_timing_test.go 0000664 0000000 0000000 00000001267 14130360711 0021317 0 ustar 00root root 0000000 0000000 package stackless
import (
"sync/atomic"
"testing"
)
func BenchmarkFuncOverhead(b *testing.B) {
var n uint64
f := NewFunc(func(ctx interface{}) {
atomic.AddUint64(&n, *(ctx.(*uint64)))
})
b.RunParallel(func(pb *testing.PB) {
x := uint64(1)
for pb.Next() {
if !f(&x) {
b.Fatalf("f mustn't return false")
}
}
})
if n != uint64(b.N) {
b.Fatalf("unexected n: %d. Expecting %d", n, b.N)
}
}
func BenchmarkFuncPure(b *testing.B) {
var n uint64
f := func(x *uint64) {
atomic.AddUint64(&n, *x)
}
b.RunParallel(func(pb *testing.PB) {
x := uint64(1)
for pb.Next() {
f(&x)
}
})
if n != uint64(b.N) {
b.Fatalf("unexected n: %d. Expecting %d", n, b.N)
}
}
fasthttp-1.31.0/stackless/writer.go 0000664 0000000 0000000 00000004532 14130360711 0017270 0 ustar 00root root 0000000 0000000 package stackless
import (
"errors"
"fmt"
"io"
"github.com/valyala/bytebufferpool"
)
// Writer is an interface stackless writer must conform to.
//
// The interface contains common subset for Writers from compress/* packages.
type Writer interface {
Write(p []byte) (int, error)
Flush() error
Close() error
Reset(w io.Writer)
}
// NewWriterFunc must return new writer that will be wrapped into
// stackless writer.
type NewWriterFunc func(w io.Writer) Writer
// NewWriter creates a stackless writer around a writer returned
// from newWriter.
//
// The returned writer writes data to dstW.
//
// Writers that use a lot of stack space may be wrapped into stackless writer,
// thus saving stack space for high number of concurrently running goroutines.
func NewWriter(dstW io.Writer, newWriter NewWriterFunc) Writer {
w := &writer{
dstW: dstW,
}
w.zw = newWriter(&w.xw)
return w
}
type writer struct {
dstW io.Writer
zw Writer
xw xWriter
err error
n int
p []byte
op op
}
type op int
const (
opWrite op = iota
opFlush
opClose
opReset
)
func (w *writer) Write(p []byte) (int, error) {
w.p = p
err := w.do(opWrite)
w.p = nil
return w.n, err
}
func (w *writer) Flush() error {
return w.do(opFlush)
}
func (w *writer) Close() error {
return w.do(opClose)
}
func (w *writer) Reset(dstW io.Writer) {
w.xw.Reset()
w.do(opReset) //nolint:errcheck
w.dstW = dstW
}
func (w *writer) do(op op) error {
w.op = op
if !stacklessWriterFunc(w) {
return errHighLoad
}
err := w.err
if err != nil {
return err
}
if w.xw.bb != nil && len(w.xw.bb.B) > 0 {
_, err = w.dstW.Write(w.xw.bb.B)
}
w.xw.Reset()
return err
}
var errHighLoad = errors.New("cannot compress data due to high load")
var stacklessWriterFunc = NewFunc(writerFunc)
func writerFunc(ctx interface{}) {
w := ctx.(*writer)
switch w.op {
case opWrite:
w.n, w.err = w.zw.Write(w.p)
case opFlush:
w.err = w.zw.Flush()
case opClose:
w.err = w.zw.Close()
case opReset:
w.zw.Reset(&w.xw)
w.err = nil
default:
panic(fmt.Sprintf("BUG: unexpected op: %d", w.op))
}
}
type xWriter struct {
bb *bytebufferpool.ByteBuffer
}
func (w *xWriter) Write(p []byte) (int, error) {
if w.bb == nil {
w.bb = bufferPool.Get()
}
return w.bb.Write(p)
}
func (w *xWriter) Reset() {
if w.bb != nil {
bufferPool.Put(w.bb)
w.bb = nil
}
}
var bufferPool bytebufferpool.Pool
fasthttp-1.31.0/stackless/writer_test.go 0000664 0000000 0000000 00000005363 14130360711 0020332 0 ustar 00root root 0000000 0000000 package stackless
import (
"bytes"
"compress/flate"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"testing"
"time"
)
func TestCompressFlateSerial(t *testing.T) {
t.Parallel()
if err := testCompressFlate(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestCompressFlateConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(testCompressFlate, 10); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func testCompressFlate() error {
return testWriter(func(w io.Writer) Writer {
zw, err := flate.NewWriter(w, flate.DefaultCompression)
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error: %s", err))
}
return zw
}, func(r io.Reader) io.Reader {
return flate.NewReader(r)
})
}
func TestCompressGzipSerial(t *testing.T) {
t.Parallel()
if err := testCompressGzip(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestCompressGzipConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(testCompressGzip, 10); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func testCompressGzip() error {
return testWriter(func(w io.Writer) Writer {
return gzip.NewWriter(w)
}, func(r io.Reader) io.Reader {
zr, err := gzip.NewReader(r)
if err != nil {
panic(fmt.Sprintf("BUG: cannot create gzip reader: %s", err))
}
return zr
})
}
func testWriter(newWriter NewWriterFunc, newReader func(io.Reader) io.Reader) error {
dstW := &bytes.Buffer{}
w := NewWriter(dstW, newWriter)
for i := 0; i < 5; i++ {
if err := testWriterReuse(w, dstW, newReader); err != nil {
return fmt.Errorf("unexpected error when re-using writer on iteration %d: %s", i, err)
}
dstW = &bytes.Buffer{}
w.Reset(dstW)
}
return nil
}
func testWriterReuse(w Writer, r io.Reader, newReader func(io.Reader) io.Reader) error {
wantW := &bytes.Buffer{}
mw := io.MultiWriter(w, wantW)
for i := 0; i < 30; i++ {
fmt.Fprintf(mw, "foobar %d\n", i)
if i%13 == 0 {
if err := w.Flush(); err != nil {
return fmt.Errorf("error on flush: %s", err)
}
}
}
w.Close()
zr := newReader(r)
data, err := ioutil.ReadAll(zr)
if err != nil {
return fmt.Errorf("unexpected error: %s, data=%q", err, data)
}
wantData := wantW.Bytes()
if !bytes.Equal(data, wantData) {
return fmt.Errorf("unexpected data: %q. Expecting %q", data, wantData)
}
return nil
}
func testConcurrent(testFunc func() error, concurrency int) error {
ch := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
ch <- testFunc()
}()
}
for i := 0; i < concurrency; i++ {
select {
case err := <-ch:
if err != nil {
return fmt.Errorf("unexpected error on goroutine %d: %s", i, err)
}
case <-time.After(time.Second):
return fmt.Errorf("timeout on goroutine %d", i)
}
}
return nil
}
fasthttp-1.31.0/status.go 0000664 0000000 0000000 00000017756 14130360711 0015317 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"fmt"
"strconv"
)
const (
statusMessageMin = 100
statusMessageMax = 511
)
// HTTP status codes were stolen from net/http.
const (
StatusContinue = 100 // RFC 7231, 6.2.1
StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2
StatusProcessing = 102 // RFC 2518, 10.1
StatusEarlyHints = 103 // RFC 8297
StatusOK = 200 // RFC 7231, 6.3.1
StatusCreated = 201 // RFC 7231, 6.3.2
StatusAccepted = 202 // RFC 7231, 6.3.3
StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4
StatusNoContent = 204 // RFC 7231, 6.3.5
StatusResetContent = 205 // RFC 7231, 6.3.6
StatusPartialContent = 206 // RFC 7233, 4.1
StatusMultiStatus = 207 // RFC 4918, 11.1
StatusAlreadyReported = 208 // RFC 5842, 7.1
StatusIMUsed = 226 // RFC 3229, 10.4.1
StatusMultipleChoices = 300 // RFC 7231, 6.4.1
StatusMovedPermanently = 301 // RFC 7231, 6.4.2
StatusFound = 302 // RFC 7231, 6.4.3
StatusSeeOther = 303 // RFC 7231, 6.4.4
StatusNotModified = 304 // RFC 7232, 4.1
StatusUseProxy = 305 // RFC 7231, 6.4.5
_ = 306 // RFC 7231, 6.4.6 (Unused)
StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7
StatusPermanentRedirect = 308 // RFC 7538, 3
StatusBadRequest = 400 // RFC 7231, 6.5.1
StatusUnauthorized = 401 // RFC 7235, 3.1
StatusPaymentRequired = 402 // RFC 7231, 6.5.2
StatusForbidden = 403 // RFC 7231, 6.5.3
StatusNotFound = 404 // RFC 7231, 6.5.4
StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5
StatusNotAcceptable = 406 // RFC 7231, 6.5.6
StatusProxyAuthRequired = 407 // RFC 7235, 3.2
StatusRequestTimeout = 408 // RFC 7231, 6.5.7
StatusConflict = 409 // RFC 7231, 6.5.8
StatusGone = 410 // RFC 7231, 6.5.9
StatusLengthRequired = 411 // RFC 7231, 6.5.10
StatusPreconditionFailed = 412 // RFC 7232, 4.2
StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11
StatusRequestURITooLong = 414 // RFC 7231, 6.5.12
StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13
StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4
StatusExpectationFailed = 417 // RFC 7231, 6.5.14
StatusTeapot = 418 // RFC 7168, 2.3.3
StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2
StatusUnprocessableEntity = 422 // RFC 4918, 11.2
StatusLocked = 423 // RFC 4918, 11.3
StatusFailedDependency = 424 // RFC 4918, 11.4
StatusUpgradeRequired = 426 // RFC 7231, 6.5.15
StatusPreconditionRequired = 428 // RFC 6585, 3
StatusTooManyRequests = 429 // RFC 6585, 4
StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5
StatusUnavailableForLegalReasons = 451 // RFC 7725, 3
StatusInternalServerError = 500 // RFC 7231, 6.6.1
StatusNotImplemented = 501 // RFC 7231, 6.6.2
StatusBadGateway = 502 // RFC 7231, 6.6.3
StatusServiceUnavailable = 503 // RFC 7231, 6.6.4
StatusGatewayTimeout = 504 // RFC 7231, 6.6.5
StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6
StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1
StatusInsufficientStorage = 507 // RFC 4918, 11.5
StatusLoopDetected = 508 // RFC 5842, 7.2
StatusNotExtended = 510 // RFC 2774, 7
StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6
)
var (
statusLines = make([][]byte, statusMessageMax+1)
statusMessages = []string{
StatusContinue: "Continue",
StatusSwitchingProtocols: "Switching Protocols",
StatusProcessing: "Processing",
StatusEarlyHints: "Early Hints",
StatusOK: "OK",
StatusCreated: "Created",
StatusAccepted: "Accepted",
StatusNonAuthoritativeInfo: "Non-Authoritative Information",
StatusNoContent: "No Content",
StatusResetContent: "Reset Content",
StatusPartialContent: "Partial Content",
StatusMultiStatus: "Multi-Status",
StatusAlreadyReported: "Already Reported",
StatusIMUsed: "IM Used",
StatusMultipleChoices: "Multiple Choices",
StatusMovedPermanently: "Moved Permanently",
StatusFound: "Found",
StatusSeeOther: "See Other",
StatusNotModified: "Not Modified",
StatusUseProxy: "Use Proxy",
StatusTemporaryRedirect: "Temporary Redirect",
StatusPermanentRedirect: "Permanent Redirect",
StatusBadRequest: "Bad Request",
StatusUnauthorized: "Unauthorized",
StatusPaymentRequired: "Payment Required",
StatusForbidden: "Forbidden",
StatusNotFound: "Not Found",
StatusMethodNotAllowed: "Method Not Allowed",
StatusNotAcceptable: "Not Acceptable",
StatusProxyAuthRequired: "Proxy Authentication Required",
StatusRequestTimeout: "Request Timeout",
StatusConflict: "Conflict",
StatusGone: "Gone",
StatusLengthRequired: "Length Required",
StatusPreconditionFailed: "Precondition Failed",
StatusRequestEntityTooLarge: "Request Entity Too Large",
StatusRequestURITooLong: "Request URI Too Long",
StatusUnsupportedMediaType: "Unsupported Media Type",
StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
StatusExpectationFailed: "Expectation Failed",
StatusTeapot: "I'm a teapot",
StatusMisdirectedRequest: "Misdirected Request",
StatusUnprocessableEntity: "Unprocessable Entity",
StatusLocked: "Locked",
StatusFailedDependency: "Failed Dependency",
StatusUpgradeRequired: "Upgrade Required",
StatusPreconditionRequired: "Precondition Required",
StatusTooManyRequests: "Too Many Requests",
StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons",
StatusInternalServerError: "Internal Server Error",
StatusNotImplemented: "Not Implemented",
StatusBadGateway: "Bad Gateway",
StatusServiceUnavailable: "Service Unavailable",
StatusGatewayTimeout: "Gateway Timeout",
StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
StatusVariantAlsoNegotiates: "Variant Also Negotiates",
StatusInsufficientStorage: "Insufficient Storage",
StatusLoopDetected: "Loop Detected",
StatusNotExtended: "Not Extended",
StatusNetworkAuthenticationRequired: "Network Authentication Required",
}
)
// StatusMessage returns HTTP status message for the given status code.
func StatusMessage(statusCode int) string {
if statusCode < statusMessageMin || statusCode > statusMessageMax {
return "Unknown Status Code"
}
s := statusMessages[statusCode]
if s == "" {
s = "Unknown Status Code"
}
return s
}
func init() {
// Fill all valid status lines
for i := 0; i < len(statusLines); i++ {
statusLines[i] = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", i, StatusMessage(i)))
}
}
func statusLine(statusCode int) []byte {
if statusCode < 0 || statusCode > statusMessageMax {
return invalidStatusLine(statusCode)
}
return statusLines[statusCode]
}
func invalidStatusLine(statusCode int) []byte {
statusText := StatusMessage(statusCode)
// xxx placeholder of status code
var line = make([]byte, 0, len("HTTP/1.1 xxx \r\n")+len(statusText))
line = append(line, "HTTP/1.1 "...)
line = strconv.AppendInt(line, int64(statusCode), 10)
line = append(line, ' ')
line = append(line, statusText...)
line = append(line, "\r\n"...)
return line
}
fasthttp-1.31.0/status_test.go 0000664 0000000 0000000 00000001346 14130360711 0016342 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"testing"
)
func TestStatusLine(t *testing.T) {
t.Parallel()
testStatusLine(t, -1, []byte("HTTP/1.1 -1 Unknown Status Code\r\n"))
testStatusLine(t, 99, []byte("HTTP/1.1 99 Unknown Status Code\r\n"))
testStatusLine(t, 200, []byte("HTTP/1.1 200 OK\r\n"))
testStatusLine(t, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n"))
testStatusLine(t, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n"))
testStatusLine(t, 520, []byte("HTTP/1.1 520 Unknown Status Code\r\n"))
}
func testStatusLine(t *testing.T, statusCode int, expected []byte) {
line := statusLine(statusCode)
if !bytes.Equal(expected, line) {
t.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected))
}
}
fasthttp-1.31.0/status_timing_test.go 0000664 0000000 0000000 00000001302 14130360711 0017701 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"testing"
)
func BenchmarkStatusLine99(b *testing.B) {
benchmarkStatusLine(b, 99, []byte("HTTP/1.1 99 Unknown Status Code\r\n"))
}
func BenchmarkStatusLine200(b *testing.B) {
benchmarkStatusLine(b, 200, []byte("HTTP/1.1 200 OK\r\n"))
}
func BenchmarkStatusLine512(b *testing.B) {
benchmarkStatusLine(b, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n"))
}
func benchmarkStatusLine(b *testing.B, statusCode int, expected []byte) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
line := statusLine(statusCode)
if !bytes.Equal(expected, line) {
b.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected))
}
}
})
}
fasthttp-1.31.0/stream.go 0000664 0000000 0000000 00000002223 14130360711 0015246 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"io"
"sync"
"github.com/valyala/fasthttp/fasthttputil"
)
// StreamWriter must write data to w.
//
// Usually StreamWriter writes data to w in a loop (aka 'data streaming').
//
// StreamWriter must return immediately if w returns error.
//
// Since the written data is buffered, do not forget calling w.Flush
// when the data must be propagated to reader.
type StreamWriter func(w *bufio.Writer)
// NewStreamReader returns a reader, which replays all the data generated by sw.
//
// The returned reader may be passed to Response.SetBodyStream.
//
// Close must be called on the returned reader after all the required data
// has been read. Otherwise goroutine leak may occur.
//
// See also Response.SetBodyStreamWriter.
func NewStreamReader(sw StreamWriter) io.ReadCloser {
pc := fasthttputil.NewPipeConns()
pw := pc.Conn1()
pr := pc.Conn2()
var bw *bufio.Writer
v := streamWriterBufPool.Get()
if v == nil {
bw = bufio.NewWriter(pw)
} else {
bw = v.(*bufio.Writer)
bw.Reset(pw)
}
go func() {
sw(bw)
bw.Flush()
pw.Close()
streamWriterBufPool.Put(bw)
}()
return pr
}
var streamWriterBufPool sync.Pool
fasthttp-1.31.0/stream_test.go 0000664 0000000 0000000 00000004253 14130360711 0016312 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"testing"
"time"
)
func TestNewStreamReader(t *testing.T) {
t.Parallel()
ch := make(chan struct{})
r := NewStreamReader(func(w *bufio.Writer) {
fmt.Fprintf(w, "Hello, world\n")
fmt.Fprintf(w, "Line #2\n")
close(ch)
})
data, err := ioutil.ReadAll(r)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedData := "Hello, world\nLine #2\n"
if string(data) != expectedData {
t.Fatalf("unexpected data %q. Expecting %q", data, expectedData)
}
if err = r.Close(); err != nil {
t.Fatalf("unexpected error")
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestStreamReaderClose(t *testing.T) {
t.Parallel()
firstLine := "the first line must pass"
ch := make(chan error, 1)
r := NewStreamReader(func(w *bufio.Writer) {
fmt.Fprintf(w, "%s", firstLine)
if err := w.Flush(); err != nil {
ch <- fmt.Errorf("unexpected error on first flush: %s", err)
return
}
data := createFixedBody(4000)
for i := 0; i < 100; i++ {
w.Write(data) //nolint:errcheck
}
if err := w.Flush(); err == nil {
ch <- fmt.Errorf("expecting error on the second flush")
}
ch <- nil
})
buf := make([]byte, len(firstLine))
n, err := io.ReadFull(r, buf)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != len(buf) {
t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(buf))
}
if string(buf) != firstLine {
t.Fatalf("unexpected result: %q. Expecting %q", buf, firstLine)
}
if err := r.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case err := <-ch:
if err != nil {
t.Fatalf("error returned from stream reader: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout when waiting for stream reader")
}
// read trailing data
go func() {
if _, err := ioutil.ReadAll(r); err != nil {
ch <- fmt.Errorf("unexpected error when reading trailing data: %s", err)
return
}
ch <- nil
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("error returned when reading tail data: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout when reading tail data")
}
}
fasthttp-1.31.0/stream_timing_test.go 0000664 0000000 0000000 00000002527 14130360711 0017663 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"io"
"testing"
"time"
)
func BenchmarkStreamReader1(b *testing.B) {
benchmarkStreamReader(b, 1)
}
func BenchmarkStreamReader10(b *testing.B) {
benchmarkStreamReader(b, 10)
}
func BenchmarkStreamReader100(b *testing.B) {
benchmarkStreamReader(b, 100)
}
func BenchmarkStreamReader1K(b *testing.B) {
benchmarkStreamReader(b, 1000)
}
func BenchmarkStreamReader10K(b *testing.B) {
benchmarkStreamReader(b, 10000)
}
func benchmarkStreamReader(b *testing.B, size int) {
src := createFixedBody(size)
b.SetBytes(int64(size))
b.RunParallel(func(pb *testing.PB) {
dst := make([]byte, size)
ch := make(chan error, 1)
sr := NewStreamReader(func(w *bufio.Writer) {
for pb.Next() {
if _, err := w.Write(src); err != nil {
ch <- err
return
}
if err := w.Flush(); err != nil {
ch <- err
return
}
}
ch <- nil
})
for {
if _, err := sr.Read(dst); err != nil {
if err == io.EOF {
break
}
b.Fatalf("unexpected error when reading from stream reader: %s", err)
}
}
if err := sr.Close(); err != nil {
b.Fatalf("unexpected error when closing stream reader: %s", err)
}
select {
case err := <-ch:
if err != nil {
b.Fatalf("unexpected error from stream reader: %s", err)
}
case <-time.After(time.Second):
b.Fatalf("timeout")
}
})
}
fasthttp-1.31.0/streaming.go 0000664 0000000 0000000 00000004120 14130360711 0015742 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"io"
"sync"
"github.com/valyala/bytebufferpool"
)
type requestStream struct {
prefetchedBytes *bytes.Reader
reader *bufio.Reader
totalBytesRead int
contentLength int
chunkLeft int
}
func (rs *requestStream) Read(p []byte) (int, error) {
var (
n int
err error
)
if rs.contentLength == -1 {
if rs.chunkLeft == 0 {
chunkSize, err := parseChunkSize(rs.reader)
if err != nil {
return 0, err
}
if chunkSize == 0 {
err = readCrLf(rs.reader)
if err == nil {
err = io.EOF
}
return 0, err
}
rs.chunkLeft = chunkSize
}
bytesToRead := len(p)
if rs.chunkLeft < len(p) {
bytesToRead = rs.chunkLeft
}
n, err = rs.reader.Read(p[:bytesToRead])
rs.totalBytesRead += n
rs.chunkLeft -= n
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err == nil && rs.chunkLeft == 0 {
err = readCrLf(rs.reader)
}
return n, err
}
if rs.totalBytesRead == rs.contentLength {
return 0, io.EOF
}
prefetchedSize := int(rs.prefetchedBytes.Size())
if prefetchedSize > rs.totalBytesRead {
left := prefetchedSize - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.contentLength {
return n, io.EOF
}
return n, err
} else {
left := rs.contentLength - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
n, err = rs.reader.Read(p)
rs.totalBytesRead += n
if err != nil {
return n, err
}
}
if rs.totalBytesRead == rs.contentLength {
err = io.EOF
}
return n, err
}
func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, contentLength int) *requestStream {
rs := requestStreamPool.Get().(*requestStream)
rs.prefetchedBytes = bytes.NewReader(b.B)
rs.reader = r
rs.contentLength = contentLength
return rs
}
func releaseRequestStream(rs *requestStream) {
rs.prefetchedBytes = nil
rs.totalBytesRead = 0
rs.chunkLeft = 0
rs.reader = nil
requestStreamPool.Put(rs)
}
var requestStreamPool = sync.Pool{
New: func() interface{} {
return &requestStream{}
},
}
fasthttp-1.31.0/streaming_test.go 0000664 0000000 0000000 00000010143 14130360711 0017003 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bufio"
"bytes"
"io/ioutil"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestStreamingPipeline(t *testing.T) {
t.Parallel()
reqS := `POST /one HTTP/1.1
Host: example.com
Content-Length: 10
aaaaaaaaaa
POST /two HTTP/1.1
Host: example.com
Content-Length: 10
aaaaaaaaaa`
ln := fasthttputil.NewInmemoryListener()
s := &Server{
StreamRequestBody: true,
Handler: func(ctx *RequestCtx) {
body := ""
expected := "aaaaaaaaaa"
if string(ctx.Path()) == "/one" {
body = string(ctx.PostBody())
} else {
all, err := ioutil.ReadAll(ctx.RequestBodyStream())
if err != nil {
t.Error(err)
}
body = string(all)
}
if body != expected {
t.Errorf("expected %q got %q", expected, body)
}
},
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func getChunkedTestEnv(t testing.TB) (*fasthttputil.InmemoryListener, []byte) {
body := createFixedBody(128 * 1024)
chunkedBody := createChunkedBody(body)
testHandler := func(ctx *RequestCtx) {
bodyBytes, err := ioutil.ReadAll(ctx.RequestBodyStream())
if err != nil {
t.Logf("ioutil read returned err=%s", err)
t.Error("unexpected error while reading request body stream")
}
if !bytes.Equal(body, bodyBytes) {
t.Errorf("unexpected request body, expected %q, got %q", body, bodyBytes)
}
}
s := &Server{
Handler: testHandler,
StreamRequestBody: true,
MaxRequestBodySize: 1, // easier to test with small limit
}
ln := fasthttputil.NewInmemoryListener()
go func() {
err := s.Serve(ln)
if err != nil {
t.Errorf("could not serve listener: %s", err)
}
}()
req := Request{}
req.SetHost("localhost")
req.Header.SetMethod("POST")
req.Header.Set("transfer-encoding", "chunked")
req.Header.SetContentLength(-1)
formattedRequest := req.Header.Header()
formattedRequest = append(formattedRequest, chunkedBody...)
return ln, formattedRequest
}
func TestRequestStream(t *testing.T) {
t.Parallel()
ln, formattedRequest := getChunkedTestEnv(t)
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error while dialing: %s", err)
}
if _, err = c.Write(formattedRequest); err != nil {
t.Errorf("unexpected error while writing request: %s", err)
}
br := bufio.NewReader(c)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
t.Errorf("unexpected error: %s", err)
}
}
func BenchmarkRequestStreamE2E(b *testing.B) {
ln, formattedRequest := getChunkedTestEnv(b)
wg := &sync.WaitGroup{}
wg.Add(4)
for i := 0; i < 4; i++ {
go func(wg *sync.WaitGroup) {
for i := 0; i < b.N/4; i++ {
c, err := ln.Dial()
if err != nil {
b.Errorf("unexpected error while dialing: %s", err)
}
if _, err = c.Write(formattedRequest); err != nil {
b.Errorf("unexpected error while writing request: %s", err)
}
br := bufio.NewReaderSize(c, 128)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
b.Errorf("unexpected error: %s", err)
}
c.Close()
}
wg.Done()
}(wg)
}
wg.Wait()
}
fasthttp-1.31.0/strings.go 0000664 0000000 0000000 00000006157 14130360711 0015456 0 ustar 00root root 0000000 0000000 package fasthttp
var (
defaultServerName = []byte("fasthttp")
defaultUserAgent = []byte("fasthttp")
defaultContentType = []byte("text/plain; charset=utf-8")
)
var (
strSlash = []byte("/")
strSlashSlash = []byte("//")
strSlashDotDot = []byte("/..")
strSlashDotSlash = []byte("/./")
strSlashDotDotSlash = []byte("/../")
strCRLF = []byte("\r\n")
strHTTP = []byte("http")
strHTTPS = []byte("https")
strHTTP10 = []byte("HTTP/1.0")
strHTTP11 = []byte("HTTP/1.1")
strColon = []byte(":")
strColonSlashSlash = []byte("://")
strColonSpace = []byte(": ")
strGMT = []byte("GMT")
strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n")
strExpect = []byte(HeaderExpect)
strConnection = []byte(HeaderConnection)
strContentLength = []byte(HeaderContentLength)
strContentType = []byte(HeaderContentType)
strDate = []byte(HeaderDate)
strHost = []byte(HeaderHost)
strReferer = []byte(HeaderReferer)
strServer = []byte(HeaderServer)
strTransferEncoding = []byte(HeaderTransferEncoding)
strContentEncoding = []byte(HeaderContentEncoding)
strAcceptEncoding = []byte(HeaderAcceptEncoding)
strUserAgent = []byte(HeaderUserAgent)
strCookie = []byte(HeaderCookie)
strSetCookie = []byte(HeaderSetCookie)
strLocation = []byte(HeaderLocation)
strIfModifiedSince = []byte(HeaderIfModifiedSince)
strLastModified = []byte(HeaderLastModified)
strAcceptRanges = []byte(HeaderAcceptRanges)
strRange = []byte(HeaderRange)
strContentRange = []byte(HeaderContentRange)
strAuthorization = []byte(HeaderAuthorization)
strCookieExpires = []byte("expires")
strCookieDomain = []byte("domain")
strCookiePath = []byte("path")
strCookieHTTPOnly = []byte("HttpOnly")
strCookieSecure = []byte("secure")
strCookieMaxAge = []byte("max-age")
strCookieSameSite = []byte("SameSite")
strCookieSameSiteLax = []byte("Lax")
strCookieSameSiteStrict = []byte("Strict")
strCookieSameSiteNone = []byte("None")
strClose = []byte("close")
strGzip = []byte("gzip")
strBr = []byte("br")
strDeflate = []byte("deflate")
strKeepAlive = []byte("keep-alive")
strUpgrade = []byte("Upgrade")
strChunked = []byte("chunked")
strIdentity = []byte("identity")
str100Continue = []byte("100-continue")
strPostArgsContentType = []byte("application/x-www-form-urlencoded")
strDefaultContentType = []byte("application/octet-stream")
strMultipartFormData = []byte("multipart/form-data")
strBoundary = []byte("boundary")
strBytes = []byte("bytes")
strBasicSpace = []byte("Basic ")
strApplicationSlash = []byte("application/")
strImageSVG = []byte("image/svg")
strImageIcon = []byte("image/x-icon")
strFontSlash = []byte("font/")
strMultipartSlash = []byte("multipart/")
strTextSlash = []byte("text/")
)
fasthttp-1.31.0/tcpdialer.go 0000664 0000000 0000000 00000032276 14130360711 0015735 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"context"
"errors"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
)
// Dial dials the given TCP addr using tcp4.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func Dial(addr string) (net.Conn, error) {
return defaultDialer.Dial(addr)
}
// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(addr, timeout)
}
// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func DialDualStack(addr string) (net.Conn, error) {
return defaultDialer.DialDualStack(addr)
}
// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
// using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialDualStackTimeout(addr, timeout)
}
var (
defaultDialer = &TCPDialer{Concurrency: 1000}
)
// Resolver represents interface of the tcp resolver.
type Resolver interface {
LookupIPAddr(context.Context, string) (names []net.IPAddr, err error)
}
// TCPDialer contains options to control a group of Dial calls.
type TCPDialer struct {
// Concurrency controls the maximum number of concurrent Dails
// that can be performed using this object.
// Setting this to 0 means unlimited.
//
// WARNING: This can only be changed before the first Dial.
// Changes made after the first Dial will not affect anything.
Concurrency int
// LocalAddr is the local address to use when dialing an
// address.
// If nil, a local address is automatically chosen.
LocalAddr *net.TCPAddr
// This may be used to override DNS resolving policy, like this:
// var dialer = &fasthttp.TCPDialer{
// Resolver: &net.Resolver{
// PreferGo: true,
// StrictErrors: false,
// Dial: func (ctx context.Context, network, address string) (net.Conn, error) {
// d := net.Dialer{}
// return d.DialContext(ctx, "udp", "8.8.8.8:53")
// },
// },
// }
Resolver Resolver
// DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration)
DNSCacheDuration time.Duration
tcpAddrsMap sync.Map
concurrencyCh chan struct{}
once sync.Once
}
// Dial dials the given TCP addr using tcp4.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
return d.dial(addr, false, DefaultDialTimeout)
}
// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, false, timeout)
}
// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
return d.dial(addr, true, DefaultDialTimeout)
}
// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
// using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// * foobar.baz:443
// * foo.bar:80
// * aaa.com:8080
func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, true, timeout)
}
func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) {
d.once.Do(func() {
if d.Concurrency > 0 {
d.concurrencyCh = make(chan struct{}, d.Concurrency)
}
if d.DNSCacheDuration == 0 {
d.DNSCacheDuration = DefaultDNSCacheDuration
}
go d.tcpAddrsClean()
})
addrs, idx, err := d.getTCPAddrs(addr, dualStack)
if err != nil {
return nil, err
}
network := "tcp4"
if dualStack {
network = "tcp"
}
var conn net.Conn
n := uint32(len(addrs))
deadline := time.Now().Add(timeout)
for n > 0 {
conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh)
if err == nil {
return conn, nil
}
if err == ErrDialTimeout {
return nil, err
}
idx++
n--
}
return nil, err
}
func (d *TCPDialer) tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) {
timeout := -time.Since(deadline)
if timeout <= 0 {
return nil, ErrDialTimeout
}
if concurrencyCh != nil {
select {
case concurrencyCh <- struct{}{}:
default:
tc := AcquireTimer(timeout)
isTimeout := false
select {
case concurrencyCh <- struct{}{}:
case <-tc.C:
isTimeout = true
}
ReleaseTimer(tc)
if isTimeout {
return nil, ErrDialTimeout
}
}
defer func() { <-concurrencyCh }()
}
dialer := net.Dialer{}
if d.LocalAddr != nil {
dialer.LocalAddr = d.LocalAddr
}
ctx, cancel_ctx := context.WithDeadline(context.Background(), deadline)
defer cancel_ctx()
conn, err := dialer.DialContext(ctx, network, addr.String())
if err != nil && ctx.Err() == context.DeadlineExceeded {
return nil, ErrDialTimeout
}
return conn, err
}
// ErrDialTimeout is returned when TCP dialing is timed out.
var ErrDialTimeout = errors.New("dialing to the given TCP address timed out")
// DefaultDialTimeout is timeout used by Dial and DialDualStack
// for establishing TCP connections.
const DefaultDialTimeout = 3 * time.Second
type tcpAddrEntry struct {
addrs []net.TCPAddr
addrsIdx uint32
resolveTime time.Time
pending bool
}
// DefaultDNSCacheDuration is the duration for caching resolved TCP addresses
// by Dial* functions.
const DefaultDNSCacheDuration = time.Minute
func (d *TCPDialer) tcpAddrsClean() {
expireDuration := 2 * d.DNSCacheDuration
for {
time.Sleep(time.Second)
t := time.Now()
d.tcpAddrsMap.Range(func(k, v interface{}) bool {
if e, ok := v.(*tcpAddrEntry); ok && t.Sub(e.resolveTime) > expireDuration {
d.tcpAddrsMap.Delete(k)
}
return true
})
}
}
func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok := item.(*tcpAddrEntry)
if exist && ok && e != nil && !e.pending && time.Since(e.resolveTime) > d.DNSCacheDuration {
e.pending = true
e = nil
}
if e == nil {
addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver)
if err != nil {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok = item.(*tcpAddrEntry)
if exist && ok && e != nil && e.pending {
e.pending = false
}
return nil, 0, err
}
e = &tcpAddrEntry{
addrs: addrs,
resolveTime: time.Now(),
}
d.tcpAddrsMap.Store(addr, e)
}
idx := atomic.AddUint32(&e.addrsIdx, 1)
return e.addrs, idx, nil
}
func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver) ([]net.TCPAddr, error) {
host, portS, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portS)
if err != nil {
return nil, err
}
if resolver == nil {
resolver = net.DefaultResolver
}
ctx := context.Background()
ipaddrs, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
n := len(ipaddrs)
addrs := make([]net.TCPAddr, 0, n)
for i := 0; i < n; i++ {
ip := ipaddrs[i]
if !dualStack && ip.IP.To4() == nil {
continue
}
addrs = append(addrs, net.TCPAddr{
IP: ip.IP,
Port: port,
Zone: ip.Zone,
})
}
if len(addrs) == 0 {
return nil, errNoDNSEntries
}
return addrs, nil
}
var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack")
fasthttp-1.31.0/testdata/ 0000775 0000000 0000000 00000000000 14130360711 0015236 5 ustar 00root root 0000000 0000000 fasthttp-1.31.0/testdata/test.png 0000664 0000000 0000000 00000000001 14130360711 0016712 0 ustar 00root root 0000000 0000000
fasthttp-1.31.0/timer.go 0000664 0000000 0000000 00000002310 14130360711 0015070 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"sync"
"time"
)
func initTimer(t *time.Timer, timeout time.Duration) *time.Timer {
if t == nil {
return time.NewTimer(timeout)
}
if t.Reset(timeout) {
panic("BUG: active timer trapped into initTimer()")
}
return t
}
func stopTimer(t *time.Timer) {
if !t.Stop() {
// Collect possibly added time from the channel
// if timer has been stopped and nobody collected its' value.
select {
case <-t.C:
default:
}
}
}
// AcquireTimer returns a time.Timer from the pool and updates it to
// send the current time on its channel after at least timeout.
//
// The returned Timer may be returned to the pool with ReleaseTimer
// when no longer needed. This allows reducing GC load.
func AcquireTimer(timeout time.Duration) *time.Timer {
v := timerPool.Get()
if v == nil {
return time.NewTimer(timeout)
}
t := v.(*time.Timer)
initTimer(t, timeout)
return t
}
// ReleaseTimer returns the time.Timer acquired via AcquireTimer to the pool
// and prevents the Timer from firing.
//
// Do not access the released time.Timer or read from it's channel otherwise
// data races may occur.
func ReleaseTimer(t *time.Timer) {
stopTimer(t)
timerPool.Put(t)
}
var timerPool sync.Pool
fasthttp-1.31.0/tls.go 0000664 0000000 0000000 00000002647 14130360711 0014567 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
)
// GenerateTestCertificate generates a test certificate and private key based on the given host.
func GenerateTestCertificate(host string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"fasthttp test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
SignatureAlgorithm: x509.SHA256WithRSA,
DNSNames: []string{host},
BasicConstraintsValid: true,
IsCA: true,
}
certBytes, err := x509.CreateCertificate(
rand.Reader, cert, cert, &priv.PublicKey, priv,
)
p := pem.EncodeToMemory(
&pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
},
)
b := pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
},
)
return b, p, err
}
fasthttp-1.31.0/uri.go 0000664 0000000 0000000 00000053354 14130360711 0014565 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"sync"
)
// AcquireURI returns an empty URI instance from the pool.
//
// Release the URI with ReleaseURI after the URI is no longer needed.
// This allows reducing GC load.
func AcquireURI() *URI {
return uriPool.Get().(*URI)
}
// ReleaseURI releases the URI acquired via AcquireURI.
//
// The released URI mustn't be used after releasing it, otherwise data races
// may occur.
func ReleaseURI(u *URI) {
u.Reset()
uriPool.Put(u)
}
var uriPool = &sync.Pool{
New: func() interface{} {
return &URI{}
},
}
// URI represents URI :) .
//
// It is forbidden copying URI instances. Create new instance and use CopyTo
// instead.
//
// URI instance MUST NOT be used from concurrently running goroutines.
type URI struct {
noCopy noCopy //nolint:unused,structcheck
pathOriginal []byte
scheme []byte
path []byte
queryString []byte
hash []byte
host []byte
queryArgs Args
parsedQueryArgs bool
// Path values are sent as-is without normalization
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
fullURI []byte
requestURI []byte
username []byte
password []byte
}
// CopyTo copies uri contents to dst.
func (u *URI) CopyTo(dst *URI) {
dst.Reset()
dst.pathOriginal = append(dst.pathOriginal[:0], u.pathOriginal...)
dst.scheme = append(dst.scheme[:0], u.scheme...)
dst.path = append(dst.path[:0], u.path...)
dst.queryString = append(dst.queryString[:0], u.queryString...)
dst.hash = append(dst.hash[:0], u.hash...)
dst.host = append(dst.host[:0], u.host...)
dst.username = append(dst.username[:0], u.username...)
dst.password = append(dst.password[:0], u.password...)
u.queryArgs.CopyTo(&dst.queryArgs)
dst.parsedQueryArgs = u.parsedQueryArgs
dst.DisablePathNormalizing = u.DisablePathNormalizing
// fullURI and requestURI shouldn't be copied, since they are created
// from scratch on each FullURI() and RequestURI() call.
}
// Hash returns URI hash, i.e. qwe of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Hash() []byte {
return u.hash
}
// SetHash sets URI hash.
func (u *URI) SetHash(hash string) {
u.hash = append(u.hash[:0], hash...)
}
// SetHashBytes sets URI hash.
func (u *URI) SetHashBytes(hash []byte) {
u.hash = append(u.hash[:0], hash...)
}
// Username returns URI username
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Username() []byte {
return u.username
}
// SetUsername sets URI username.
func (u *URI) SetUsername(username string) {
u.username = append(u.username[:0], username...)
}
// SetUsernameBytes sets URI username.
func (u *URI) SetUsernameBytes(username []byte) {
u.username = append(u.username[:0], username...)
}
// Password returns URI password
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Password() []byte {
return u.password
}
// SetPassword sets URI password.
func (u *URI) SetPassword(password string) {
u.password = append(u.password[:0], password...)
}
// SetPasswordBytes sets URI password.
func (u *URI) SetPasswordBytes(password []byte) {
u.password = append(u.password[:0], password...)
}
// QueryString returns URI query string,
// i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned bytes are valid until the next URI method call.
func (u *URI) QueryString() []byte {
return u.queryString
}
// SetQueryString sets URI query string.
func (u *URI) SetQueryString(queryString string) {
u.queryString = append(u.queryString[:0], queryString...)
u.parsedQueryArgs = false
}
// SetQueryStringBytes sets URI query string.
func (u *URI) SetQueryStringBytes(queryString []byte) {
u.queryString = append(u.queryString[:0], queryString...)
u.parsedQueryArgs = false
}
// Path returns URI path, i.e. /foo/bar of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned path is always urldecoded and normalized,
// i.e. '//f%20obar/baz/../zzz' becomes '/f obar/zzz'.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Path() []byte {
path := u.path
if len(path) == 0 {
path = strSlash
}
return path
}
// SetPath sets URI path.
func (u *URI) SetPath(path string) {
u.pathOriginal = append(u.pathOriginal[:0], path...)
u.path = normalizePath(u.path, u.pathOriginal)
}
// SetPathBytes sets URI path.
func (u *URI) SetPathBytes(path []byte) {
u.pathOriginal = append(u.pathOriginal[:0], path...)
u.path = normalizePath(u.path, u.pathOriginal)
}
// PathOriginal returns the original path from requestURI passed to URI.Parse().
//
// The returned bytes are valid until the next URI method call.
func (u *URI) PathOriginal() []byte {
return u.pathOriginal
}
// Scheme returns URI scheme, i.e. http of http://aaa.com/foo/bar?baz=123#qwe .
//
// Returned scheme is always lowercased.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Scheme() []byte {
scheme := u.scheme
if len(scheme) == 0 {
scheme = strHTTP
}
return scheme
}
// SetScheme sets URI scheme, i.e. http, https, ftp, etc.
func (u *URI) SetScheme(scheme string) {
u.scheme = append(u.scheme[:0], scheme...)
lowercaseBytes(u.scheme)
}
// SetSchemeBytes sets URI scheme, i.e. http, https, ftp, etc.
func (u *URI) SetSchemeBytes(scheme []byte) {
u.scheme = append(u.scheme[:0], scheme...)
lowercaseBytes(u.scheme)
}
// Reset clears uri.
func (u *URI) Reset() {
u.pathOriginal = u.pathOriginal[:0]
u.scheme = u.scheme[:0]
u.path = u.path[:0]
u.queryString = u.queryString[:0]
u.hash = u.hash[:0]
u.username = u.username[:0]
u.password = u.password[:0]
u.host = u.host[:0]
u.queryArgs.Reset()
u.parsedQueryArgs = false
u.DisablePathNormalizing = false
// There is no need in u.fullURI = u.fullURI[:0], since full uri
// is calculated on each call to FullURI().
// There is no need in u.requestURI = u.requestURI[:0], since requestURI
// is calculated on each call to RequestURI().
}
// Host returns host part, i.e. aaa.com of http://aaa.com/foo/bar?baz=123#qwe .
//
// Host is always lowercased.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Host() []byte {
return u.host
}
// SetHost sets host for the uri.
func (u *URI) SetHost(host string) {
u.host = append(u.host[:0], host...)
lowercaseBytes(u.host)
}
// SetHostBytes sets host for the uri.
func (u *URI) SetHostBytes(host []byte) {
u.host = append(u.host[:0], host...)
lowercaseBytes(u.host)
}
var (
ErrorInvalidURI = errors.New("invalid uri")
)
// Parse initializes URI from the given host and uri.
//
// host may be nil. In this case uri must contain fully qualified uri,
// i.e. with scheme and host. http is assumed if scheme is omitted.
//
// uri may contain e.g. RequestURI without scheme and host if host is non-empty.
func (u *URI) Parse(host, uri []byte) error {
return u.parse(host, uri, false)
}
func (u *URI) parse(host, uri []byte, isTLS bool) error {
u.Reset()
if stringContainsCTLByte(uri) {
return ErrorInvalidURI
}
if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) {
scheme, newHost, newURI := splitHostURI(host, uri)
u.scheme = append(u.scheme, scheme...)
lowercaseBytes(u.scheme)
host = newHost
uri = newURI
}
if isTLS {
u.scheme = append(u.scheme[:0], strHTTPS...)
}
if n := bytes.IndexByte(host, '@'); n >= 0 {
auth := host[:n]
host = host[n+1:]
if n := bytes.IndexByte(auth, ':'); n >= 0 {
u.username = append(u.username[:0], auth[:n]...)
u.password = append(u.password[:0], auth[n+1:]...)
} else {
u.username = append(u.username[:0], auth...)
u.password = u.password[:0]
}
}
u.host = append(u.host, host...)
if parsedHost, err := parseHost(u.host); err != nil {
return err
} else {
u.host = parsedHost
}
lowercaseBytes(u.host)
b := uri
queryIndex := bytes.IndexByte(b, '?')
fragmentIndex := bytes.IndexByte(b, '#')
// Ignore query in fragment part
if fragmentIndex >= 0 && queryIndex > fragmentIndex {
queryIndex = -1
}
if queryIndex < 0 && fragmentIndex < 0 {
u.pathOriginal = append(u.pathOriginal, b...)
u.path = normalizePath(u.path, u.pathOriginal)
return nil
}
if queryIndex >= 0 {
// Path is everything up to the start of the query
u.pathOriginal = append(u.pathOriginal, b[:queryIndex]...)
u.path = normalizePath(u.path, u.pathOriginal)
if fragmentIndex < 0 {
u.queryString = append(u.queryString, b[queryIndex+1:]...)
} else {
u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...)
u.hash = append(u.hash, b[fragmentIndex+1:]...)
}
return nil
}
// fragmentIndex >= 0 && queryIndex < 0
// Path is up to the start of fragment
u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...)
u.path = normalizePath(u.path, u.pathOriginal)
u.hash = append(u.hash, b[fragmentIndex+1:]...)
return nil
}
// parseHost parses host as an authority without user
// information. That is, as host[:port].
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L619
//
// The host is parsed and unescaped in place overwriting the contents of the host parameter.
func parseHost(host []byte) ([]byte, error) {
if len(host) > 0 && host[0] == '[' {
// Parse an IP-Literal in RFC 3986 and RFC 6874.
// E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80".
i := bytes.LastIndexByte(host, ']')
if i < 0 {
return nil, errors.New("missing ']' in host")
}
colonPort := host[i+1:]
if !validOptionalPort(colonPort) {
return nil, fmt.Errorf("invalid port %q after host", colonPort)
}
// RFC 6874 defines that %25 (%-encoded percent) introduces
// the zone identifier, and the zone identifier can use basically
// any %-encoding it likes. That's different from the host, which
// can only %-encode non-ASCII bytes.
// We do impose some restrictions on the zone, to avoid stupidity
// like newlines.
zone := bytes.Index(host[:i], []byte("%25"))
if zone >= 0 {
host1, err := unescape(host[:zone], encodeHost)
if err != nil {
return nil, err
}
host2, err := unescape(host[zone:i], encodeZone)
if err != nil {
return nil, err
}
host3, err := unescape(host[i:], encodeHost)
if err != nil {
return nil, err
}
return append(host1, append(host2, host3...)...), nil
}
} else if i := bytes.LastIndexByte(host, ':'); i != -1 {
colonPort := host[i:]
if !validOptionalPort(colonPort) {
return nil, fmt.Errorf("invalid port %q after host", colonPort)
}
}
var err error
if host, err = unescape(host, encodeHost); err != nil {
return nil, err
}
return host, nil
}
type encoding int
const (
encodeHost encoding = 1 + iota
encodeZone
)
type EscapeError string
func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
type InvalidHostError string
func (e InvalidHostError) Error() string {
return "invalid character " + strconv.Quote(string(e)) + " in host name"
}
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L199
//
// Unescapes in place overwriting the contents of s and returning it.
func unescape(s []byte, mode encoding) ([]byte, error) {
// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
switch s[i] {
case '%':
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return nil, EscapeError(s)
}
// Per https://tools.ietf.org/html/rfc3986#page-21
// in the host component %-encoding can only be used
// for non-ASCII bytes.
// But https://tools.ietf.org/html/rfc6874#section-2
// introduces %25 being allowed to escape a percent sign
// in IPv6 scoped-address literals. Yay.
if mode == encodeHost && unhex(s[i+1]) < 8 && !bytes.Equal(s[i:i+3], []byte("%25")) {
return nil, EscapeError(s[i : i+3])
}
if mode == encodeZone {
// RFC 6874 says basically "anything goes" for zone identifiers
// and that even non-ASCII can be redundantly escaped,
// but it seems prudent to restrict %-escaped bytes here to those
// that are valid host name bytes in their unescaped form.
// That is, you can use escaping in the zone identifier but not
// to introduce bytes you couldn't just write directly.
// But Windows puts spaces here! Yay.
v := unhex(s[i+1])<<4 | unhex(s[i+2])
if !bytes.Equal(s[i:i+3], []byte("%25")) && v != ' ' && shouldEscape(v, encodeHost) {
return nil, EscapeError(s[i : i+3])
}
}
i += 3
default:
if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
return nil, InvalidHostError(s[i : i+1])
}
i++
}
}
if n == 0 {
return s, nil
}
t := s[:0]
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
t = append(t, unhex(s[i+1])<<4|unhex(s[i+2]))
i += 2
default:
t = append(t, s[i])
}
}
return t, nil
}
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
// Please be informed that for now shouldEscape does not check all
// reserved characters correctly. See golang.org/issue/5684.
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L100
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
if mode == encodeHost || mode == encodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
// We add [ ] because we include [ipv6]:port as part of host.
// We add < > because they're the only characters left that
// we could possibly allow, and Parse will reject them if we
// escape them (because hosts can't use %-encoding for
// ASCII bytes).
switch c {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
if c == '-' || c == '_' || c == '.' || c == '~' { // §2.3 Unreserved characters (mark)
return false
}
// Everything else must be escaped.
return true
}
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/
func validOptionalPort(port []byte) bool {
if len(port) == 0 {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
func normalizePath(dst, src []byte) []byte {
dst = dst[:0]
dst = addLeadingSlash(dst, src)
dst = decodeArgAppendNoPlus(dst, src)
// remove duplicate slashes
b := dst
bSize := len(b)
for {
n := bytes.Index(b, strSlashSlash)
if n < 0 {
break
}
b = b[n:]
copy(b, b[1:])
b = b[:len(b)-1]
bSize--
}
dst = dst[:bSize]
// remove /./ parts
b = dst
for {
n := bytes.Index(b, strSlashDotSlash)
if n < 0 {
break
}
nn := n + len(strSlashDotSlash) - 1
copy(b[n:], b[nn:])
b = b[:len(b)-nn+n]
}
// remove /foo/../ parts
for {
n := bytes.Index(b, strSlashDotDotSlash)
if n < 0 {
break
}
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
nn = 0
}
n += len(strSlashDotDotSlash) - 1
copy(b[nn:], b[n:])
b = b[:len(b)-n+nn]
}
// remove trailing /foo/..
n := bytes.LastIndex(b, strSlashDotDot)
if n >= 0 && n+len(strSlashDotDot) == len(b) {
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
return append(dst[:0], strSlash...)
}
b = b[:nn+1]
}
return b
}
// RequestURI returns RequestURI - i.e. URI without Scheme and Host.
func (u *URI) RequestURI() []byte {
var dst []byte
if u.DisablePathNormalizing {
dst = append(u.requestURI[:0], u.PathOriginal()...)
} else {
dst = appendQuotedPath(u.requestURI[:0], u.Path())
}
if u.parsedQueryArgs && u.queryArgs.Len() > 0 {
dst = append(dst, '?')
dst = u.queryArgs.AppendBytes(dst)
} else if len(u.queryString) > 0 {
dst = append(dst, '?')
dst = append(dst, u.queryString...)
}
u.requestURI = dst
return u.requestURI
}
// LastPathSegment returns the last part of uri path after '/'.
//
// Examples:
//
// * For /foo/bar/baz.html path returns baz.html.
// * For /foo/bar/ returns empty byte slice.
// * For /foobar.js returns foobar.js.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) LastPathSegment() []byte {
path := u.Path()
n := bytes.LastIndexByte(path, '/')
if n < 0 {
return path
}
return path[n+1:]
}
// Update updates uri.
//
// The following newURI types are accepted:
//
// * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
// uri is replaced by newURI.
// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
// the original scheme is preserved.
// * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
// of the original uri is replaced.
// * Relative path, i.e. xx?yy=abc . In this case the original RequestURI
// is updated according to the new relative path.
func (u *URI) Update(newURI string) {
u.UpdateBytes(s2b(newURI))
}
// UpdateBytes updates uri.
//
// The following newURI types are accepted:
//
// * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
// uri is replaced by newURI.
// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
// the original scheme is preserved.
// * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
// of the original uri is replaced.
// * Relative path, i.e. xx?yy=abc . In this case the original RequestURI
// is updated according to the new relative path.
func (u *URI) UpdateBytes(newURI []byte) {
u.requestURI = u.updateBytes(newURI, u.requestURI)
}
func (u *URI) updateBytes(newURI, buf []byte) []byte {
if len(newURI) == 0 {
return buf
}
n := bytes.Index(newURI, strSlashSlash)
if n >= 0 {
// absolute uri
var b [32]byte
schemeOriginal := b[:0]
if len(u.scheme) > 0 {
schemeOriginal = append([]byte(nil), u.scheme...)
}
if err := u.Parse(nil, newURI); err != nil {
return nil
}
if len(schemeOriginal) > 0 && len(u.scheme) == 0 {
u.scheme = append(u.scheme[:0], schemeOriginal...)
}
return buf
}
if newURI[0] == '/' {
// uri without host
buf = u.appendSchemeHost(buf[:0])
buf = append(buf, newURI...)
if err := u.Parse(nil, buf); err != nil {
return nil
}
return buf
}
// relative path
switch newURI[0] {
case '?':
// query string only update
u.SetQueryStringBytes(newURI[1:])
return append(buf[:0], u.FullURI()...)
case '#':
// update only hash
u.SetHashBytes(newURI[1:])
return append(buf[:0], u.FullURI()...)
default:
// update the last path part after the slash
path := u.Path()
n = bytes.LastIndexByte(path, '/')
if n < 0 {
panic(fmt.Sprintf("BUG: path must contain at least one slash: %s %s", u.Path(), newURI))
}
buf = u.appendSchemeHost(buf[:0])
buf = appendQuotedPath(buf, path[:n+1])
buf = append(buf, newURI...)
if err := u.Parse(nil, buf); err != nil {
return nil
}
return buf
}
}
// FullURI returns full uri in the form {Scheme}://{Host}{RequestURI}#{Hash}.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) FullURI() []byte {
u.fullURI = u.AppendBytes(u.fullURI[:0])
return u.fullURI
}
// AppendBytes appends full uri to dst and returns the extended dst.
func (u *URI) AppendBytes(dst []byte) []byte {
dst = u.appendSchemeHost(dst)
dst = append(dst, u.RequestURI()...)
if len(u.hash) > 0 {
dst = append(dst, '#')
dst = append(dst, u.hash...)
}
return dst
}
func (u *URI) appendSchemeHost(dst []byte) []byte {
dst = append(dst, u.Scheme()...)
dst = append(dst, strColonSlashSlash...)
return append(dst, u.Host()...)
}
// WriteTo writes full uri to w.
//
// WriteTo implements io.WriterTo interface.
func (u *URI) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(u.FullURI())
return int64(n), err
}
// String returns full uri.
func (u *URI) String() string {
return string(u.FullURI())
}
func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) {
n := bytes.Index(uri, strSlashSlash)
if n < 0 {
return strHTTP, host, uri
}
scheme := uri[:n]
if bytes.IndexByte(scheme, '/') >= 0 {
return strHTTP, host, uri
}
if len(scheme) > 0 && scheme[len(scheme)-1] == ':' {
scheme = scheme[:len(scheme)-1]
}
n += len(strSlashSlash)
uri = uri[n:]
n = bytes.IndexByte(uri, '/')
nq := bytes.IndexByte(uri, '?')
if nq >= 0 && nq < n {
// A hack for urls like foobar.com?a=b/xyz
n = nq
} else if n < 0 {
// A hack for bogus urls like foobar.com?a=b without
// slash after host.
if nq >= 0 {
return scheme, uri[:nq], uri[nq:]
}
return scheme, uri, strSlash
}
return scheme, uri[:n], uri[n:]
}
// QueryArgs returns query args.
//
// The returned args are valid until the next URI method call.
func (u *URI) QueryArgs() *Args {
u.parseQueryArgs()
return &u.queryArgs
}
func (u *URI) parseQueryArgs() {
if u.parsedQueryArgs {
return
}
u.queryArgs.ParseBytes(u.queryString)
u.parsedQueryArgs = true
}
// stringContainsCTLByte reports whether s contains any ASCII control character.
func stringContainsCTLByte(s []byte) bool {
for i := 0; i < len(s); i++ {
b := s[i]
if b < ' ' || b == 0x7f {
return true
}
}
return false
}
fasthttp-1.31.0/uri_test.go 0000664 0000000 0000000 00000034234 14130360711 0015620 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"bytes"
"fmt"
"reflect"
"testing"
"time"
)
func TestURICopyToQueryArgs(t *testing.T) {
t.Parallel()
var u URI
a := u.QueryArgs()
a.Set("foo", "bar")
var u1 URI
u.CopyTo(&u1)
a1 := u1.QueryArgs()
if string(a1.Peek("foo")) != "bar" {
t.Fatalf("unexpected query args value %q. Expecting %q", a1.Peek("foo"), "bar")
}
}
func TestURIAcquireReleaseSequential(t *testing.T) {
t.Parallel()
testURIAcquireRelease(t)
}
func TestURIAcquireReleaseConcurrent(t *testing.T) {
t.Parallel()
ch := make(chan struct{}, 10)
for i := 0; i < 10; i++ {
go func() {
testURIAcquireRelease(t)
ch <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testURIAcquireRelease(t *testing.T) {
for i := 0; i < 10; i++ {
u := AcquireURI()
host := fmt.Sprintf("host.%d.com", i*23)
path := fmt.Sprintf("/foo/%d/bar", i*17)
queryArgs := "?foo=bar&baz=aass"
u.Parse([]byte(host), []byte(path+queryArgs)) //nolint:errcheck
if string(u.Host()) != host {
t.Fatalf("unexpected host %q. Expecting %q", u.Host(), host)
}
if string(u.Path()) != path {
t.Fatalf("unexpected path %q. Expecting %q", u.Path(), path)
}
ReleaseURI(u)
}
}
func TestURILastPathSegment(t *testing.T) {
t.Parallel()
testURILastPathSegment(t, "", "")
testURILastPathSegment(t, "/", "")
testURILastPathSegment(t, "/foo/bar/", "")
testURILastPathSegment(t, "/foobar.js", "foobar.js")
testURILastPathSegment(t, "/foo/bar/baz.html", "baz.html")
}
func testURILastPathSegment(t *testing.T, path, expectedSegment string) {
var u URI
u.SetPath(path)
segment := u.LastPathSegment()
if string(segment) != expectedSegment {
t.Fatalf("unexpected last path segment for path %q: %q. Expecting %q", path, segment, expectedSegment)
}
}
func TestURIPathEscape(t *testing.T) {
t.Parallel()
testURIPathEscape(t, "/foo/bar", "/foo/bar")
testURIPathEscape(t, "/f_o-o=b:ar,b.c&q", "/f_o-o=b:ar,b.c&q")
testURIPathEscape(t, "/aa?bb.тест~qq", "/aa%3Fbb.%D1%82%D0%B5%D1%81%D1%82~qq")
}
func testURIPathEscape(t *testing.T, path, expectedRequestURI string) {
var u URI
u.SetPath(path)
requestURI := u.RequestURI()
if string(requestURI) != expectedRequestURI {
t.Fatalf("unexpected requestURI %q. Expecting %q. path %q", requestURI, expectedRequestURI, path)
}
}
func TestURIUpdate(t *testing.T) {
t.Parallel()
// full uri
testURIUpdate(t, "http://foo.bar/baz?aaa=22#aaa", "https://aa.com/bb", "https://aa.com/bb")
// empty uri
testURIUpdate(t, "http://aaa.com/aaa.html?234=234#add", "", "http://aaa.com/aaa.html?234=234#add")
// request uri
testURIUpdate(t, "ftp://aaa/xxx/yyy?aaa=bb#aa", "/boo/bar?xx", "ftp://aaa/boo/bar?xx")
// relative uri
testURIUpdate(t, "http://foo.bar/baz/xxx.html?aaa=22#aaa", "bb.html?xx=12#pp", "http://foo.bar/baz/bb.html?xx=12#pp")
testURIUpdate(t, "http://xx/a/b/c/d", "../qwe/p?zx=34", "http://xx/a/b/qwe/p?zx=34")
testURIUpdate(t, "https://qqq/aaa.html?foo=bar", "?baz=434&aaa#xcv", "https://qqq/aaa.html?baz=434&aaa#xcv")
testURIUpdate(t, "http://foo.bar/baz", "~a/%20b=c,тест?йцу=ке", "http://foo.bar/~a/%20b=c,%D1%82%D0%B5%D1%81%D1%82?йцу=ке")
testURIUpdate(t, "http://foo.bar/baz", "/qwe#fragment", "http://foo.bar/qwe#fragment")
testURIUpdate(t, "http://foobar/baz/xxx", "aaa.html#bb?cc=dd&ee=dfd", "http://foobar/baz/aaa.html#bb?cc=dd&ee=dfd")
// hash
testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment")
// uri without scheme
testURIUpdate(t, "https://foo.bar/baz", "//aaa.bbb/cc?dd", "https://aaa.bbb/cc?dd")
testURIUpdate(t, "http://foo.bar/baz", "//aaa.bbb/cc?dd", "http://aaa.bbb/cc?dd")
}
func testURIUpdate(t *testing.T, base, update, result string) {
var u URI
u.Parse(nil, []byte(base)) //nolint:errcheck
u.Update(update)
s := u.String()
if s != result {
t.Fatalf("unexpected result %q. Expecting %q. base=%q, update=%q", s, result, base, update)
}
}
func TestURIPathNormalize(t *testing.T) {
t.Parallel()
var u URI
// double slash
testURIPathNormalize(t, &u, "/aa//bb", "/aa/bb")
// triple slash
testURIPathNormalize(t, &u, "/x///y/", "/x/y/")
// multi slashes
testURIPathNormalize(t, &u, "/abc//de///fg////", "/abc/de/fg/")
// encoded slashes
testURIPathNormalize(t, &u, "/xxxx%2fyyy%2f%2F%2F", "/xxxx/yyy/")
// dotdot
testURIPathNormalize(t, &u, "/aaa/..", "/")
// dotdot with trailing slash
testURIPathNormalize(t, &u, "/xxx/yyy/../", "/xxx/")
// multi dotdots
testURIPathNormalize(t, &u, "/aaa/bbb/ccc/../../ddd", "/aaa/ddd")
// dotdots separated by other data
testURIPathNormalize(t, &u, "/a/b/../c/d/../e/..", "/a/c/")
// too many dotdots
testURIPathNormalize(t, &u, "/aaa/../../../../xxx", "/xxx")
testURIPathNormalize(t, &u, "/../../../../../..", "/")
testURIPathNormalize(t, &u, "/../../../../../../", "/")
// encoded dotdots
testURIPathNormalize(t, &u, "/aaa%2Fbbb%2F%2E.%2Fxxx", "/aaa/xxx")
// double slash with dotdots
testURIPathNormalize(t, &u, "/aaa////..//b", "/b")
// fake dotdot
testURIPathNormalize(t, &u, "/aaa/..bbb/ccc/..", "/aaa/..bbb/")
// single dot
testURIPathNormalize(t, &u, "/a/./b/././c/./d.html", "/a/b/c/d.html")
testURIPathNormalize(t, &u, "./foo/", "/foo/")
testURIPathNormalize(t, &u, "./../.././../../aaa/bbb/../../../././../", "/")
testURIPathNormalize(t, &u, "./a/./.././../b/./foo.html", "/b/foo.html")
}
func testURIPathNormalize(t *testing.T, u *URI, requestURI, expectedPath string) {
u.Parse(nil, []byte(requestURI)) //nolint:errcheck
if string(u.Path()) != expectedPath {
t.Fatalf("Unexpected path %q. Expected %q. requestURI=%q", u.Path(), expectedPath, requestURI)
}
}
func TestURINoNormalization(t *testing.T) {
t.Parallel()
var u URI
irregularPath := "/aaa%2Fbbb%2F%2E.%2Fxxx"
u.Parse(nil, []byte(irregularPath)) //nolint:errcheck
u.DisablePathNormalizing = true
if string(u.RequestURI()) != irregularPath {
t.Fatalf("Unexpected path %q. Expected %q.", u.Path(), irregularPath)
}
}
func TestURICopyTo(t *testing.T) {
t.Parallel()
var u URI
var copyU URI
u.CopyTo(©U)
if !reflect.DeepEqual(u, copyU) { //nolint:govet
t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet
}
u.UpdateBytes([]byte("https://google.com/foo?bar=baz&baraz#qqqq"))
u.CopyTo(©U)
if !reflect.DeepEqual(u, copyU) { //nolint:govet
t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet
}
}
func TestURIFullURI(t *testing.T) {
t.Parallel()
var args Args
// empty scheme, path and hash
testURIFullURI(t, "", "foobar.com", "", "", &args, "http://foobar.com/")
// empty scheme and hash
testURIFullURI(t, "", "aa.com", "/foo/bar", "", &args, "http://aa.com/foo/bar")
// empty hash
testURIFullURI(t, "fTP", "XXx.com", "/foo", "", &args, "ftp://xxx.com/foo")
// empty args
testURIFullURI(t, "https", "xx.com", "/", "aaa", &args, "https://xx.com/#aaa")
// non-empty args and non-ASCII path
args.Set("foo", "bar")
args.Set("xxx", "йух")
testURIFullURI(t, "", "xxx.com", "/тест123", "2er", &args, "http://xxx.com/%D1%82%D0%B5%D1%81%D1%82123?foo=bar&xxx=%D0%B9%D1%83%D1%85#2er")
// test with empty args and non-empty query string
var u URI
u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) //nolint:errcheck
uri := u.FullURI()
expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq"
if string(uri) != expectedURI {
t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI)
}
}
func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, expectedURI string) {
var u URI
u.SetScheme(scheme)
u.SetHost(host)
u.SetPath(path)
u.SetHash(hash)
args.CopyTo(u.QueryArgs())
uri := u.FullURI()
if string(uri) != expectedURI {
t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI)
}
}
func TestURIParseNilHost(t *testing.T) {
t.Parallel()
testURIParseScheme(t, "http://google.com/foo?bar#baz", "http", "google.com", "/foo?bar", "baz")
testURIParseScheme(t, "HTtP://google.com/", "http", "google.com", "/", "")
testURIParseScheme(t, "://google.com/xyz", "http", "google.com", "/xyz", "")
testURIParseScheme(t, "//google.com/foobar", "http", "google.com", "/foobar", "")
testURIParseScheme(t, "fTP://aaa.com", "ftp", "aaa.com", "/", "")
testURIParseScheme(t, "httPS://aaa.com", "https", "aaa.com", "/", "")
// missing slash after hostname
testURIParseScheme(t, "http://foobar.com?baz=111", "http", "foobar.com", "/?baz=111", "")
// slash in args
testURIParseScheme(t, "http://foobar.com?baz=111/222/xyz", "http", "foobar.com", "/?baz=111/222/xyz", "")
testURIParseScheme(t, "http://foobar.com?111/222/xyz", "http", "foobar.com", "/?111/222/xyz", "")
}
func testURIParseScheme(t *testing.T, uri, expectedScheme, expectedHost, expectedRequestURI, expectedHash string) {
var u URI
u.Parse(nil, []byte(uri)) //nolint:errcheck
if string(u.Scheme()) != expectedScheme {
t.Fatalf("Unexpected scheme %q. Expecting %q for uri %q", u.Scheme(), expectedScheme, uri)
}
if string(u.Host()) != expectedHost {
t.Fatalf("Unexepcted host %q. Expecting %q for uri %q", u.Host(), expectedHost, uri)
}
if string(u.RequestURI()) != expectedRequestURI {
t.Fatalf("Unexepcted requestURI %q. Expecting %q for uri %q", u.RequestURI(), expectedRequestURI, uri)
}
if string(u.hash) != expectedHash {
t.Fatalf("Unexepcted hash %q. Expecting %q for uri %q", u.hash, expectedHash, uri)
}
}
func TestURIParse(t *testing.T) {
t.Parallel()
var u URI
// no args
testURIParse(t, &u, "aaa", "sdfdsf",
"http://aaa/sdfdsf", "aaa", "/sdfdsf", "sdfdsf", "", "")
// args
testURIParse(t, &u, "xx", "/aa?ss",
"http://xx/aa?ss", "xx", "/aa", "/aa", "ss", "")
// args and hash
testURIParse(t, &u, "foobar.com", "/a.b.c?def=gkl#mnop",
"http://foobar.com/a.b.c?def=gkl#mnop", "foobar.com", "/a.b.c", "/a.b.c", "def=gkl", "mnop")
// '?' and '#' in hash
testURIParse(t, &u, "aaa.com", "/foo#bar?baz=aaa#bbb",
"http://aaa.com/foo#bar?baz=aaa#bbb", "aaa.com", "/foo", "/foo", "", "bar?baz=aaa#bbb")
// encoded path
testURIParse(t, &u, "aa.com", "/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf",
"http://aa.com/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "aa.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf")
// host in uppercase
testURIParse(t, &u, "FOObar.COM", "/bC?De=F#Gh",
"http://foobar.com/bC?De=F#Gh", "foobar.com", "/bC", "/bC", "De=F", "Gh")
// uri with hostname
testURIParse(t, &u, "xxx.com", "http://aaa.com/foo/bar?baz=aaa#ddd",
"http://aaa.com/foo/bar?baz=aaa#ddd", "aaa.com", "/foo/bar", "/foo/bar", "baz=aaa", "ddd")
testURIParse(t, &u, "xxx.com", "https://ab.com/f/b%20r?baz=aaa#ddd",
"https://ab.com/f/b%20r?baz=aaa#ddd", "ab.com", "/f/b r", "/f/b%20r", "baz=aaa", "ddd")
// no slash after hostname in uri
testURIParse(t, &u, "aaa.com", "http://google.com",
"http://google.com/", "google.com", "/", "/", "", "")
// uppercase hostname in uri
testURIParse(t, &u, "abc.com", "http://GoGLE.com/aaa",
"http://gogle.com/aaa", "gogle.com", "/aaa", "/aaa", "", "")
// http:// in query params
testURIParse(t, &u, "aaa.com", "/foo?bar=http://google.com",
"http://aaa.com/foo?bar=http://google.com", "aaa.com", "/foo", "/foo", "bar=http://google.com", "")
testURIParse(t, &u, "aaa.com", "//relative",
"http://aaa.com/relative", "aaa.com", "/relative", "//relative", "", "")
testURIParse(t, &u, "", "//aaa.com//absolute",
"http://aaa.com/absolute", "aaa.com", "/absolute", "//absolute", "", "")
testURIParse(t, &u, "", "//aaa.com\r\n\r\nGET x",
"http:///", "", "/", "", "", "")
testURIParse(t, &u, "", "http://[fe80::1%25en0]/",
"http://[fe80::1%en0]/", "[fe80::1%en0]", "/", "/", "", "")
testURIParse(t, &u, "", "http://[fe80::1%25en0]:8080/",
"http://[fe80::1%en0]:8080/", "[fe80::1%en0]:8080", "/", "/", "", "")
testURIParse(t, &u, "", "http://hello.世界.com/foo",
"http://hello.世界.com/foo", "hello.世界.com", "/foo", "/foo", "", "")
testURIParse(t, &u, "", "http://hello.%e4%b8%96%e7%95%8c.com/foo",
"http://hello.世界.com/foo", "hello.世界.com", "/foo", "/foo", "", "")
}
func testURIParse(t *testing.T, u *URI, host, uri,
expectedURI, expectedHost, expectedPath, expectedPathOriginal, expectedArgs, expectedHash string) {
u.Parse([]byte(host), []byte(uri)) //nolint:errcheck
if !bytes.Equal(u.FullURI(), []byte(expectedURI)) {
t.Fatalf("Unexpected uri %q. Expected %q. host=%q, uri=%q", u.FullURI(), expectedURI, host, uri)
}
if !bytes.Equal(u.Host(), []byte(expectedHost)) {
t.Fatalf("Unexpected host %q. Expected %q. host=%q, uri=%q", u.Host(), expectedHost, host, uri)
}
if !bytes.Equal(u.PathOriginal(), []byte(expectedPathOriginal)) {
t.Fatalf("Unexpected original path %q. Expected %q. host=%q, uri=%q", u.PathOriginal(), expectedPathOriginal, host, uri)
}
if !bytes.Equal(u.Path(), []byte(expectedPath)) {
t.Fatalf("Unexpected path %q. Expected %q. host=%q, uri=%q", u.Path(), expectedPath, host, uri)
}
if !bytes.Equal(u.QueryString(), []byte(expectedArgs)) {
t.Fatalf("Unexpected args %q. Expected %q. host=%q, uri=%q", u.QueryString(), expectedArgs, host, uri)
}
if !bytes.Equal(u.Hash(), []byte(expectedHash)) {
t.Fatalf("Unexpected hash %q. Expected %q. host=%q, uri=%q", u.Hash(), expectedHash, host, uri)
}
}
func TestURIWithQuerystringOverride(t *testing.T) {
t.Parallel()
var u URI
u.SetQueryString("q1=foo&q2=bar")
u.QueryArgs().Add("q3", "baz")
u.SetQueryString("q1=foo&q2=bar&q4=quux")
uriString := string(u.RequestURI())
if uriString != "/?q1=foo&q2=bar&q4=quux" {
t.Fatalf("Expected Querystring to be overridden but was %s ", uriString)
}
}
func TestInvalidUrl(t *testing.T) {
url := `https://.çèéà@&~!&:=\\/\"'~<>|+-*()[]{}%$;,¥&&$22|||<>< 4ly8lzjmoNx233AXELDtyaFQiiUH-fd8c-CnXUJVYnGIs4Uwr-bptom5GCnWtsGMQxeM2ZhoKE973eKgs2Sjh6RePnyaLpCi6SiNSLevcMoraARrp88L-SgtKqd-XHAtSI8hiPRiXPQmDIA4BGhSgoc0nfn1PoYuGKKmDcZ04tANRc3iz4aF4-A1UrO8bLHTH7MEJvzx.someqa.fr/A/?&QS_BEGIN<&8{b'Ob=p*f> QS_END`
u := AcquireURI()
defer ReleaseURI(u)
if err := u.Parse(nil, []byte(url)); err == nil {
t.Fail()
}
}
func TestNoOverwriteInput(t *testing.T) {
str := `//%AA`
url := []byte(str)
u := AcquireURI()
defer ReleaseURI(u)
if err := u.Parse(nil, url); err != nil {
t.Error(err)
}
if string(url) != str {
t.Error()
}
if u.String() != "http://\xaa/" {
t.Errorf("%q", u.String())
}
}
fasthttp-1.31.0/uri_timing_test.go 0000664 0000000 0000000 00000002266 14130360711 0017167 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"testing"
)
func BenchmarkURIParsePath(b *testing.B) {
benchmarkURIParse(b, "google.com", "/foo/bar")
}
func BenchmarkURIParsePathQueryString(b *testing.B) {
benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value")
}
func BenchmarkURIParsePathQueryStringHash(b *testing.B) {
benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value#hashstring")
}
func BenchmarkURIParseHostname(b *testing.B) {
benchmarkURIParse(b, "google.com", "http://foobar.com/foo/bar?query=string&other=value#hashstring")
}
func BenchmarkURIFullURI(b *testing.B) {
host := []byte("foobar.com")
requestURI := []byte("/foobar/baz?aaa=bbb&ccc=ddd")
uriLen := len(host) + len(requestURI) + 7
b.RunParallel(func(pb *testing.PB) {
var u URI
u.Parse(host, requestURI) //nolint:errcheck
for pb.Next() {
uri := u.FullURI()
if len(uri) != uriLen {
b.Fatalf("unexpected uri len %d. Expecting %d", len(uri), uriLen)
}
}
})
}
func benchmarkURIParse(b *testing.B, host, uri string) {
strHost, strURI := []byte(host), []byte(uri)
b.RunParallel(func(pb *testing.PB) {
var u URI
for pb.Next() {
u.Parse(strHost, strURI) //nolint:errcheck
}
})
}
fasthttp-1.31.0/uri_unix.go 0000664 0000000 0000000 00000000336 14130360711 0015620 0 ustar 00root root 0000000 0000000 //go:build !windows
// +build !windows
package fasthttp
func addLeadingSlash(dst, src []byte) []byte {
// add leading slash for unix paths
if len(src) == 0 || src[0] != '/' {
dst = append(dst, '/')
}
return dst
}
fasthttp-1.31.0/uri_windows.go 0000664 0000000 0000000 00000000350 14130360711 0016323 0 ustar 00root root 0000000 0000000 //go:build windows
// +build windows
package fasthttp
func addLeadingSlash(dst, src []byte) []byte {
// zero length and "C:/" case
if len(src) == 0 || (len(src) > 2 && src[1] != ':') {
dst = append(dst, '/')
}
return dst
}
fasthttp-1.31.0/uri_windows_test.go 0000664 0000000 0000000 00000000417 14130360711 0017366 0 ustar 00root root 0000000 0000000 //go:build windows
// +build windows
package fasthttp
import "testing"
func TestURIPathNormalizeIssue86(t *testing.T) {
t.Parallel()
// see https://github.com/valyala/fasthttp/issues/86
var u URI
testURIPathNormalize(t, &u, `C:\a\b\c\fs.go`, `C:\a\b\c\fs.go`)
}
fasthttp-1.31.0/userdata.go 0000664 0000000 0000000 00000002646 14130360711 0015574 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"io"
)
type userDataKV struct {
key []byte
value interface{}
}
type userData []userDataKV
func (d *userData) Set(key string, value interface{}) {
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
kv.value = value
return
}
}
if value == nil {
return
}
c := cap(args)
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = append(kv.key[:0], key...)
kv.value = value
*d = args
return
}
kv := userDataKV{}
kv.key = append(kv.key[:0], key...)
kv.value = value
*d = append(args, kv)
}
func (d *userData) SetBytes(key []byte, value interface{}) {
d.Set(b2s(key), value)
}
func (d *userData) Get(key string) interface{} {
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
return kv.value
}
}
return nil
}
func (d *userData) GetBytes(key []byte) interface{} {
return d.Get(b2s(key))
}
func (d *userData) Reset() {
args := *d
n := len(args)
for i := 0; i < n; i++ {
v := args[i].value
if vc, ok := v.(io.Closer); ok {
vc.Close()
}
}
*d = (*d)[:0]
}
func (d *userData) Remove(key string) {
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
n--
args[i] = args[n]
args[n].value = nil
args = args[:n]
*d = args
return
}
}
}
func (d *userData) RemoveBytes(key []byte) {
d.Remove(b2s(key))
}
fasthttp-1.31.0/userdata_test.go 0000664 0000000 0000000 00000003775 14130360711 0016637 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"fmt"
"reflect"
"testing"
)
func TestUserData(t *testing.T) {
t.Parallel()
var u userData
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("key_%d", i))
u.SetBytes(key, i+5)
testUserDataGet(t, &u, key, i+5)
u.SetBytes(key, i)
testUserDataGet(t, &u, key, i)
}
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("key_%d", i))
testUserDataGet(t, &u, key, i)
}
u.Reset()
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("key_%d", i))
testUserDataGet(t, &u, key, nil)
}
}
func testUserDataGet(t *testing.T, u *userData, key []byte, value interface{}) {
v := u.GetBytes(key)
if v == nil && value != nil {
t.Fatalf("cannot obtain value for key=%q", key)
}
if !reflect.DeepEqual(v, value) {
t.Fatalf("unexpected value for key=%q: %d. Expecting %d", key, v, value)
}
}
func TestUserDataValueClose(t *testing.T) {
t.Parallel()
var u userData
closeCalls := 0
// store values implementing io.Closer
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key_%d", i)
u.Set(key, &closerValue{&closeCalls})
}
// store values without io.Closer
for i := 0; i < 10; i++ {
key := fmt.Sprintf("key_noclose_%d", i)
u.Set(key, i)
}
u.Reset()
if closeCalls != 5 {
t.Fatalf("unexpected number of Close calls: %d. Expecting 10", closeCalls)
}
}
type closerValue struct {
closeCalls *int
}
func (cv *closerValue) Close() error {
(*cv.closeCalls)++
return nil
}
func TestUserDataDelete(t *testing.T) {
t.Parallel()
var u userData
for i := 0; i < 10; i++ {
key := fmt.Sprintf("key_%d", i)
u.Set(key, i)
testUserDataGet(t, &u, []byte(key), i)
}
for i := 0; i < 10; i += 2 {
k := fmt.Sprintf("key_%d", i)
u.Remove(k)
if val := u.Get(k); val != nil {
t.Fatalf("unexpected key= %s, value =%v ,Expecting key= %s, value = nil", k, val, k)
}
kk := fmt.Sprintf("key_%d", i+1)
testUserDataGet(t, &u, []byte(kk), i+1)
}
for i := 0; i < 10; i++ {
key := fmt.Sprintf("key_new_%d", i)
u.Set(key, i)
testUserDataGet(t, &u, []byte(key), i)
}
}
fasthttp-1.31.0/userdata_timing_test.go 0000664 0000000 0000000 00000001666 14130360711 0020203 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"testing"
)
func BenchmarkUserDataCustom(b *testing.B) {
keys := []string{"foobar", "baz", "aaa", "bsdfs"}
b.RunParallel(func(pb *testing.PB) {
var u userData
var v interface{} = u
for pb.Next() {
for _, key := range keys {
u.Set(key, v)
}
for _, key := range keys {
vv := u.Get(key)
if _, ok := vv.(userData); !ok {
b.Fatalf("unexpected value %v for key %q", vv, key)
}
}
u.Reset()
}
})
}
func BenchmarkUserDataStdMap(b *testing.B) {
keys := []string{"foobar", "baz", "aaa", "bsdfs"}
b.RunParallel(func(pb *testing.PB) {
u := make(map[string]interface{})
var v interface{} = u
for pb.Next() {
for _, key := range keys {
u[key] = v
}
for _, key := range keys {
vv := u[key]
if _, ok := vv.(map[string]interface{}); !ok {
b.Fatalf("unexpected value %v for key %q", vv, key)
}
}
for k := range u {
delete(u, k)
}
}
})
}
fasthttp-1.31.0/workerpool.go 0000664 0000000 0000000 00000012004 14130360711 0016154 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"net"
"runtime"
"strings"
"sync"
"time"
)
// workerPool serves incoming connections via a pool of workers
// in FILO order, i.e. the most recently stopped worker will serve the next
// incoming connection.
//
// Such a scheme keeps CPU caches hot (in theory).
type workerPool struct {
// Function for serving server connections.
// It must leave c unclosed.
WorkerFunc ServeHandler
MaxWorkersCount int
LogAllErrors bool
MaxIdleWorkerDuration time.Duration
Logger Logger
lock sync.Mutex
workersCount int
mustStop bool
ready []*workerChan
stopCh chan struct{}
workerChanPool sync.Pool
connState func(net.Conn, ConnState)
}
type workerChan struct {
lastUseTime time.Time
ch chan net.Conn
}
func (wp *workerPool) Start() {
if wp.stopCh != nil {
panic("BUG: workerPool already started")
}
wp.stopCh = make(chan struct{})
stopCh := wp.stopCh
wp.workerChanPool.New = func() interface{} {
return &workerChan{
ch: make(chan net.Conn, workerChanCap),
}
}
go func() {
var scratch []*workerChan
for {
wp.clean(&scratch)
select {
case <-stopCh:
return
default:
time.Sleep(wp.getMaxIdleWorkerDuration())
}
}
}()
}
func (wp *workerPool) Stop() {
if wp.stopCh == nil {
panic("BUG: workerPool wasn't started")
}
close(wp.stopCh)
wp.stopCh = nil
// Stop all the workers waiting for incoming connections.
// Do not wait for busy workers - they will stop after
// serving the connection and noticing wp.mustStop = true.
wp.lock.Lock()
ready := wp.ready
for i := range ready {
ready[i].ch <- nil
ready[i] = nil
}
wp.ready = ready[:0]
wp.mustStop = true
wp.lock.Unlock()
}
func (wp *workerPool) getMaxIdleWorkerDuration() time.Duration {
if wp.MaxIdleWorkerDuration <= 0 {
return 10 * time.Second
}
return wp.MaxIdleWorkerDuration
}
func (wp *workerPool) clean(scratch *[]*workerChan) {
maxIdleWorkerDuration := wp.getMaxIdleWorkerDuration()
// Clean least recently used workers if they didn't serve connections
// for more than maxIdleWorkerDuration.
criticalTime := time.Now().Add(-maxIdleWorkerDuration)
wp.lock.Lock()
ready := wp.ready
n := len(ready)
// Use binary-search algorithm to find out the index of the least recently worker which can be cleaned up.
l, r, mid := 0, n-1, 0
for l <= r {
mid = (l + r) / 2
if criticalTime.After(wp.ready[mid].lastUseTime) {
l = mid + 1
} else {
r = mid - 1
}
}
i := r
if i == -1 {
wp.lock.Unlock()
return
}
*scratch = append((*scratch)[:0], ready[:i+1]...)
m := copy(ready, ready[i+1:])
for i = m; i < n; i++ {
ready[i] = nil
}
wp.ready = ready[:m]
wp.lock.Unlock()
// Notify obsolete workers to stop.
// This notification must be outside the wp.lock, since ch.ch
// may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs.
tmp := *scratch
for i := range tmp {
tmp[i].ch <- nil
tmp[i] = nil
}
}
func (wp *workerPool) Serve(c net.Conn) bool {
ch := wp.getCh()
if ch == nil {
return false
}
ch.ch <- c
return true
}
var workerChanCap = func() int {
// Use blocking workerChan if GOMAXPROCS=1.
// This immediately switches Serve to WorkerFunc, which results
// in higher performance (under go1.5 at least).
if runtime.GOMAXPROCS(0) == 1 {
return 0
}
// Use non-blocking workerChan if GOMAXPROCS>1,
// since otherwise the Serve caller (Acceptor) may lag accepting
// new connections if WorkerFunc is CPU-bound.
return 1
}()
func (wp *workerPool) getCh() *workerChan {
var ch *workerChan
createWorker := false
wp.lock.Lock()
ready := wp.ready
n := len(ready) - 1
if n < 0 {
if wp.workersCount < wp.MaxWorkersCount {
createWorker = true
wp.workersCount++
}
} else {
ch = ready[n]
ready[n] = nil
wp.ready = ready[:n]
}
wp.lock.Unlock()
if ch == nil {
if !createWorker {
return nil
}
vch := wp.workerChanPool.Get()
ch = vch.(*workerChan)
go func() {
wp.workerFunc(ch)
wp.workerChanPool.Put(vch)
}()
}
return ch
}
func (wp *workerPool) release(ch *workerChan) bool {
ch.lastUseTime = time.Now()
wp.lock.Lock()
if wp.mustStop {
wp.lock.Unlock()
return false
}
wp.ready = append(wp.ready, ch)
wp.lock.Unlock()
return true
}
func (wp *workerPool) workerFunc(ch *workerChan) {
var c net.Conn
var err error
for c = range ch.ch {
if c == nil {
break
}
if err = wp.WorkerFunc(c); err != nil && err != errHijacked {
errStr := err.Error()
if wp.LogAllErrors || !(strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "reset by peer") ||
strings.Contains(errStr, "request headers: small read buffer") ||
strings.Contains(errStr, "unexpected EOF") ||
strings.Contains(errStr, "i/o timeout")) {
wp.Logger.Printf("error when serving connection %q<->%q: %s", c.LocalAddr(), c.RemoteAddr(), err)
}
}
if err == errHijacked {
wp.connState(c, StateHijacked)
} else {
_ = c.Close()
wp.connState(c, StateClosed)
}
c = nil
if !wp.release(ch) {
break
}
}
wp.lock.Lock()
wp.workersCount--
wp.lock.Unlock()
}
fasthttp-1.31.0/workerpool_test.go 0000664 0000000 0000000 00000007031 14130360711 0017217 0 ustar 00root root 0000000 0000000 package fasthttp
import (
"io/ioutil"
"net"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestWorkerPoolStartStopSerial(t *testing.T) {
t.Parallel()
testWorkerPoolStartStop(t)
}
func TestWorkerPoolStartStopConcurrent(t *testing.T) {
t.Parallel()
concurrency := 10
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
testWorkerPoolStartStop(t)
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
}
func testWorkerPoolStartStop(t *testing.T) {
wp := &workerPool{
WorkerFunc: func(conn net.Conn) error { return nil },
MaxWorkersCount: 10,
Logger: defaultLogger,
}
for i := 0; i < 10; i++ {
wp.Start()
wp.Stop()
}
}
func TestWorkerPoolMaxWorkersCountSerial(t *testing.T) {
t.Parallel()
testWorkerPoolMaxWorkersCountMulti(t)
}
func TestWorkerPoolMaxWorkersCountConcurrent(t *testing.T) {
t.Parallel()
concurrency := 4
ch := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
testWorkerPoolMaxWorkersCountMulti(t)
ch <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-ch:
case <-time.After(time.Second * 2):
t.Fatalf("timeout")
}
}
}
func testWorkerPoolMaxWorkersCountMulti(t *testing.T) {
for i := 0; i < 5; i++ {
testWorkerPoolMaxWorkersCount(t)
}
}
func testWorkerPoolMaxWorkersCount(t *testing.T) {
ready := make(chan struct{})
wp := &workerPool{
WorkerFunc: func(conn net.Conn) error {
buf := make([]byte, 100)
n, err := conn.Read(buf)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
buf = buf[:n]
if string(buf) != "foobar" {
t.Errorf("unexpected data read: %q. Expecting %q", buf, "foobar")
}
if _, err = conn.Write([]byte("baz")); err != nil {
t.Errorf("unexpected error: %s", err)
}
<-ready
return nil
},
MaxWorkersCount: 10,
Logger: defaultLogger,
connState: func(net.Conn, ConnState) {},
}
wp.Start()
ln := fasthttputil.NewInmemoryListener()
clientCh := make(chan struct{}, wp.MaxWorkersCount)
for i := 0; i < wp.MaxWorkersCount; i++ {
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte("foobar")); err != nil {
t.Errorf("unexpected error: %s", err)
}
data, err := ioutil.ReadAll(conn)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if string(data) != "baz" {
t.Errorf("unexpected value read: %q. Expecting %q", data, "baz")
}
if err = conn.Close(); err != nil {
t.Errorf("unexpected error: %s", err)
}
clientCh <- struct{}{}
}()
}
for i := 0; i < wp.MaxWorkersCount; i++ {
conn, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !wp.Serve(conn) {
t.Fatalf("worker pool must have enough workers to serve the conn")
}
}
go func() {
if _, err := ln.Dial(); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()
conn, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
for i := 0; i < 5; i++ {
if wp.Serve(conn) {
t.Fatalf("worker pool must be full")
}
}
if err = conn.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
close(ready)
for i := 0; i < wp.MaxWorkersCount; i++ {
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
wp.Stop()
}