pax_global_header 0000666 0000000 0000000 00000000064 14145546033 0014517 g ustar 00root root 0000000 0000000 52 comment=0316d5a1df8598eceb137f5f77945be56810b564
chi-5.0.7/ 0000775 0000000 0000000 00000000000 14145546033 0012273 5 ustar 00root root 0000000 0000000 chi-5.0.7/.github/ 0000775 0000000 0000000 00000000000 14145546033 0013633 5 ustar 00root root 0000000 0000000 chi-5.0.7/.github/FUNDING.yml 0000664 0000000 0000000 00000001323 14145546033 0015447 0 ustar 00root root 0000000 0000000 # These are supported funding model platforms
github: [pkieltyka] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
chi-5.0.7/.github/workflows/ 0000775 0000000 0000000 00000000000 14145546033 0015670 5 ustar 00root root 0000000 0000000 chi-5.0.7/.github/workflows/ci.yml 0000664 0000000 0000000 00000001531 14145546033 0017006 0 ustar 00root root 0000000 0000000 on:
push:
branches: '**'
paths-ignore:
- 'docs/**'
pull_request:
branches: '**'
paths-ignore:
- 'docs/**'
name: Test
jobs:
test:
env:
GOPATH: ${{ github.workspace }}
GO111MODULE: off
defaults:
run:
working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
strategy:
matrix:
go-version: [1.14.x, 1.15.x, 1.16.x, 1.17.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v2
with:
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
- name: Test
run: |
go get -d -t ./...
make test
chi-5.0.7/.gitignore 0000664 0000000 0000000 00000000024 14145546033 0014257 0 ustar 00root root 0000000 0000000 .idea
*.sw?
.vscode
chi-5.0.7/CHANGELOG.md 0000664 0000000 0000000 00000035624 14145546033 0014116 0 ustar 00root root 0000000 0000000 # Changelog
## v5.0.7 (2021-11-18)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.6...v5.0.7
## v5.0.6 (2021-11-15)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.5...v5.0.6
## v5.0.5 (2021-10-27)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.4...v5.0.5
## v5.0.4 (2021-08-29)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.3...v5.0.4
## v5.0.3 (2021-04-29)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.2...v5.0.3
## v5.0.2 (2021-03-25)
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.1...v5.0.2
## v5.0.1 (2021-03-10)
- Small improvements
- History of changes: see https://github.com/go-chi/chi/compare/v5.0.0...v5.0.1
## v5.0.0 (2021-02-27)
- chi v5, `github.com/go-chi/chi/v5` introduces the adoption of Go's SIV to adhere to the current state-of-the-tools in Go.
- chi v1.5.x did not work out as planned, as the Go tooling is too powerful and chi's adoption is too wide.
The most responsible thing to do for everyone's benefit is to just release v5 with SIV, so I present to you all,
chi v5 at `github.com/go-chi/chi/v5`. I hope someday the developer experience and ergonomics I've been seeking
will still come to fruition in some form, see https://github.com/golang/go/issues/44550
- History of changes: see https://github.com/go-chi/chi/compare/v1.5.4...v5.0.0
## v1.5.4 (2021-02-27)
- Undo prior retraction in v1.5.3 as we prepare for v5.0.0 release
- History of changes: see https://github.com/go-chi/chi/compare/v1.5.3...v1.5.4
## v1.5.3 (2021-02-21)
- Update go.mod to go 1.16 with new retract directive marking all versions without prior go.mod support
- History of changes: see https://github.com/go-chi/chi/compare/v1.5.2...v1.5.3
## v1.5.2 (2021-02-10)
- Reverting allocation optimization as a precaution as go test -race fails.
- Minor improvements, see history below
- History of changes: see https://github.com/go-chi/chi/compare/v1.5.1...v1.5.2
## v1.5.1 (2020-12-06)
- Performance improvement: removing 1 allocation by foregoing context.WithValue, thank you @bouk for
your contribution (https://github.com/go-chi/chi/pull/555). Note: new benchmarks posted in README.
- `middleware.CleanPath`: new middleware that clean's request path of double slashes
- deprecate & remove `chi.ServerBaseContext` in favour of stdlib `http.Server#BaseContext`
- plus other tiny improvements, see full commit history below
- History of changes: see https://github.com/go-chi/chi/compare/v4.1.2...v1.5.1
## v1.5.0 (2020-11-12) - now with go.mod support
`chi` dates back to 2016 with it's original implementation as one of the first routers to adopt the newly introduced
context.Context api to the stdlib -- set out to design a router that is faster, more modular and simpler than anything
else out there -- while not introducing any custom handler types or dependencies. Today, `chi` still has zero dependencies,
and in many ways is future proofed from changes, given it's minimal nature. Between versions, chi's iterations have been very
incremental, with the architecture and api being the same today as it was originally designed in 2016. For this reason it
makes chi a pretty easy project to maintain, as well thanks to the many amazing community contributions over the years
to who all help make chi better (total of 86 contributors to date -- thanks all!).
Chi has been an labour of love, art and engineering, with the goals to offer beautiful ergonomics, flexibility, performance
and simplicity when building HTTP services with Go. I've strived to keep the router very minimal in surface area / code size,
and always improving the code wherever possible -- and as of today the `chi` package is just 1082 lines of code (not counting
middlewares, which are all optional). As well, I don't have the exact metrics, but from my analysis and email exchanges from
companies and developers, chi is used by thousands of projects around the world -- thank you all as there is no better form of
joy for me than to have art I had started be helpful and enjoyed by others. And of course I use chi in all of my own projects too :)
For me, the asthetics of chi's code and usage are very important. With the introduction of Go's module support
(which I'm a big fan of), chi's past versioning scheme choice to v2, v3 and v4 would mean I'd require the import path
of "github.com/go-chi/chi/v4", leading to the lengthy discussion at https://github.com/go-chi/chi/issues/462.
Haha, to some, you may be scratching your head why I've spent > 1 year stalling to adopt "/vXX" convention in the import
path -- which isn't horrible in general -- but for chi, I'm unable to accept it as I strive for perfection in it's API design,
aesthetics and simplicity. It just doesn't feel good to me given chi's simple nature -- I do not foresee a "v5" or "v6",
and upgrading between versions in the future will also be just incremental.
I do understand versioning is a part of the API design as well, which is why the solution for a while has been to "do nothing",
as Go supports both old and new import paths with/out go.mod. However, now that Go module support has had time to iron out kinks and
is adopted everywhere, it's time for chi to get with the times. Luckily, I've discovered a path forward that will make me happy,
while also not breaking anyone's app who adopted a prior versioning from tags in v2/v3/v4. I've made an experimental release of
v1.5.0 with go.mod silently, and tested it with new and old projects, to ensure the developer experience is preserved, and it's
largely unnoticed. Fortunately, Go's toolchain will check the tags of a repo and consider the "latest" tag the one with go.mod.
However, you can still request a specific older tag such as v4.1.2, and everything will "just work". But new users can just
`go get github.com/go-chi/chi` or `go get github.com/go-chi/chi@latest` and they will get the latest version which contains
go.mod support, which is v1.5.0+. `chi` will not change very much over the years, just like it hasn't changed much from 4 years ago.
Therefore, we will stay on v1.x from here on, starting from v1.5.0. Any breaking changes will bump a "minor" release and
backwards-compatible improvements/fixes will bump a "tiny" release.
For existing projects who want to upgrade to the latest go.mod version, run: `go get -u github.com/go-chi/chi@v1.5.0`,
which will get you on the go.mod version line (as Go's mod cache may still remember v4.x). Brand new systems can run
`go get -u github.com/go-chi/chi` or `go get -u github.com/go-chi/chi@latest` to install chi, which will install v1.5.0+
built with go.mod support.
My apologies to the developers who will disagree with the decisions above, but, hope you'll try it and see it's a very
minor request which is backwards compatible and won't break your existing installations.
Cheers all, happy coding!
---
## v4.1.2 (2020-06-02)
- fix that handles MethodNotAllowed with path variables, thank you @caseyhadden for your contribution
- fix to replace nested wildcards correctly in RoutePattern, thank you @@unmultimedio for your contribution
- History of changes: see https://github.com/go-chi/chi/compare/v4.1.1...v4.1.2
## v4.1.1 (2020-04-16)
- fix for issue https://github.com/go-chi/chi/issues/411 which allows for overlapping regexp
route to the correct handler through a recursive tree search, thanks to @Jahaja for the PR/fix!
- new middleware.RouteHeaders as a simple router for request headers with wildcard support
- History of changes: see https://github.com/go-chi/chi/compare/v4.1.0...v4.1.1
## v4.1.0 (2020-04-1)
- middleware.LogEntry: Write method on interface now passes the response header
and an extra interface type useful for custom logger implementations.
- middleware.WrapResponseWriter: minor fix
- middleware.Recoverer: a bit prettier
- History of changes: see https://github.com/go-chi/chi/compare/v4.0.4...v4.1.0
## v4.0.4 (2020-03-24)
- middleware.Recoverer: new pretty stack trace printing (https://github.com/go-chi/chi/pull/496)
- a few minor improvements and fixes
- History of changes: see https://github.com/go-chi/chi/compare/v4.0.3...v4.0.4
## v4.0.3 (2020-01-09)
- core: fix regexp routing to include default value when param is not matched
- middleware: rewrite of middleware.Compress
- middleware: suppress http.ErrAbortHandler in middleware.Recoverer
- History of changes: see https://github.com/go-chi/chi/compare/v4.0.2...v4.0.3
## v4.0.2 (2019-02-26)
- Minor fixes
- History of changes: see https://github.com/go-chi/chi/compare/v4.0.1...v4.0.2
## v4.0.1 (2019-01-21)
- Fixes issue with compress middleware: #382 #385
- History of changes: see https://github.com/go-chi/chi/compare/v4.0.0...v4.0.1
## v4.0.0 (2019-01-10)
- chi v4 requires Go 1.10.3+ (or Go 1.9.7+) - we have deprecated support for Go 1.7 and 1.8
- router: respond with 404 on router with no routes (#362)
- router: additional check to ensure wildcard is at the end of a url pattern (#333)
- middleware: deprecate use of http.CloseNotifier (#347)
- middleware: fix RedirectSlashes to include query params on redirect (#334)
- History of changes: see https://github.com/go-chi/chi/compare/v3.3.4...v4.0.0
## v3.3.4 (2019-01-07)
- Minor middleware improvements. No changes to core library/router. Moving v3 into its
- own branch as a version of chi for Go 1.7, 1.8, 1.9, 1.10, 1.11
- History of changes: see https://github.com/go-chi/chi/compare/v3.3.3...v3.3.4
## v3.3.3 (2018-08-27)
- Minor release
- See https://github.com/go-chi/chi/compare/v3.3.2...v3.3.3
## v3.3.2 (2017-12-22)
- Support to route trailing slashes on mounted sub-routers (#281)
- middleware: new `ContentCharset` to check matching charsets. Thank you
@csucu for your community contribution!
## v3.3.1 (2017-11-20)
- middleware: new `AllowContentType` handler for explicit whitelist of accepted request Content-Types
- middleware: new `SetHeader` handler for short-hand middleware to set a response header key/value
- Minor bug fixes
## v3.3.0 (2017-10-10)
- New chi.RegisterMethod(method) to add support for custom HTTP methods, see _examples/custom-method for usage
- Deprecated LINK and UNLINK methods from the default list, please use `chi.RegisterMethod("LINK")` and `chi.RegisterMethod("UNLINK")` in an `init()` function
## v3.2.1 (2017-08-31)
- Add new `Match(rctx *Context, method, path string) bool` method to `Routes` interface
and `Mux`. Match searches the mux's routing tree for a handler that matches the method/path
- Add new `RouteMethod` to `*Context`
- Add new `Routes` pointer to `*Context`
- Add new `middleware.GetHead` to route missing HEAD requests to GET handler
- Updated benchmarks (see README)
## v3.1.5 (2017-08-02)
- Setup golint and go vet for the project
- As per golint, we've redefined `func ServerBaseContext(h http.Handler, baseCtx context.Context) http.Handler`
to `func ServerBaseContext(baseCtx context.Context, h http.Handler) http.Handler`
## v3.1.0 (2017-07-10)
- Fix a few minor issues after v3 release
- Move `docgen` sub-pkg to https://github.com/go-chi/docgen
- Move `render` sub-pkg to https://github.com/go-chi/render
- Add new `URLFormat` handler to chi/middleware sub-pkg to make working with url mime
suffixes easier, ie. parsing `/articles/1.json` and `/articles/1.xml`. See comments in
https://github.com/go-chi/chi/blob/master/middleware/url_format.go for example usage.
## v3.0.0 (2017-06-21)
- Major update to chi library with many exciting updates, but also some *breaking changes*
- URL parameter syntax changed from `/:id` to `/{id}` for even more flexible routing, such as
`/articles/{month}-{day}-{year}-{slug}`, `/articles/{id}`, and `/articles/{id}.{ext}` on the
same router
- Support for regexp for routing patterns, in the form of `/{paramKey:regExp}` for example:
`r.Get("/articles/{name:[a-z]+}", h)` and `chi.URLParam(r, "name")`
- Add `Method` and `MethodFunc` to `chi.Router` to allow routing definitions such as
`r.Method("GET", "/", h)` which provides a cleaner interface for custom handlers like
in `_examples/custom-handler`
- Deprecating `mux#FileServer` helper function. Instead, we encourage users to create their
own using file handler with the stdlib, see `_examples/fileserver` for an example
- Add support for LINK/UNLINK http methods via `r.Method()` and `r.MethodFunc()`
- Moved the chi project to its own organization, to allow chi-related community packages to
be easily discovered and supported, at: https://github.com/go-chi
- *NOTE:* please update your import paths to `"github.com/go-chi/chi"`
- *NOTE:* chi v2 is still available at https://github.com/go-chi/chi/tree/v2
## v2.1.0 (2017-03-30)
- Minor improvements and update to the chi core library
- Introduced a brand new `chi/render` sub-package to complete the story of building
APIs to offer a pattern for managing well-defined request / response payloads. Please
check out the updated `_examples/rest` example for how it works.
- Added `MethodNotAllowed(h http.HandlerFunc)` to chi.Router interface
## v2.0.0 (2017-01-06)
- After many months of v2 being in an RC state with many companies and users running it in
production, the inclusion of some improvements to the middlewares, we are very pleased to
announce v2.0.0 of chi.
## v2.0.0-rc1 (2016-07-26)
- Huge update! chi v2 is a large refactor targetting Go 1.7+. As of Go 1.7, the popular
community `"net/context"` package has been included in the standard library as `"context"` and
utilized by `"net/http"` and `http.Request` to managing deadlines, cancelation signals and other
request-scoped values. We're very excited about the new context addition and are proud to
introduce chi v2, a minimal and powerful routing package for building large HTTP services,
with zero external dependencies. Chi focuses on idiomatic design and encourages the use of
stdlib HTTP handlers and middlwares.
- chi v2 deprecates its `chi.Handler` interface and requires `http.Handler` or `http.HandlerFunc`
- chi v2 stores URL routing parameters and patterns in the standard request context: `r.Context()`
- chi v2 lower-level routing context is accessible by `chi.RouteContext(r.Context()) *chi.Context`,
which provides direct access to URL routing parameters, the routing path and the matching
routing patterns.
- Users upgrading from chi v1 to v2, need to:
1. Update the old chi.Handler signature, `func(ctx context.Context, w http.ResponseWriter, r *http.Request)` to
the standard http.Handler: `func(w http.ResponseWriter, r *http.Request)`
2. Use `chi.URLParam(r *http.Request, paramKey string) string`
or `URLParamFromCtx(ctx context.Context, paramKey string) string` to access a url parameter value
## v1.0.0 (2016-07-01)
- Released chi v1 stable https://github.com/go-chi/chi/tree/v1.0.0 for Go 1.6 and older.
## v0.9.0 (2016-03-31)
- Reuse context objects via sync.Pool for zero-allocation routing [#33](https://github.com/go-chi/chi/pull/33)
- BREAKING NOTE: due to subtle API changes, previously `chi.URLParams(ctx)["id"]` used to access url parameters
has changed to: `chi.URLParam(ctx, "id")`
chi-5.0.7/CONTRIBUTING.md 0000664 0000000 0000000 00000002056 14145546033 0014527 0 ustar 00root root 0000000 0000000 # Contributing
## Prerequisites
1. [Install Go][go-install].
2. Download the sources and switch the working directory:
```bash
go get -u -d github.com/go-chi/chi
cd $GOPATH/src/github.com/go-chi/chi
```
## Submitting a Pull Request
A typical workflow is:
1. [Fork the repository.][fork] [This tip maybe also helpful.][go-fork-tip]
2. [Create a topic branch.][branch]
3. Add tests for your change.
4. Run `go test`. If your tests pass, return to the step 3.
5. Implement the change and ensure the steps from the previous step pass.
6. Run `goimports -w .`, to ensure the new code conforms to Go formatting guideline.
7. [Add, commit and push your changes.][git-help]
8. [Submit a pull request.][pull-req]
[go-install]: https://golang.org/doc/install
[go-fork-tip]: http://blog.campoy.cat/2014/03/github-and-go-forking-pull-requests-and.html
[fork]: https://help.github.com/articles/fork-a-repo
[branch]: http://learn.github.com/p/branching.html
[git-help]: https://guides.github.com
[pull-req]: https://help.github.com/articles/using-pull-requests
chi-5.0.7/LICENSE 0000664 0000000 0000000 00000002143 14145546033 0013300 0 ustar 00root root 0000000 0000000 Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
chi-5.0.7/Makefile 0000664 0000000 0000000 00000000645 14145546033 0013740 0 ustar 00root root 0000000 0000000 all:
@echo "**********************************************************"
@echo "** chi build tool **"
@echo "**********************************************************"
test:
go clean -testcache && $(MAKE) test-router && $(MAKE) test-middleware
test-router:
go test -race -v .
test-middleware:
go test -race -v ./middleware
.PHONY: docs
docs:
npx docsify-cli serve ./docs
chi-5.0.7/README.md 0000664 0000000 0000000 00000054752 14145546033 0013567 0 ustar 00root root 0000000 0000000 #
[![GoDoc Widget]][GoDoc] [![Travis Widget]][Travis]
`chi` is a lightweight, idiomatic and composable router for building Go HTTP services. It's
especially good at helping you write large REST API services that are kept maintainable as your
project grows and changes. `chi` is built on the new `context` package introduced in Go 1.7 to
handle signaling, cancelation and request-scoped values across a handler chain.
The focus of the project has been to seek out an elegant and comfortable design for writing
REST API servers, written during the development of the Pressly API service that powers our
public API service, which in turn powers all of our client-side applications.
The key considerations of chi's design are: project structure, maintainability, standard http
handlers (stdlib-only), developer productivity, and deconstructing a large system into many small
parts. The core router `github.com/go-chi/chi` is quite small (less than 1000 LOC), but we've also
included some useful/optional subpackages: [middleware](/middleware), [render](https://github.com/go-chi/render)
and [docgen](https://github.com/go-chi/docgen). We hope you enjoy it too!
## Install
`go get -u github.com/go-chi/chi/v5`
## Features
* **Lightweight** - cloc'd in ~1000 LOC for the chi router
* **Fast** - yes, see [benchmarks](#benchmarks)
* **100% compatible with net/http** - use any http or middleware pkg in the ecosystem that is also compatible with `net/http`
* **Designed for modular/composable APIs** - middlewares, inline middlewares, route groups and sub-router mounting
* **Context control** - built on new `context` package, providing value chaining, cancellations and timeouts
* **Robust** - in production at Pressly, CloudFlare, Heroku, 99Designs, and many others (see [discussion](https://github.com/go-chi/chi/issues/91))
* **Doc generation** - `docgen` auto-generates routing documentation from your source to JSON or Markdown
* **Go.mod support** - as of v5, go.mod support (see [CHANGELOG](https://github.com/go-chi/chi/blob/master/CHANGELOG.md))
* **No external dependencies** - plain ol' Go stdlib + net/http
## Examples
See [_examples/](https://github.com/go-chi/chi/blob/master/_examples/) for a variety of examples.
**As easy as:**
```go
package main
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
http.ListenAndServe(":3000", r)
}
```
**REST Preview:**
Here is a little preview of how routing looks like with chi. Also take a look at the generated routing docs
in JSON ([routes.json](https://github.com/go-chi/chi/blob/master/_examples/rest/routes.json)) and in
Markdown ([routes.md](https://github.com/go-chi/chi/blob/master/_examples/rest/routes.md)).
I highly recommend reading the source of the [examples](https://github.com/go-chi/chi/blob/master/_examples/) listed
above, they will show you all the features of chi and serve as a good form of documentation.
```go
import (
//...
"context"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
// A good base middleware stack
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
// Set a timeout value on the request context (ctx), that will signal
// through ctx.Done() that the request has timed out and further
// processing should be stopped.
r.Use(middleware.Timeout(60 * time.Second))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi"))
})
// RESTy routes for "articles" resource
r.Route("/articles", func(r chi.Router) {
r.With(paginate).Get("/", listArticles) // GET /articles
r.With(paginate).Get("/{month}-{day}-{year}", listArticlesByDate) // GET /articles/01-16-2017
r.Post("/", createArticle) // POST /articles
r.Get("/search", searchArticles) // GET /articles/search
// Regexp url parameters:
r.Get("/{articleSlug:[a-z-]+}", getArticleBySlug) // GET /articles/home-is-toronto
// Subrouters:
r.Route("/{articleID}", func(r chi.Router) {
r.Use(ArticleCtx)
r.Get("/", getArticle) // GET /articles/123
r.Put("/", updateArticle) // PUT /articles/123
r.Delete("/", deleteArticle) // DELETE /articles/123
})
})
// Mount the admin sub-router
r.Mount("/admin", adminRouter())
http.ListenAndServe(":3333", r)
}
func ArticleCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
articleID := chi.URLParam(r, "articleID")
article, err := dbGetArticle(articleID)
if err != nil {
http.Error(w, http.StatusText(404), 404)
return
}
ctx := context.WithValue(r.Context(), "article", article)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func getArticle(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
article, ok := ctx.Value("article").(*Article)
if !ok {
http.Error(w, http.StatusText(422), 422)
return
}
w.Write([]byte(fmt.Sprintf("title:%s", article.Title)))
}
// A completely separate router for administrator routes
func adminRouter() http.Handler {
r := chi.NewRouter()
r.Use(AdminOnly)
r.Get("/", adminIndex)
r.Get("/accounts", adminListAccounts)
return r
}
func AdminOnly(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
perm, ok := ctx.Value("acl.permission").(YourPermissionType)
if !ok || !perm.IsAdmin() {
http.Error(w, http.StatusText(403), 403)
return
}
next.ServeHTTP(w, r)
})
}
```
## Router interface
chi's router is based on a kind of [Patricia Radix trie](https://en.wikipedia.org/wiki/Radix_tree).
The router is fully compatible with `net/http`.
Built on top of the tree is the `Router` interface:
```go
// Router consisting of the core routing methods used by chi's Mux,
// using only the standard net/http.
type Router interface {
http.Handler
Routes
// Use appends one or more middlewares onto the Router stack.
Use(middlewares ...func(http.Handler) http.Handler)
// With adds inline middlewares for an endpoint handler.
With(middlewares ...func(http.Handler) http.Handler) Router
// Group adds a new inline-Router along the current routing
// path, with a fresh middleware stack for the inline-Router.
Group(fn func(r Router)) Router
// Route mounts a sub-Router along a `pattern`` string.
Route(pattern string, fn func(r Router)) Router
// Mount attaches another http.Handler along ./pattern/*
Mount(pattern string, h http.Handler)
// Handle and HandleFunc adds routes for `pattern` that matches
// all HTTP methods.
Handle(pattern string, h http.Handler)
HandleFunc(pattern string, h http.HandlerFunc)
// Method and MethodFunc adds routes for `pattern` that matches
// the `method` HTTP method.
Method(method, pattern string, h http.Handler)
MethodFunc(method, pattern string, h http.HandlerFunc)
// HTTP-method routing along `pattern`
Connect(pattern string, h http.HandlerFunc)
Delete(pattern string, h http.HandlerFunc)
Get(pattern string, h http.HandlerFunc)
Head(pattern string, h http.HandlerFunc)
Options(pattern string, h http.HandlerFunc)
Patch(pattern string, h http.HandlerFunc)
Post(pattern string, h http.HandlerFunc)
Put(pattern string, h http.HandlerFunc)
Trace(pattern string, h http.HandlerFunc)
// NotFound defines a handler to respond whenever a route could
// not be found.
NotFound(h http.HandlerFunc)
// MethodNotAllowed defines a handler to respond whenever a method is
// not allowed.
MethodNotAllowed(h http.HandlerFunc)
}
// Routes interface adds two methods for router traversal, which is also
// used by the github.com/go-chi/docgen package to generate documentation for Routers.
type Routes interface {
// Routes returns the routing tree in an easily traversable structure.
Routes() []Route
// Middlewares returns the list of middlewares in use by the router.
Middlewares() Middlewares
// Match searches the routing tree for a handler that matches
// the method/path - similar to routing a http request, but without
// executing the handler thereafter.
Match(rctx *Context, method, path string) bool
}
```
Each routing method accepts a URL `pattern` and chain of `handlers`. The URL pattern
supports named params (ie. `/users/{userID}`) and wildcards (ie. `/admin/*`). URL parameters
can be fetched at runtime by calling `chi.URLParam(r, "userID")` for named parameters
and `chi.URLParam(r, "*")` for a wildcard parameter.
### Middleware handlers
chi's middlewares are just stdlib net/http middleware handlers. There is nothing special
about them, which means the router and all the tooling is designed to be compatible and
friendly with any middleware in the community. This offers much better extensibility and reuse
of packages and is at the heart of chi's purpose.
Here is an example of a standard net/http middleware where we assign a context key `"user"`
the value of `"123"`. This middleware sets a hypothetical user identifier on the request
context and calls the next handler in the chain.
```go
// HTTP middleware setting a value on the request context
func MyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// create new context from `r` request context, and assign key `"user"`
// to value of `"123"`
ctx := context.WithValue(r.Context(), "user", "123")
// call the next handler in the chain, passing the response writer and
// the updated request object with the new context value.
//
// note: context.Context values are nested, so any previously set
// values will be accessible as well, and the new `"user"` key
// will be accessible from this point forward.
next.ServeHTTP(w, r.WithContext(ctx))
})
}
```
### Request handlers
chi uses standard net/http request handlers. This little snippet is an example of a http.Handler
func that reads a user identifier from the request context - hypothetically, identifying
the user sending an authenticated request, validated+set by a previous middleware handler.
```go
// HTTP handler accessing data from the request context.
func MyRequestHandler(w http.ResponseWriter, r *http.Request) {
// here we read from the request context and fetch out `"user"` key set in
// the MyMiddleware example above.
user := r.Context().Value("user").(string)
// respond to the client
w.Write([]byte(fmt.Sprintf("hi %s", user)))
}
```
### URL parameters
chi's router parses and stores URL parameters right onto the request context. Here is
an example of how to access URL params in your net/http handlers. And of course, middlewares
are able to access the same information.
```go
// HTTP handler accessing the url routing parameters.
func MyRequestHandler(w http.ResponseWriter, r *http.Request) {
// fetch the url parameter `"userID"` from the request of a matching
// routing pattern. An example routing pattern could be: /users/{userID}
userID := chi.URLParam(r, "userID")
// fetch `"key"` from the request context
ctx := r.Context()
key := ctx.Value("key").(string)
// respond to the client
w.Write([]byte(fmt.Sprintf("hi %v, %v", userID, key)))
}
```
## Middlewares
chi comes equipped with an optional `middleware` package, providing a suite of standard
`net/http` middlewares. Please note, any middleware in the ecosystem that is also compatible
with `net/http` can be used with chi's mux.
### Core middlewares
----------------------------------------------------------------------------------------------------
| chi/middleware Handler | description |
| :--------------------- | :---------------------------------------------------------------------- |
| [AllowContentEncoding] | Enforces a whitelist of request Content-Encoding headers |
| [AllowContentType] | Explicit whitelist of accepted request Content-Types |
| [BasicAuth] | Basic HTTP authentication |
| [Compress] | Gzip compression for clients that accept compressed responses |
| [ContentCharset] | Ensure charset for Content-Type request headers |
| [CleanPath] | Clean double slashes from request path |
| [GetHead] | Automatically route undefined HEAD requests to GET handlers |
| [Heartbeat] | Monitoring endpoint to check the servers pulse |
| [Logger] | Logs the start and end of each request with the elapsed processing time |
| [NoCache] | Sets response headers to prevent clients from caching |
| [Profiler] | Easily attach net/http/pprof to your routers |
| [RealIP] | Sets a http.Request's RemoteAddr to either X-Real-IP or X-Forwarded-For |
| [Recoverer] | Gracefully absorb panics and prints the stack trace |
| [RequestID] | Injects a request ID into the context of each request |
| [RedirectSlashes] | Redirect slashes on routing paths |
| [RouteHeaders] | Route handling for request headers |
| [SetHeader] | Short-hand middleware to set a response header key/value |
| [StripSlashes] | Strip slashes on routing paths |
| [Throttle] | Puts a ceiling on the number of concurrent requests |
| [Timeout] | Signals to the request context when the timeout deadline is reached |
| [URLFormat] | Parse extension from url and put it on request context |
| [WithValue] | Short-hand middleware to set a key/value on the request context |
----------------------------------------------------------------------------------------------------
[AllowContentEncoding]: https://pkg.go.dev/github.com/go-chi/chi/middleware#AllowContentEncoding
[AllowContentType]: https://pkg.go.dev/github.com/go-chi/chi/middleware#AllowContentType
[BasicAuth]: https://pkg.go.dev/github.com/go-chi/chi/middleware#BasicAuth
[Compress]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Compress
[ContentCharset]: https://pkg.go.dev/github.com/go-chi/chi/middleware#ContentCharset
[CleanPath]: https://pkg.go.dev/github.com/go-chi/chi/middleware#CleanPath
[GetHead]: https://pkg.go.dev/github.com/go-chi/chi/middleware#GetHead
[GetReqID]: https://pkg.go.dev/github.com/go-chi/chi/middleware#GetReqID
[Heartbeat]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Heartbeat
[Logger]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Logger
[NoCache]: https://pkg.go.dev/github.com/go-chi/chi/middleware#NoCache
[Profiler]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Profiler
[RealIP]: https://pkg.go.dev/github.com/go-chi/chi/middleware#RealIP
[Recoverer]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Recoverer
[RedirectSlashes]: https://pkg.go.dev/github.com/go-chi/chi/middleware#RedirectSlashes
[RequestLogger]: https://pkg.go.dev/github.com/go-chi/chi/middleware#RequestLogger
[RequestID]: https://pkg.go.dev/github.com/go-chi/chi/middleware#RequestID
[RouteHeaders]: https://pkg.go.dev/github.com/go-chi/chi/middleware#RouteHeaders
[SetHeader]: https://pkg.go.dev/github.com/go-chi/chi/middleware#SetHeader
[StripSlashes]: https://pkg.go.dev/github.com/go-chi/chi/middleware#StripSlashes
[Throttle]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Throttle
[ThrottleBacklog]: https://pkg.go.dev/github.com/go-chi/chi/middleware#ThrottleBacklog
[ThrottleWithOpts]: https://pkg.go.dev/github.com/go-chi/chi/middleware#ThrottleWithOpts
[Timeout]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Timeout
[URLFormat]: https://pkg.go.dev/github.com/go-chi/chi/middleware#URLFormat
[WithLogEntry]: https://pkg.go.dev/github.com/go-chi/chi/middleware#WithLogEntry
[WithValue]: https://pkg.go.dev/github.com/go-chi/chi/middleware#WithValue
[Compressor]: https://pkg.go.dev/github.com/go-chi/chi/middleware#Compressor
[DefaultLogFormatter]: https://pkg.go.dev/github.com/go-chi/chi/middleware#DefaultLogFormatter
[EncoderFunc]: https://pkg.go.dev/github.com/go-chi/chi/middleware#EncoderFunc
[HeaderRoute]: https://pkg.go.dev/github.com/go-chi/chi/middleware#HeaderRoute
[HeaderRouter]: https://pkg.go.dev/github.com/go-chi/chi/middleware#HeaderRouter
[LogEntry]: https://pkg.go.dev/github.com/go-chi/chi/middleware#LogEntry
[LogFormatter]: https://pkg.go.dev/github.com/go-chi/chi/middleware#LogFormatter
[LoggerInterface]: https://pkg.go.dev/github.com/go-chi/chi/middleware#LoggerInterface
[ThrottleOpts]: https://pkg.go.dev/github.com/go-chi/chi/middleware#ThrottleOpts
[WrapResponseWriter]: https://pkg.go.dev/github.com/go-chi/chi/middleware#WrapResponseWriter
### Extra middlewares & packages
Please see https://github.com/go-chi for additional packages.
--------------------------------------------------------------------------------------------------------------------
| package | description |
|:---------------------------------------------------|:-------------------------------------------------------------
| [cors](https://github.com/go-chi/cors) | Cross-origin resource sharing (CORS) |
| [docgen](https://github.com/go-chi/docgen) | Print chi.Router routes at runtime |
| [jwtauth](https://github.com/go-chi/jwtauth) | JWT authentication |
| [hostrouter](https://github.com/go-chi/hostrouter) | Domain/host based request routing |
| [httplog](https://github.com/go-chi/httplog) | Small but powerful structured HTTP request logging |
| [httprate](https://github.com/go-chi/httprate) | HTTP request rate limiter |
| [httptracer](https://github.com/go-chi/httptracer) | HTTP request performance tracing library |
| [httpvcr](https://github.com/go-chi/httpvcr) | Write deterministic tests for external sources |
| [stampede](https://github.com/go-chi/stampede) | HTTP request coalescer |
--------------------------------------------------------------------------------------------------------------------
## context?
`context` is a tiny pkg that provides simple interface to signal context across call stacks
and goroutines. It was originally written by [Sameer Ajmani](https://github.com/Sajmani)
and is available in stdlib since go1.7.
Learn more at https://blog.golang.org/context
and..
* Docs: https://golang.org/pkg/context
* Source: https://github.com/golang/go/tree/master/src/context
## Benchmarks
The benchmark suite: https://github.com/pkieltyka/go-http-routing-benchmark
Results as of Nov 29, 2020 with Go 1.15.5 on Linux AMD 3950x
```shell
BenchmarkChi_Param 3075895 384 ns/op 400 B/op 2 allocs/op
BenchmarkChi_Param5 2116603 566 ns/op 400 B/op 2 allocs/op
BenchmarkChi_Param20 964117 1227 ns/op 400 B/op 2 allocs/op
BenchmarkChi_ParamWrite 2863413 420 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GithubStatic 3045488 395 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GithubParam 2204115 540 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GithubAll 10000 113811 ns/op 81203 B/op 406 allocs/op
BenchmarkChi_GPlusStatic 3337485 359 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GPlusParam 2825853 423 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GPlus2Params 2471697 483 ns/op 400 B/op 2 allocs/op
BenchmarkChi_GPlusAll 194220 5950 ns/op 5200 B/op 26 allocs/op
BenchmarkChi_ParseStatic 3365324 356 ns/op 400 B/op 2 allocs/op
BenchmarkChi_ParseParam 2976614 404 ns/op 400 B/op 2 allocs/op
BenchmarkChi_Parse2Params 2638084 439 ns/op 400 B/op 2 allocs/op
BenchmarkChi_ParseAll 109567 11295 ns/op 10400 B/op 52 allocs/op
BenchmarkChi_StaticAll 16846 71308 ns/op 62802 B/op 314 allocs/op
```
Comparison with other routers: https://gist.github.com/pkieltyka/123032f12052520aaccab752bd3e78cc
NOTE: the allocs in the benchmark above are from the calls to http.Request's
`WithContext(context.Context)` method that clones the http.Request, sets the `Context()`
on the duplicated (alloc'd) request and returns it the new request object. This is just
how setting context on a request in Go works.
## Credits
* Carl Jackson for https://github.com/zenazn/goji
* Parts of chi's thinking comes from goji, and chi's middleware package
sources from goji.
* Armon Dadgar for https://github.com/armon/go-radix
* Contributions: [@VojtechVitek](https://github.com/VojtechVitek)
We'll be more than happy to see [your contributions](./CONTRIBUTING.md)!
## Beyond REST
chi is just a http router that lets you decompose request handling into many smaller layers.
Many companies use chi to write REST services for their public APIs. But, REST is just a convention
for managing state via HTTP, and there's a lot of other pieces required to write a complete client-server
system or network of microservices.
Looking beyond REST, I also recommend some newer works in the field:
* [webrpc](https://github.com/webrpc/webrpc) - Web-focused RPC client+server framework with code-gen
* [gRPC](https://github.com/grpc/grpc-go) - Google's RPC framework via protobufs
* [graphql](https://github.com/99designs/gqlgen) - Declarative query language
* [NATS](https://nats.io) - lightweight pub-sub
## License
Copyright (c) 2015-present [Peter Kieltyka](https://github.com/pkieltyka)
Licensed under [MIT License](./LICENSE)
[GoDoc]: https://pkg.go.dev/github.com/go-chi/chi?tab=versions
[GoDoc Widget]: https://godoc.org/github.com/go-chi/chi?status.svg
[Travis]: https://travis-ci.org/go-chi/chi
[Travis Widget]: https://travis-ci.org/go-chi/chi.svg?branch=master
chi-5.0.7/_examples/ 0000775 0000000 0000000 00000000000 14145546033 0014250 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/README.md 0000664 0000000 0000000 00000003327 14145546033 0015534 0 ustar 00root root 0000000 0000000 chi examples
============
* [custom-handler](https://github.com/go-chi/chi/blob/master/_examples/custom-handler/main.go) - Use a custom handler function signature
* [custom-method](https://github.com/go-chi/chi/blob/master/_examples/custom-method/main.go) - Add a custom HTTP method
* [fileserver](https://github.com/go-chi/chi/blob/master/_examples/fileserver/main.go) - Easily serve static files
* [graceful](https://github.com/go-chi/chi/blob/master/_examples/graceful/main.go) - Graceful context signaling and server shutdown
* [hello-world](https://github.com/go-chi/chi/blob/master/_examples/hello-world/main.go) - Hello World!
* [limits](https://github.com/go-chi/chi/blob/master/_examples/limits/main.go) - Timeouts and Throttling
* [logging](https://github.com/go-chi/chi/blob/master/_examples/logging/main.go) - Easy structured logging for any backend
* [rest](https://github.com/go-chi/chi/blob/master/_examples/rest/main.go) - REST APIs made easy, productive and maintainable
* [router-walk](https://github.com/go-chi/chi/blob/master/_examples/router-walk/main.go) - Print to stdout a router's routes
* [todos-resource](https://github.com/go-chi/chi/blob/master/_examples/todos-resource/main.go) - Struct routers/handlers, an example of another code layout style
* [versions](https://github.com/go-chi/chi/blob/master/_examples/versions/main.go) - Demo of `chi/render` subpkg
## Usage
1. `go get -v -d -u ./...` - fetch example deps
2. `cd /` ie. `cd rest/`
3. `go run *.go` - note, example services run on port 3333
4. Open another terminal and use curl to send some requests to your example service,
`curl -v http://localhost:3333/`
5. Read /main.go source to learn how service works and read comments for usage
chi-5.0.7/_examples/chi.svg 0000664 0000000 0000000 00000005132 14145546033 0015535 0 ustar 00root root 0000000 0000000
chi-5.0.7/_examples/custom-handler/ 0000775 0000000 0000000 00000000000 14145546033 0017175 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/custom-handler/main.go 0000664 0000000 0000000 00000001161 14145546033 0020447 0 ustar 00root root 0000000 0000000 package main
import (
"errors"
"net/http"
"github.com/go-chi/chi/v5"
)
type Handler func(w http.ResponseWriter, r *http.Request) error
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := h(w, r); err != nil {
// handle returned error here.
w.WriteHeader(503)
w.Write([]byte("bad"))
}
}
func main() {
r := chi.NewRouter()
r.Method("GET", "/", Handler(customHandler))
http.ListenAndServe(":3333", r)
}
func customHandler(w http.ResponseWriter, r *http.Request) error {
q := r.URL.Query().Get("err")
if q != "" {
return errors.New(q)
}
w.Write([]byte("foo"))
return nil
}
chi-5.0.7/_examples/custom-method/ 0000775 0000000 0000000 00000000000 14145546033 0017040 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/custom-method/main.go 0000664 0000000 0000000 00000001525 14145546033 0020316 0 ustar 00root root 0000000 0000000 package main
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func init() {
chi.RegisterMethod("LINK")
chi.RegisterMethod("UNLINK")
chi.RegisterMethod("WOOHOO")
}
func main() {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello world"))
})
r.MethodFunc("LINK", "/link", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom link method"))
})
r.MethodFunc("WOOHOO", "/woo", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom woohoo method"))
})
r.HandleFunc("/everything", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("capturing all standard http methods, as well as LINK, UNLINK and WOOHOO"))
})
http.ListenAndServe(":3333", r)
}
chi-5.0.7/_examples/fileserver/ 0000775 0000000 0000000 00000000000 14145546033 0016416 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/fileserver/data/ 0000775 0000000 0000000 00000000000 14145546033 0017327 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/fileserver/data/notes.txt 0000664 0000000 0000000 00000000013 14145546033 0021212 0 ustar 00root root 0000000 0000000 Notessszzz
chi-5.0.7/_examples/fileserver/main.go 0000664 0000000 0000000 00000003071 14145546033 0017672 0 ustar 00root root 0000000 0000000 //
// FileServer
// ===========
// This example demonstrates how to serve static files from your filesystem.
//
//
// Boot the server:
// ----------------
// $ go run main.go
//
// Client requests:
// ----------------
// $ curl http://localhost:3333/files/
//
// notes.txt
//
//
// $ curl http://localhost:3333/files/notes.txt
// Notessszzz
//
package main
import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.Logger)
// Index handler
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi"))
})
// Create a route along /files that will serve contents from
// the ./data/ folder.
workDir, _ := os.Getwd()
filesDir := http.Dir(filepath.Join(workDir, "data"))
FileServer(r, "/files", filesDir)
http.ListenAndServe(":3333", r)
}
// FileServer conveniently sets up a http.FileServer handler to serve
// static files from a http.FileSystem.
func FileServer(r chi.Router, path string, root http.FileSystem) {
if strings.ContainsAny(path, "{}*") {
panic("FileServer does not permit any URL parameters.")
}
if path != "/" && path[len(path)-1] != '/' {
r.Get(path, http.RedirectHandler(path+"/", 301).ServeHTTP)
path += "/"
}
path += "*"
r.Get(path, func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*")
fs := http.StripPrefix(pathPrefix, http.FileServer(root))
fs.ServeHTTP(w, r)
})
}
chi-5.0.7/_examples/graceful/ 0000775 0000000 0000000 00000000000 14145546033 0016040 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/graceful/main.go 0000664 0000000 0000000 00000003657 14145546033 0017326 0 ustar 00root root 0000000 0000000 package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
// The HTTP Server
server := &http.Server{Addr: "0.0.0.0:3333", Handler: service()}
// Server run context
serverCtx, serverStopCtx := context.WithCancel(context.Background())
// Listen for syscall signals for process to interrupt/quit
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
<-sig
// Shutdown signal with grace period of 30 seconds
shutdownCtx, _ := context.WithTimeout(serverCtx, 30*time.Second)
go func() {
<-shutdownCtx.Done()
if shutdownCtx.Err() == context.DeadlineExceeded {
log.Fatal("graceful shutdown timed out.. forcing exit.")
}
}()
// Trigger graceful shutdown
err := server.Shutdown(shutdownCtx)
if err != nil {
log.Fatal(err)
}
serverStopCtx()
}()
// Run the server
err := server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
log.Fatal(err)
}
// Wait for server context to be stopped
<-serverCtx.Done()
}
func service() http.Handler {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sup"))
})
r.Get("/slow", func(w http.ResponseWriter, r *http.Request) {
// Simulates some hard work.
//
// We want this handler to complete successfully during a shutdown signal,
// so consider the work here as some background routine to fetch a long running
// search query to find as many results as possible, but, instead we cut it short
// and respond with what we have so far. How a shutdown is handled is entirely
// up to the developer, as some code blocks are preemptable, and others are not.
time.Sleep(5 * time.Second)
w.Write([]byte(fmt.Sprintf("all done.\n")))
})
return r
}
chi-5.0.7/_examples/hello-world/ 0000775 0000000 0000000 00000000000 14145546033 0016500 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/hello-world/main.go 0000664 0000000 0000000 00000000547 14145546033 0017761 0 ustar 00root root 0000000 0000000 package main
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello world"))
})
http.ListenAndServe(":3333", r)
}
chi-5.0.7/_examples/limits/ 0000775 0000000 0000000 00000000000 14145546033 0015551 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/limits/main.go 0000664 0000000 0000000 00000004255 14145546033 0017032 0 ustar 00root root 0000000 0000000 //
// Limits
// ======
// This example demonstrates the use of Timeout, and Throttle middlewares.
//
// Timeout:
// cancel a request if processing takes longer than 2.5 seconds,
// server will respond with a http.StatusGatewayTimeout.
//
// Throttle:
// limit the number of in-flight requests along a particular
// routing path and backlog the others.
//
package main
import (
"context"
"fmt"
"math/rand"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root."))
})
r.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
})
r.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
panic("test")
})
// Slow handlers/operations.
r.Group(func(r chi.Router) {
// Stop processing after 2.5 seconds.
r.Use(middleware.Timeout(2500 * time.Millisecond))
r.Get("/slow", func(w http.ResponseWriter, r *http.Request) {
rand.Seed(time.Now().Unix())
// Processing will take 1-5 seconds.
processTime := time.Duration(rand.Intn(4)+1) * time.Second
select {
case <-r.Context().Done():
return
case <-time.After(processTime):
// The above channel simulates some hard work.
}
w.Write([]byte(fmt.Sprintf("Processed in %v seconds\n", processTime)))
})
})
// Throttle very expensive handlers/operations.
r.Group(func(r chi.Router) {
// Stop processing after 30 seconds.
r.Use(middleware.Timeout(30 * time.Second))
// Only one request will be processed at a time.
r.Use(middleware.Throttle(1))
r.Get("/throttled", func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
switch r.Context().Err() {
case context.DeadlineExceeded:
w.WriteHeader(504)
w.Write([]byte("Processing too slow\n"))
default:
w.Write([]byte("Canceled\n"))
}
return
case <-time.After(5 * time.Second):
// The above channel simulates some hard work.
}
w.Write([]byte("Processed\n"))
})
})
http.ListenAndServe(":3333", r)
}
chi-5.0.7/_examples/logging/ 0000775 0000000 0000000 00000000000 14145546033 0015676 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/logging/main.go 0000664 0000000 0000000 00000010122 14145546033 0017145 0 ustar 00root root 0000000 0000000 //
// Custom Structured Logger
// ========================
// This example demonstrates how to use middleware.RequestLogger,
// middleware.LogFormatter and middleware.LogEntry to build a structured
// logger using the amazing sirupsen/logrus package as the logging
// backend.
//
// Also: check out https://github.com/goware/httplog for an improved context
// logger with support for HTTP request logging, based on the example below.
//
package main
import (
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/sirupsen/logrus"
)
func main() {
// Setup the logger backend using sirupsen/logrus and configure
// it to use a custom JSONFormatter. See the logrus docs for how to
// configure the backend at github.com/sirupsen/logrus
logger := logrus.New()
logger.Formatter = &logrus.JSONFormatter{
// disable, as we set our own
DisableTimestamp: true,
}
// Routes
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(NewStructuredLogger(logger))
r.Use(middleware.Recoverer)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
r.Get("/wait", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
LogEntrySetField(r, "wait", true)
w.Write([]byte("hi"))
})
r.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
panic("oops")
})
http.ListenAndServe(":3333", r)
}
// StructuredLogger is a simple, but powerful implementation of a custom structured
// logger backed on logrus. I encourage users to copy it, adapt it and make it their
// own. Also take a look at https://github.com/pressly/lg for a dedicated pkg based
// on this work, designed for context-based http routers.
func NewStructuredLogger(logger *logrus.Logger) func(next http.Handler) http.Handler {
return middleware.RequestLogger(&StructuredLogger{logger})
}
type StructuredLogger struct {
Logger *logrus.Logger
}
func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry {
entry := &StructuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)}
logFields := logrus.Fields{}
logFields["ts"] = time.Now().UTC().Format(time.RFC1123)
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
logFields["req_id"] = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
logFields["http_scheme"] = scheme
logFields["http_proto"] = r.Proto
logFields["http_method"] = r.Method
logFields["remote_addr"] = r.RemoteAddr
logFields["user_agent"] = r.UserAgent()
logFields["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
entry.Logger = entry.Logger.WithFields(logFields)
entry.Logger.Infoln("request started")
return entry
}
type StructuredLoggerEntry struct {
Logger logrus.FieldLogger
}
func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) {
l.Logger = l.Logger.WithFields(logrus.Fields{
"resp_status": status, "resp_bytes_length": bytes,
"resp_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0,
})
l.Logger.Infoln("request complete")
}
func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) {
l.Logger = l.Logger.WithFields(logrus.Fields{
"stack": string(stack),
"panic": fmt.Sprintf("%+v", v),
})
}
// Helper methods used by the application to get the request-scoped
// logger entry and set additional fields between handlers.
//
// This is a useful pattern to use to set state on the entry as it
// passes through the handler chain, which at any point can be logged
// with a call to .Print(), .Info(), etc.
func GetLogEntry(r *http.Request) logrus.FieldLogger {
entry := middleware.GetLogEntry(r).(*StructuredLoggerEntry)
return entry.Logger
}
func LogEntrySetField(r *http.Request, key string, value interface{}) {
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
entry.Logger = entry.Logger.WithField(key, value)
}
}
func LogEntrySetFields(r *http.Request, fields map[string]interface{}) {
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
entry.Logger = entry.Logger.WithFields(fields)
}
}
chi-5.0.7/_examples/rest/ 0000775 0000000 0000000 00000000000 14145546033 0015225 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/rest/go.mod 0000664 0000000 0000000 00000000215 14145546033 0016331 0 ustar 00root root 0000000 0000000 module rest-example
go 1.16
require (
github.com/go-chi/chi/v5 v5.0.1
github.com/go-chi/docgen v1.2.0
github.com/go-chi/render v1.0.1
)
chi-5.0.7/_examples/rest/go.sum 0000664 0000000 0000000 00000001260 14145546033 0016357 0 ustar 00root root 0000000 0000000 github.com/go-chi/chi/v5 v5.0.1 h1:ALxjCrTf1aflOlkhMnCUP86MubbWFrzB3gkRPReLpTo=
github.com/go-chi/chi/v5 v5.0.1/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/docgen v1.2.0 h1:da0Nq2PKU9W9pSOTUfVrKI1vIgTGpauo9cfh4Iwivek=
github.com/go-chi/docgen v1.2.0/go.mod h1:G9W0G551cs2BFMSn/cnGwX+JBHEloAgo17MBhyrnhPI=
github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8=
github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
chi-5.0.7/_examples/rest/main.go 0000664 0000000 0000000 00000036304 14145546033 0016506 0 ustar 00root root 0000000 0000000 //
// REST
// ====
// This example demonstrates a HTTP REST web service with some fixture data.
// Follow along the example and patterns.
//
// Also check routes.json for the generated docs from passing the -routes flag,
// to run yourself do: `go run . -routes`
//
// Boot the server:
// ----------------
// $ go run main.go
//
// Client requests:
// ----------------
// $ curl http://localhost:3333/
// root.
//
// $ curl http://localhost:3333/articles
// [{"id":"1","title":"Hi"},{"id":"2","title":"sup"}]
//
// $ curl http://localhost:3333/articles/1
// {"id":"1","title":"Hi"}
//
// $ curl -X DELETE http://localhost:3333/articles/1
// {"id":"1","title":"Hi"}
//
// $ curl http://localhost:3333/articles/1
// "Not Found"
//
// $ curl -X POST -d '{"id":"will-be-omitted","title":"awesomeness"}' http://localhost:3333/articles
// {"id":"97","title":"awesomeness"}
//
// $ curl http://localhost:3333/articles/97
// {"id":"97","title":"awesomeness"}
//
// $ curl http://localhost:3333/articles
// [{"id":"2","title":"sup"},{"id":"97","title":"awesomeness"}]
//
package main
import (
"context"
"errors"
"flag"
"fmt"
"math/rand"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/docgen"
"github.com/go-chi/render"
)
var routes = flag.Bool("routes", false, "Generate router documentation")
func main() {
flag.Parse()
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.URLFormat)
r.Use(render.SetContentType(render.ContentTypeJSON))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root."))
})
r.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
})
r.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
panic("test")
})
// RESTy routes for "articles" resource
r.Route("/articles", func(r chi.Router) {
r.With(paginate).Get("/", ListArticles)
r.Post("/", CreateArticle) // POST /articles
r.Get("/search", SearchArticles) // GET /articles/search
r.Route("/{articleID}", func(r chi.Router) {
r.Use(ArticleCtx) // Load the *Article on the request context
r.Get("/", GetArticle) // GET /articles/123
r.Put("/", UpdateArticle) // PUT /articles/123
r.Delete("/", DeleteArticle) // DELETE /articles/123
})
// GET /articles/whats-up
r.With(ArticleCtx).Get("/{articleSlug:[a-z-]+}", GetArticle)
})
// Mount the admin sub-router, which btw is the same as:
// r.Route("/admin", func(r chi.Router) { admin routes here })
r.Mount("/admin", adminRouter())
// Passing -routes to the program will generate docs for the above
// router definition. See the `routes.json` file in this folder for
// the output.
if *routes {
// fmt.Println(docgen.JSONRoutesDoc(r))
fmt.Println(docgen.MarkdownRoutesDoc(r, docgen.MarkdownOpts{
ProjectPath: "github.com/go-chi/chi/v5",
Intro: "Welcome to the chi/_examples/rest generated docs.",
}))
return
}
http.ListenAndServe(":3333", r)
}
func ListArticles(w http.ResponseWriter, r *http.Request) {
if err := render.RenderList(w, r, NewArticleListResponse(articles)); err != nil {
render.Render(w, r, ErrRender(err))
return
}
}
// ArticleCtx middleware is used to load an Article object from
// the URL parameters passed through as the request. In case
// the Article could not be found, we stop here and return a 404.
func ArticleCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var article *Article
var err error
if articleID := chi.URLParam(r, "articleID"); articleID != "" {
article, err = dbGetArticle(articleID)
} else if articleSlug := chi.URLParam(r, "articleSlug"); articleSlug != "" {
article, err = dbGetArticleBySlug(articleSlug)
} else {
render.Render(w, r, ErrNotFound)
return
}
if err != nil {
render.Render(w, r, ErrNotFound)
return
}
ctx := context.WithValue(r.Context(), "article", article)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// SearchArticles searches the Articles data for a matching article.
// It's just a stub, but you get the idea.
func SearchArticles(w http.ResponseWriter, r *http.Request) {
render.RenderList(w, r, NewArticleListResponse(articles))
}
// CreateArticle persists the posted Article and returns it
// back to the client as an acknowledgement.
func CreateArticle(w http.ResponseWriter, r *http.Request) {
data := &ArticleRequest{}
if err := render.Bind(r, data); err != nil {
render.Render(w, r, ErrInvalidRequest(err))
return
}
article := data.Article
dbNewArticle(article)
render.Status(r, http.StatusCreated)
render.Render(w, r, NewArticleResponse(article))
}
// GetArticle returns the specific Article. You'll notice it just
// fetches the Article right off the context, as its understood that
// if we made it this far, the Article must be on the context. In case
// its not due to a bug, then it will panic, and our Recoverer will save us.
func GetArticle(w http.ResponseWriter, r *http.Request) {
// Assume if we've reach this far, we can access the article
// context because this handler is a child of the ArticleCtx
// middleware. The worst case, the recoverer middleware will save us.
article := r.Context().Value("article").(*Article)
if err := render.Render(w, r, NewArticleResponse(article)); err != nil {
render.Render(w, r, ErrRender(err))
return
}
}
// UpdateArticle updates an existing Article in our persistent store.
func UpdateArticle(w http.ResponseWriter, r *http.Request) {
article := r.Context().Value("article").(*Article)
data := &ArticleRequest{Article: article}
if err := render.Bind(r, data); err != nil {
render.Render(w, r, ErrInvalidRequest(err))
return
}
article = data.Article
dbUpdateArticle(article.ID, article)
render.Render(w, r, NewArticleResponse(article))
}
// DeleteArticle removes an existing Article from our persistent store.
func DeleteArticle(w http.ResponseWriter, r *http.Request) {
var err error
// Assume if we've reach this far, we can access the article
// context because this handler is a child of the ArticleCtx
// middleware. The worst case, the recoverer middleware will save us.
article := r.Context().Value("article").(*Article)
article, err = dbRemoveArticle(article.ID)
if err != nil {
render.Render(w, r, ErrInvalidRequest(err))
return
}
render.Render(w, r, NewArticleResponse(article))
}
// A completely separate router for administrator routes
func adminRouter() chi.Router {
r := chi.NewRouter()
r.Use(AdminOnly)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("admin: index"))
})
r.Get("/accounts", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("admin: list accounts.."))
})
r.Get("/users/{userId}", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("admin: view user id %v", chi.URLParam(r, "userId"))))
})
return r
}
// AdminOnly middleware restricts access to just administrators.
func AdminOnly(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
isAdmin, ok := r.Context().Value("acl.admin").(bool)
if !ok || !isAdmin {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// paginate is a stub, but very possible to implement middleware logic
// to handle the request params for handling a paginated request.
func paginate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// just a stub.. some ideas are to look at URL query params for something like
// the page number, or the limit, and send a query cursor down the chain
next.ServeHTTP(w, r)
})
}
// This is entirely optional, but I wanted to demonstrate how you could easily
// add your own logic to the render.Respond method.
func init() {
render.Respond = func(w http.ResponseWriter, r *http.Request, v interface{}) {
if err, ok := v.(error); ok {
// We set a default error status response code if one hasn't been set.
if _, ok := r.Context().Value(render.StatusCtxKey).(int); !ok {
w.WriteHeader(400)
}
// We log the error
fmt.Printf("Logging err: %s\n", err.Error())
// We change the response to not reveal the actual error message,
// instead we can transform the message something more friendly or mapped
// to some code / language, etc.
render.DefaultResponder(w, r, render.M{"status": "error"})
return
}
render.DefaultResponder(w, r, v)
}
}
//--
// Request and Response payloads for the REST api.
//
// The payloads embed the data model objects an
//
// In a real-world project, it would make sense to put these payloads
// in another file, or another sub-package.
//--
type UserPayload struct {
*User
Role string `json:"role"`
}
func NewUserPayloadResponse(user *User) *UserPayload {
return &UserPayload{User: user}
}
// Bind on UserPayload will run after the unmarshalling is complete, its
// a good time to focus some post-processing after a decoding.
func (u *UserPayload) Bind(r *http.Request) error {
return nil
}
func (u *UserPayload) Render(w http.ResponseWriter, r *http.Request) error {
u.Role = "collaborator"
return nil
}
// ArticleRequest is the request payload for Article data model.
//
// NOTE: It's good practice to have well defined request and response payloads
// so you can manage the specific inputs and outputs for clients, and also gives
// you the opportunity to transform data on input or output, for example
// on request, we'd like to protect certain fields and on output perhaps
// we'd like to include a computed field based on other values that aren't
// in the data model. Also, check out this awesome blog post on struct composition:
// http://attilaolah.eu/2014/09/10/json-and-struct-composition-in-go/
type ArticleRequest struct {
*Article
User *UserPayload `json:"user,omitempty"`
ProtectedID string `json:"id"` // override 'id' json to have more control
}
func (a *ArticleRequest) Bind(r *http.Request) error {
// a.Article is nil if no Article fields are sent in the request. Return an
// error to avoid a nil pointer dereference.
if a.Article == nil {
return errors.New("missing required Article fields.")
}
// a.User is nil if no Userpayload fields are sent in the request. In this app
// this won't cause a panic, but checks in this Bind method may be required if
// a.User or futher nested fields like a.User.Name are accessed elsewhere.
// just a post-process after a decode..
a.ProtectedID = "" // unset the protected ID
a.Article.Title = strings.ToLower(a.Article.Title) // as an example, we down-case
return nil
}
// ArticleResponse is the response payload for the Article data model.
// See NOTE above in ArticleRequest as well.
//
// In the ArticleResponse object, first a Render() is called on itself,
// then the next field, and so on, all the way down the tree.
// Render is called in top-down order, like a http handler middleware chain.
type ArticleResponse struct {
*Article
User *UserPayload `json:"user,omitempty"`
// We add an additional field to the response here.. such as this
// elapsed computed property
Elapsed int64 `json:"elapsed"`
}
func NewArticleResponse(article *Article) *ArticleResponse {
resp := &ArticleResponse{Article: article}
if resp.User == nil {
if user, _ := dbGetUser(resp.UserID); user != nil {
resp.User = NewUserPayloadResponse(user)
}
}
return resp
}
func (rd *ArticleResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
rd.Elapsed = 10
return nil
}
func NewArticleListResponse(articles []*Article) []render.Renderer {
list := []render.Renderer{}
for _, article := range articles {
list = append(list, NewArticleResponse(article))
}
return list
}
// NOTE: as a thought, the request and response payloads for an Article could be the
// same payload type, perhaps will do an example with it as well.
// type ArticlePayload struct {
// *Article
// }
//--
// Error response payloads & renderers
//--
// ErrResponse renderer type for handling all sorts of errors.
//
// In the best case scenario, the excellent github.com/pkg/errors package
// helps reveal information on the error, setting it on Err, and in the Render()
// method, using it to set the application-specific error code in AppCode.
type ErrResponse struct {
Err error `json:"-"` // low-level runtime error
HTTPStatusCode int `json:"-"` // http response status code
StatusText string `json:"status"` // user-level status message
AppCode int64 `json:"code,omitempty"` // application-specific error code
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
}
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
func ErrInvalidRequest(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 400,
StatusText: "Invalid request.",
ErrorText: err.Error(),
}
}
func ErrRender(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 422,
StatusText: "Error rendering response.",
ErrorText: err.Error(),
}
}
var ErrNotFound = &ErrResponse{HTTPStatusCode: 404, StatusText: "Resource not found."}
//--
// Data model objects and persistence mocks:
//--
// User data model
type User struct {
ID int64 `json:"id"`
Name string `json:"name"`
}
// Article data model. I suggest looking at https://upper.io for an easy
// and powerful data persistence adapter.
type Article struct {
ID string `json:"id"`
UserID int64 `json:"user_id"` // the author
Title string `json:"title"`
Slug string `json:"slug"`
}
// Article fixture data
var articles = []*Article{
{ID: "1", UserID: 100, Title: "Hi", Slug: "hi"},
{ID: "2", UserID: 200, Title: "sup", Slug: "sup"},
{ID: "3", UserID: 300, Title: "alo", Slug: "alo"},
{ID: "4", UserID: 400, Title: "bonjour", Slug: "bonjour"},
{ID: "5", UserID: 500, Title: "whats up", Slug: "whats-up"},
}
// User fixture data
var users = []*User{
{ID: 100, Name: "Peter"},
{ID: 200, Name: "Julia"},
}
func dbNewArticle(article *Article) (string, error) {
article.ID = fmt.Sprintf("%d", rand.Intn(100)+10)
articles = append(articles, article)
return article.ID, nil
}
func dbGetArticle(id string) (*Article, error) {
for _, a := range articles {
if a.ID == id {
return a, nil
}
}
return nil, errors.New("article not found.")
}
func dbGetArticleBySlug(slug string) (*Article, error) {
for _, a := range articles {
if a.Slug == slug {
return a, nil
}
}
return nil, errors.New("article not found.")
}
func dbUpdateArticle(id string, article *Article) (*Article, error) {
for i, a := range articles {
if a.ID == id {
articles[i] = article
return article, nil
}
}
return nil, errors.New("article not found.")
}
func dbRemoveArticle(id string) (*Article, error) {
for i, a := range articles {
if a.ID == id {
articles = append((articles)[:i], (articles)[i+1:]...)
return a, nil
}
}
return nil, errors.New("article not found.")
}
func dbGetUser(id int64) (*User, error) {
for _, u := range users {
if u.ID == id {
return u, nil
}
}
return nil, errors.New("user not found.")
}
chi-5.0.7/_examples/rest/routes.json 0000664 0000000 0000000 00000026065 14145546033 0017452 0 ustar 00root root 0000000 0000000 {
"router": {
"middlewares": [
{
"pkg": "github.com/go-chi/chi/v5/middleware",
"func": "RequestID",
"comment": "RequestID is a middleware that injects a request ID into the context of each\nrequest. A request ID is a string of the form \"host.example.com/random-0001\",\nwhere \"random\" is a base62 random string that uniquely identifies this go\nprocess, and where the last number is an atomically incremented request\ncounter.\n",
"file": "github.com/go-chi/chi/middleware/request_id.go",
"line": 63
},
{
"pkg": "github.com/go-chi/chi/v5/middleware",
"func": "Logger",
"comment": "Logger is a middleware that logs the start and end of each request, along\nwith some useful data about what was requested, what the response status was,\nand how long it took to return. When standard output is a TTY, Logger will\nprint in color, otherwise it will print in black and white. Logger prints a\nrequest ID if one is provided.\n\nAlternatively, look at https://github.com/pressly/lg and the `lg.RequestLogger`\nmiddleware pkg.\n",
"file": "github.com/go-chi/chi/middleware/logger.go",
"line": 26
},
{
"pkg": "github.com/go-chi/chi/v5/middleware",
"func": "Recoverer",
"comment": "Recoverer is a middleware that recovers from panics, logs the panic (and a\nbacktrace), and returns a HTTP 500 (Internal Server Error) status if\npossible. Recoverer prints a request ID if one is provided.\n\nAlternatively, look at https://github.com/pressly/lg middleware pkgs.\n",
"file": "github.com/go-chi/chi/middleware/recoverer.go",
"line": 18
},
{
"pkg": "github.com/go-chi/chi/v5/middleware",
"func": "URLFormat",
"comment": "URLFormat is a middleware that parses the url extension from a request path and stores it\non the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will\ntrim the suffix from the routing path and continue routing.\n\nRouters should not include a url parameter for the suffix when using this middleware.\n\nSample usage.. for url paths: `/articles/1`, `/articles/1.json` and `/articles/1.xml`\n\n func routes() http.Handler {\n r := chi.NewRouter()\n r.Use(middleware.URLFormat)\n\n r.Get(\"/articles/{id}\", ListArticles)\n\n return r\n }\n\n func ListArticles(w http.ResponseWriter, r *http.Request) {\n\t urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string)\n\n\t switch urlFormat {\n\t case \"json\":\n\t \trender.JSON(w, r, articles)\n\t case \"xml:\"\n\t \trender.XML(w, r, articles)\n\t default:\n\t \trender.JSON(w, r, articles)\n\t }\n}\n",
"file": "github.com/go-chi/chi/middleware/url_format.go",
"line": 45
},
{
"pkg": "github.com/go-chi/render",
"func": "SetContentType.func1",
"comment": "",
"file": "github.com/go-chi/render/content_type.go",
"line": 49,
"anonymous": true
}
],
"routes": {
"/": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.main.func1",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 69,
"anonymous": true
}
}
},
"/admin/*": {
"router": {
"middlewares": [
{
"pkg": "",
"func": "main.AdminOnly",
"comment": "AdminOnly middleware restricts access to just administrators.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 238
}
],
"routes": {
"/": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.adminRouter.func1",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 225,
"anonymous": true
}
}
},
"/accounts": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.adminRouter.func2",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 228,
"anonymous": true
}
}
},
"/users/{userId}": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.adminRouter.func3",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 231,
"anonymous": true
}
}
}
}
}
},
"/articles/*": {
"router": {
"middlewares": [],
"routes": {
"/": {
"handlers": {
"GET": {
"middlewares": [
{
"pkg": "",
"func": "main.paginate",
"comment": "paginate is a stub, but very possible to implement middleware logic\nto handle the request params for handling a paginated request.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 251
}
],
"method": "GET",
"pkg": "",
"func": "main.ListArticles",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 117
},
"POST": {
"middlewares": [],
"method": "POST",
"pkg": "",
"func": "main.CreateArticle",
"comment": "CreateArticle persists the posted Article and returns it\nback to the client as an acknowledgement.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 158
}
}
},
"/search": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.SearchArticles",
"comment": "SearchArticles searches the Articles data for a matching article.\nIt's just a stub, but you get the idea.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 152
}
}
},
"/{articleID}/*": {
"router": {
"middlewares": [
{
"pkg": "",
"func": "main.ArticleCtx",
"comment": "ArticleCtx middleware is used to load an Article object from\nthe URL parameters passed through as the request. In case\nthe Article could not be found, we stop here and return a 404.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 127
}
],
"routes": {
"/": {
"handlers": {
"DELETE": {
"middlewares": [],
"method": "DELETE",
"pkg": "",
"func": "main.DeleteArticle",
"comment": "DeleteArticle removes an existing Article from our persistent store.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 204
},
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.GetArticle",
"comment": "GetArticle returns the specific Article. You'll notice it just\nfetches the Article right off the context, as its understood that\nif we made it this far, the Article must be on the context. In case\nits not due to a bug, then it will panic, and our Recoverer will save us.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 176
},
"PUT": {
"middlewares": [],
"method": "PUT",
"pkg": "",
"func": "main.UpdateArticle",
"comment": "UpdateArticle updates an existing Article in our persistent store.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 189
}
}
}
}
}
},
"/{articleSlug:[a-z-]+}": {
"handlers": {
"GET": {
"middlewares": [
{
"pkg": "",
"func": "main.ArticleCtx",
"comment": "ArticleCtx middleware is used to load an Article object from\nthe URL parameters passed through as the request. In case\nthe Article could not be found, we stop here and return a 404.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 127
}
],
"method": "GET",
"pkg": "",
"func": "main.GetArticle",
"comment": "GetArticle returns the specific Article. You'll notice it just\nfetches the Article right off the context, as its understood that\nif we made it this far, the Article must be on the context. In case\nits not due to a bug, then it will panic, and our Recoverer will save us.\n",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 176
}
}
}
}
}
},
"/panic": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.main.func3",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 77,
"anonymous": true
}
}
},
"/ping": {
"handlers": {
"GET": {
"middlewares": [],
"method": "GET",
"pkg": "",
"func": "main.main.func2",
"comment": "",
"file": "github.com/go-chi/chi/_examples/rest/main.go",
"line": 73,
"anonymous": true
}
}
}
}
}
}
chi-5.0.7/_examples/rest/routes.md 0000664 0000000 0000000 00000011071 14145546033 0017070 0 ustar 00root root 0000000 0000000 # github.com/go-chi/chi
Welcome to the chi/_examples/rest generated docs.
## Routes
`/`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/**
- _GET_
- [main.main.func1](/_examples/rest/main.go#L69)
`/admin/*`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/admin/***
- [main.AdminOnly](/_examples/rest/main.go#L238)
- **/**
- _GET_
- [main.adminRouter.func1](/_examples/rest/main.go#L225)
`/admin/*/accounts`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/admin/***
- [main.AdminOnly](/_examples/rest/main.go#L238)
- **/accounts**
- _GET_
- [main.adminRouter.func2](/_examples/rest/main.go#L228)
`/admin/*/users/{userId}`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/admin/***
- [main.AdminOnly](/_examples/rest/main.go#L238)
- **/users/{userId}**
- _GET_
- [main.adminRouter.func3](/_examples/rest/main.go#L231)
`/articles/*`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/articles/***
- **/**
- _GET_
- [main.paginate](/_examples/rest/main.go#L251)
- [main.ListArticles](/_examples/rest/main.go#L117)
- _POST_
- [main.CreateArticle](/_examples/rest/main.go#L158)
`/articles/*/search`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/articles/***
- **/search**
- _GET_
- [main.SearchArticles](/_examples/rest/main.go#L152)
`/articles/*/{articleID}/*`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/articles/***
- **/{articleID}/***
- [main.ArticleCtx](/_examples/rest/main.go#L127)
- **/**
- _DELETE_
- [main.DeleteArticle](/_examples/rest/main.go#L204)
- _GET_
- [main.GetArticle](/_examples/rest/main.go#L176)
- _PUT_
- [main.UpdateArticle](/_examples/rest/main.go#L189)
`/articles/*/{articleSlug:[a-z-]+}`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/articles/***
- **/{articleSlug:[a-z-]+}**
- _GET_
- [main.ArticleCtx](/_examples/rest/main.go#L127)
- [main.GetArticle](/_examples/rest/main.go#L176)
`/panic`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/panic**
- _GET_
- [main.main.func3](/_examples/rest/main.go#L77)
`/ping`
- [RequestID](/middleware/request_id.go#L63)
- [Logger](/middleware/logger.go#L26)
- [Recoverer](/middleware/recoverer.go#L18)
- [URLFormat](/middleware/url_format.go#L45)
- [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49)
- **/ping**
- _GET_
- [main.main.func2](/_examples/rest/main.go#L73)
Total # of routes: 10
chi-5.0.7/_examples/router-walk/ 0000775 0000000 0000000 00000000000 14145546033 0016524 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/router-walk/main.go 0000664 0000000 0000000 00000001622 14145546033 0020000 0 ustar 00root root 0000000 0000000 package main
import (
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
)
func main() {
r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root."))
})
r.Route("/road", func(r chi.Router) {
r.Get("/left", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("left road"))
})
r.Post("/right", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("right road"))
})
})
r.Put("/ping", Ping)
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
route = strings.Replace(route, "/*/", "/", -1)
fmt.Printf("%s %s\n", method, route)
return nil
}
if err := chi.Walk(r, walkFunc); err != nil {
fmt.Printf("Logging err: %s\n", err.Error())
}
}
// Ping returns pong
func Ping(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
}
chi-5.0.7/_examples/todos-resource/ 0000775 0000000 0000000 00000000000 14145546033 0017225 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/todos-resource/main.go 0000664 0000000 0000000 00000001426 14145546033 0020503 0 ustar 00root root 0000000 0000000 //
// Todos Resource
// ==============
// This example demonstrates a project structure that defines a subrouter and its
// handlers on a struct, and mounting them as subrouters to a parent router.
// See also _examples/rest for an in-depth example of a REST service, and apply
// those same patterns to this structure.
//
package main
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("."))
})
r.Mount("/users", usersResource{}.Routes())
r.Mount("/todos", todosResource{}.Routes())
http.ListenAndServe(":3333", r)
}
chi-5.0.7/_examples/todos-resource/todos.go 0000664 0000000 0000000 00000002655 14145546033 0020714 0 ustar 00root root 0000000 0000000 package main
import (
"net/http"
"github.com/go-chi/chi/v5"
)
type todosResource struct{}
// Routes creates a REST router for the todos resource
func (rs todosResource) Routes() chi.Router {
r := chi.NewRouter()
// r.Use() // some middleware..
r.Get("/", rs.List) // GET /todos - read a list of todos
r.Post("/", rs.Create) // POST /todos - create a new todo and persist it
r.Put("/", rs.Delete)
r.Route("/{id}", func(r chi.Router) {
// r.Use(rs.TodoCtx) // lets have a todos map, and lets actually load/manipulate
r.Get("/", rs.Get) // GET /todos/{id} - read a single todo by :id
r.Put("/", rs.Update) // PUT /todos/{id} - update a single todo by :id
r.Delete("/", rs.Delete) // DELETE /todos/{id} - delete a single todo by :id
r.Get("/sync", rs.Sync)
})
return r
}
func (rs todosResource) List(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todos list of stuff.."))
}
func (rs todosResource) Create(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todos create"))
}
func (rs todosResource) Get(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todo get"))
}
func (rs todosResource) Update(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todo update"))
}
func (rs todosResource) Delete(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todo delete"))
}
func (rs todosResource) Sync(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("todo sync"))
}
chi-5.0.7/_examples/todos-resource/users.go 0000664 0000000 0000000 00000002453 14145546033 0020721 0 ustar 00root root 0000000 0000000 package main
import (
"net/http"
"github.com/go-chi/chi/v5"
)
type usersResource struct{}
// Routes creates a REST router for the todos resource
func (rs usersResource) Routes() chi.Router {
r := chi.NewRouter()
// r.Use() // some middleware..
r.Get("/", rs.List) // GET /users - read a list of users
r.Post("/", rs.Create) // POST /users - create a new user and persist it
r.Put("/", rs.Delete)
r.Route("/{id}", func(r chi.Router) {
// r.Use(rs.TodoCtx) // lets have a users map, and lets actually load/manipulate
r.Get("/", rs.Get) // GET /users/{id} - read a single user by :id
r.Put("/", rs.Update) // PUT /users/{id} - update a single user by :id
r.Delete("/", rs.Delete) // DELETE /users/{id} - delete a single user by :id
})
return r
}
func (rs usersResource) List(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("users list of stuff.."))
}
func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("users create"))
}
func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("user get"))
}
func (rs usersResource) Update(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("user update"))
}
func (rs usersResource) Delete(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("user delete"))
}
chi-5.0.7/_examples/versions/ 0000775 0000000 0000000 00000000000 14145546033 0016120 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/data/ 0000775 0000000 0000000 00000000000 14145546033 0017031 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/data/article.go 0000664 0000000 0000000 00000000604 14145546033 0021003 0 ustar 00root root 0000000 0000000 package data
// Article is runtime object, that's not meant to be sent via REST.
type Article struct {
ID int `db:"id" json:"id" xml:"id"`
Title string `db:"title" json:"title" xml:"title"`
Data []string `db:"data,stringarray" json:"data" xml:"data"`
CustomDataForAuthUsers string `db:"custom_data" json:"-" xml:"-"`
}
chi-5.0.7/_examples/versions/data/errors.go 0000664 0000000 0000000 00000001036 14145546033 0020674 0 ustar 00root root 0000000 0000000 package data
import (
"errors"
"net/http"
"github.com/go-chi/render"
)
var (
ErrUnauthorized = errors.New("Unauthorized")
ErrForbidden = errors.New("Forbidden")
ErrNotFound = errors.New("Resource not found")
)
func PresentError(r *http.Request, err error) (*http.Request, interface{}) {
switch err {
case ErrUnauthorized:
render.Status(r, 401)
case ErrForbidden:
render.Status(r, 403)
case ErrNotFound:
render.Status(r, 404)
default:
render.Status(r, 500)
}
return r, map[string]string{"error": err.Error()}
}
chi-5.0.7/_examples/versions/main.go 0000664 0000000 0000000 00000010006 14145546033 0017370 0 ustar 00root root 0000000 0000000 //
// Versions
// ========
// This example demonstrates the use of the render subpackage, with
// a quick concept for how to support multiple api versions.
//
package main
import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/_examples/versions/data"
v1 "github.com/go-chi/chi/v5/_examples/versions/presenter/v1"
v2 "github.com/go-chi/chi/v5/_examples/versions/presenter/v2"
v3 "github.com/go-chi/chi/v5/_examples/versions/presenter/v3"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/render"
)
func main() {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
// API version 3.
r.Route("/v3", func(r chi.Router) {
r.Use(apiVersionCtx("v3"))
r.Mount("/articles", articleRouter())
})
// API version 2.
r.Route("/v2", func(r chi.Router) {
r.Use(apiVersionCtx("v2"))
r.Mount("/articles", articleRouter())
})
// API version 1.
r.Route("/v1", func(r chi.Router) {
r.Use(randomErrorMiddleware) // Simulate random error, ie. version 1 is buggy.
r.Use(apiVersionCtx("v1"))
r.Mount("/articles", articleRouter())
})
http.ListenAndServe(":3333", r)
}
func apiVersionCtx(version string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), "api.version", version))
next.ServeHTTP(w, r)
})
}
}
func articleRouter() http.Handler {
r := chi.NewRouter()
r.Get("/", listArticles)
r.Route("/{articleID}", func(r chi.Router) {
r.Get("/", getArticle)
// r.Put("/", updateArticle)
// r.Delete("/", deleteArticle)
})
return r
}
func listArticles(w http.ResponseWriter, r *http.Request) {
articles := make(chan render.Renderer, 5)
// Load data asynchronously into the channel (simulate slow storage):
go func() {
for i := 1; i <= 10; i++ {
article := &data.Article{
ID: i,
Title: fmt.Sprintf("Article #%v", i),
Data: []string{"one", "two", "three", "four"},
CustomDataForAuthUsers: "secret data for auth'd users only",
}
apiVersion := r.Context().Value("api.version").(string)
switch apiVersion {
case "v1":
articles <- v1.NewArticleResponse(article)
case "v2":
articles <- v2.NewArticleResponse(article)
default:
articles <- v3.NewArticleResponse(article)
}
time.Sleep(100 * time.Millisecond)
}
close(articles)
}()
// Start streaming data from the channel.
render.Respond(w, r, articles)
}
func getArticle(w http.ResponseWriter, r *http.Request) {
// Load article.
if chi.URLParam(r, "articleID") != "1" {
render.Respond(w, r, data.ErrNotFound)
return
}
article := &data.Article{
ID: 1,
Title: "Article #1",
Data: []string{"one", "two", "three", "four"},
CustomDataForAuthUsers: "secret data for auth'd users only",
}
// Simulate some context values:
// 1. ?auth=true simluates authenticated session/user.
// 2. ?error=true simulates random error.
if r.URL.Query().Get("auth") != "" {
r = r.WithContext(context.WithValue(r.Context(), "auth", true))
}
if r.URL.Query().Get("error") != "" {
render.Respond(w, r, errors.New("error"))
return
}
var payload render.Renderer
apiVersion := r.Context().Value("api.version").(string)
switch apiVersion {
case "v1":
payload = v1.NewArticleResponse(article)
case "v2":
payload = v2.NewArticleResponse(article)
default:
payload = v3.NewArticleResponse(article)
}
render.Render(w, r, payload)
}
func randomErrorMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rand.Seed(time.Now().Unix())
// One in three chance of random error.
if rand.Int31n(3) == 0 {
errors := []error{data.ErrUnauthorized, data.ErrForbidden, data.ErrNotFound}
render.Respond(w, r, errors[rand.Intn(len(errors))])
return
}
next.ServeHTTP(w, r)
})
}
chi-5.0.7/_examples/versions/presenter/ 0000775 0000000 0000000 00000000000 14145546033 0020127 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/presenter/v1/ 0000775 0000000 0000000 00000000000 14145546033 0020455 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/presenter/v1/article.go 0000664 0000000 0000000 00000000617 14145546033 0022433 0 ustar 00root root 0000000 0000000 package v1
import (
"net/http"
"github.com/go-chi/chi/v5/_examples/versions/data"
)
// Article presented in API version 1.
type Article struct {
*data.Article
Data map[string]bool `json:"data" xml:"data"`
}
func (a *Article) Render(w http.ResponseWriter, r *http.Request) error {
return nil
}
func NewArticleResponse(article *data.Article) *Article {
return &Article{Article: article}
}
chi-5.0.7/_examples/versions/presenter/v2/ 0000775 0000000 0000000 00000000000 14145546033 0020456 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/presenter/v2/article.go 0000664 0000000 0000000 00000001161 14145546033 0022427 0 ustar 00root root 0000000 0000000 package v2
import (
"fmt"
"net/http"
"github.com/go-chi/chi/v5/_examples/versions/data"
)
// Article presented in API version 2.
type Article struct {
// *v3.Article `json:",inline" xml:",inline"`
*data.Article
// Additional fields.
SelfURL string `json:"self_url" xml:"self_url"`
// Omitted fields.
URL interface{} `json:"url,omitempty" xml:"url,omitempty"`
}
func (a *Article) Render(w http.ResponseWriter, r *http.Request) error {
a.SelfURL = fmt.Sprintf("http://localhost:3333/v2?id=%v", a.ID)
return nil
}
func NewArticleResponse(article *data.Article) *Article {
return &Article{Article: article}
}
chi-5.0.7/_examples/versions/presenter/v3/ 0000775 0000000 0000000 00000000000 14145546033 0020457 5 ustar 00root root 0000000 0000000 chi-5.0.7/_examples/versions/presenter/v3/article.go 0000664 0000000 0000000 00000001752 14145546033 0022436 0 ustar 00root root 0000000 0000000 package v3
import (
"fmt"
"math/rand"
"net/http"
"github.com/go-chi/chi/v5/_examples/versions/data"
)
// Article presented in API version 2.
type Article struct {
*data.Article `json:",inline" xml:",inline"`
// Additional fields.
URL string `json:"url" xml:"url"`
ViewsCount int64 `json:"views_count" xml:"views_count"`
APIVersion string `json:"api_version" xml:"api_version"`
// Omitted fields.
// Show custom_data explicitly for auth'd users only.
CustomDataForAuthUsers interface{} `json:"custom_data,omitempty" xml:"custom_data,omitempty"`
}
func (a *Article) Render(w http.ResponseWriter, r *http.Request) error {
a.ViewsCount = rand.Int63n(100000)
a.URL = fmt.Sprintf("http://localhost:3333/v3/?id=%v", a.ID)
// Only show to auth'd user.
if _, ok := r.Context().Value("auth").(bool); ok {
a.CustomDataForAuthUsers = a.Article.CustomDataForAuthUsers
}
return nil
}
func NewArticleResponse(article *data.Article) *Article {
return &Article{Article: article}
}
chi-5.0.7/chain.go 0000664 0000000 0000000 00000002755 14145546033 0013715 0 ustar 00root root 0000000 0000000 package chi
import "net/http"
// Chain returns a Middlewares type from a slice of middleware handlers.
func Chain(middlewares ...func(http.Handler) http.Handler) Middlewares {
return Middlewares(middlewares)
}
// Handler builds and returns a http.Handler from the chain of middlewares,
// with `h http.Handler` as the final handler.
func (mws Middlewares) Handler(h http.Handler) http.Handler {
return &ChainHandler{h, chain(mws, h), mws}
}
// HandlerFunc builds and returns a http.Handler from the chain of middlewares,
// with `h http.Handler` as the final handler.
func (mws Middlewares) HandlerFunc(h http.HandlerFunc) http.Handler {
return &ChainHandler{h, chain(mws, h), mws}
}
// ChainHandler is a http.Handler with support for handler composition and
// execution.
type ChainHandler struct {
Endpoint http.Handler
chain http.Handler
Middlewares Middlewares
}
func (c *ChainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.chain.ServeHTTP(w, r)
}
// chain builds a http.Handler composed of an inline middleware stack and endpoint
// handler in the order they are passed.
func chain(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler {
// Return ahead of time if there aren't any middlewares for the chain
if len(middlewares) == 0 {
return endpoint
}
// Wrap the end handler with the middleware chain
h := middlewares[len(middlewares)-1](endpoint)
for i := len(middlewares) - 2; i >= 0; i-- {
h = middlewares[i](h)
}
return h
}
chi-5.0.7/chi.go 0000664 0000000 0000000 00000011025 14145546033 0013364 0 ustar 00root root 0000000 0000000 //
// Package chi is a small, idiomatic and composable router for building HTTP services.
//
// chi requires Go 1.10 or newer.
//
// Example:
// package main
//
// import (
// "net/http"
//
// "github.com/go-chi/chi/v5"
// "github.com/go-chi/chi/v5/middleware"
// )
//
// func main() {
// r := chi.NewRouter()
// r.Use(middleware.Logger)
// r.Use(middleware.Recoverer)
//
// r.Get("/", func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("root."))
// })
//
// http.ListenAndServe(":3333", r)
// }
//
// See github.com/go-chi/chi/_examples/ for more in-depth examples.
//
// URL patterns allow for easy matching of path components in HTTP
// requests. The matching components can then be accessed using
// chi.URLParam(). All patterns must begin with a slash.
//
// A simple named placeholder {name} matches any sequence of characters
// up to the next / or the end of the URL. Trailing slashes on paths must
// be handled explicitly.
//
// A placeholder with a name followed by a colon allows a regular
// expression match, for example {number:\\d+}. The regular expression
// syntax is Go's normal regexp RE2 syntax, except that regular expressions
// including { or } are not supported, and / will never be
// matched. An anonymous regexp pattern is allowed, using an empty string
// before the colon in the placeholder, such as {:\\d+}
//
// The special placeholder of asterisk matches the rest of the requested
// URL. Any trailing characters in the pattern are ignored. This is the only
// placeholder which will match / characters.
//
// Examples:
// "/user/{name}" matches "/user/jsmith" but not "/user/jsmith/info" or "/user/jsmith/"
// "/user/{name}/info" matches "/user/jsmith/info"
// "/page/*" matches "/page/intro/latest"
// "/page/*/index" also matches "/page/intro/latest"
// "/date/{yyyy:\\d\\d\\d\\d}/{mm:\\d\\d}/{dd:\\d\\d}" matches "/date/2017/04/01"
//
package chi
import "net/http"
// NewRouter returns a new Mux object that implements the Router interface.
func NewRouter() *Mux {
return NewMux()
}
// Router consisting of the core routing methods used by chi's Mux,
// using only the standard net/http.
type Router interface {
http.Handler
Routes
// Use appends one or more middlewares onto the Router stack.
Use(middlewares ...func(http.Handler) http.Handler)
// With adds inline middlewares for an endpoint handler.
With(middlewares ...func(http.Handler) http.Handler) Router
// Group adds a new inline-Router along the current routing
// path, with a fresh middleware stack for the inline-Router.
Group(fn func(r Router)) Router
// Route mounts a sub-Router along a `pattern`` string.
Route(pattern string, fn func(r Router)) Router
// Mount attaches another http.Handler along ./pattern/*
Mount(pattern string, h http.Handler)
// Handle and HandleFunc adds routes for `pattern` that matches
// all HTTP methods.
Handle(pattern string, h http.Handler)
HandleFunc(pattern string, h http.HandlerFunc)
// Method and MethodFunc adds routes for `pattern` that matches
// the `method` HTTP method.
Method(method, pattern string, h http.Handler)
MethodFunc(method, pattern string, h http.HandlerFunc)
// HTTP-method routing along `pattern`
Connect(pattern string, h http.HandlerFunc)
Delete(pattern string, h http.HandlerFunc)
Get(pattern string, h http.HandlerFunc)
Head(pattern string, h http.HandlerFunc)
Options(pattern string, h http.HandlerFunc)
Patch(pattern string, h http.HandlerFunc)
Post(pattern string, h http.HandlerFunc)
Put(pattern string, h http.HandlerFunc)
Trace(pattern string, h http.HandlerFunc)
// NotFound defines a handler to respond whenever a route could
// not be found.
NotFound(h http.HandlerFunc)
// MethodNotAllowed defines a handler to respond whenever a method is
// not allowed.
MethodNotAllowed(h http.HandlerFunc)
}
// Routes interface adds two methods for router traversal, which is also
// used by the `docgen` subpackage to generation documentation for Routers.
type Routes interface {
// Routes returns the routing tree in an easily traversable structure.
Routes() []Route
// Middlewares returns the list of middlewares in use by the router.
Middlewares() Middlewares
// Match searches the routing tree for a handler that matches
// the method/path - similar to routing a http request, but without
// executing the handler thereafter.
Match(rctx *Context, method, path string) bool
}
// Middlewares type is a slice of standard middleware handlers with methods
// to compose middleware chains and http.Handler's.
type Middlewares []func(http.Handler) http.Handler
chi-5.0.7/context.go 0000664 0000000 0000000 00000010734 14145546033 0014313 0 ustar 00root root 0000000 0000000 package chi
import (
"context"
"net/http"
"strings"
)
// URLParam returns the url parameter from a http.Request object.
func URLParam(r *http.Request, key string) string {
if rctx := RouteContext(r.Context()); rctx != nil {
return rctx.URLParam(key)
}
return ""
}
// URLParamFromCtx returns the url parameter from a http.Request Context.
func URLParamFromCtx(ctx context.Context, key string) string {
if rctx := RouteContext(ctx); rctx != nil {
return rctx.URLParam(key)
}
return ""
}
// RouteContext returns chi's routing Context object from a
// http.Request Context.
func RouteContext(ctx context.Context) *Context {
val, _ := ctx.Value(RouteCtxKey).(*Context)
return val
}
// NewRouteContext returns a new routing Context object.
func NewRouteContext() *Context {
return &Context{}
}
var (
// RouteCtxKey is the context.Context key to store the request context.
RouteCtxKey = &contextKey{"RouteContext"}
)
// Context is the default routing context set on the root node of a
// request context to track route patterns, URL parameters and
// an optional routing path.
type Context struct {
Routes Routes
// parentCtx is the parent of this one, for using Context as a
// context.Context directly. This is an optimization that saves
// 1 allocation.
parentCtx context.Context
// Routing path/method override used during the route search.
// See Mux#routeHTTP method.
RoutePath string
RouteMethod string
// URLParams are the stack of routeParams captured during the
// routing lifecycle across a stack of sub-routers.
URLParams RouteParams
// Route parameters matched for the current sub-router. It is
// intentionally unexported so it cant be tampered.
routeParams RouteParams
// The endpoint routing pattern that matched the request URI path
// or `RoutePath` of the current sub-router. This value will update
// during the lifecycle of a request passing through a stack of
// sub-routers.
routePattern string
// Routing pattern stack throughout the lifecycle of the request,
// across all connected routers. It is a record of all matching
// patterns across a stack of sub-routers.
RoutePatterns []string
// methodNotAllowed hint
methodNotAllowed bool
}
// Reset a routing context to its initial state.
func (x *Context) Reset() {
x.Routes = nil
x.RoutePath = ""
x.RouteMethod = ""
x.RoutePatterns = x.RoutePatterns[:0]
x.URLParams.Keys = x.URLParams.Keys[:0]
x.URLParams.Values = x.URLParams.Values[:0]
x.routePattern = ""
x.routeParams.Keys = x.routeParams.Keys[:0]
x.routeParams.Values = x.routeParams.Values[:0]
x.methodNotAllowed = false
x.parentCtx = nil
}
// URLParam returns the corresponding URL parameter value from the request
// routing context.
func (x *Context) URLParam(key string) string {
for k := len(x.URLParams.Keys) - 1; k >= 0; k-- {
if x.URLParams.Keys[k] == key {
return x.URLParams.Values[k]
}
}
return ""
}
// RoutePattern builds the routing pattern string for the particular
// request, at the particular point during routing. This means, the value
// will change throughout the execution of a request in a router. That is
// why its advised to only use this value after calling the next handler.
//
// For example,
//
// func Instrument(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// next.ServeHTTP(w, r)
// routePattern := chi.RouteContext(r.Context()).RoutePattern()
// measure(w, r, routePattern)
// })
// }
func (x *Context) RoutePattern() string {
routePattern := strings.Join(x.RoutePatterns, "")
return replaceWildcards(routePattern)
}
// replaceWildcards takes a route pattern and recursively replaces all
// occurrences of "/*/" to "/".
func replaceWildcards(p string) string {
if strings.Contains(p, "/*/") {
return replaceWildcards(strings.Replace(p, "/*/", "/", -1))
}
return p
}
// RouteParams is a structure to track URL routing parameters efficiently.
type RouteParams struct {
Keys, Values []string
}
// Add will append a URL parameter to the end of the route param
func (s *RouteParams) Add(key, value string) {
s.Keys = append(s.Keys, key)
s.Values = append(s.Values, value)
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "chi context value " + k.name
}
chi-5.0.7/context_test.go 0000664 0000000 0000000 00000003710 14145546033 0015346 0 ustar 00root root 0000000 0000000 package chi
import "testing"
// TestRoutePattern tests correct in-the-middle wildcard removals.
// If user organizes a router like this:
//
// (router.go)
// r.Route("/v1", func(r chi.Router) {
// r.Mount("/resources", resourcesController{}.Router())
// }
//
// (resources_controller.go)
// r.Route("/", func(r chi.Router) {
// r.Get("/{resource_id}", getResource())
// other routes...
// }
//
// This test checks how the route pattern is calculated
// "/v1/resources/{resource_id}" (right)
// "/v1/resources/*/{resource_id}" (wrong)
func TestRoutePattern(t *testing.T) {
routePatterns := []string{
"/v1/*",
"/resources/*",
"/{resource_id}",
}
x := &Context{
RoutePatterns: routePatterns,
}
if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" {
t.Fatal("unexpected route pattern: " + p)
}
x.RoutePatterns = []string{
"/v1/*",
"/resources/*",
// Additional wildcard, depending on the router structure of the user
"/*",
"/{resource_id}",
}
// Correctly removes in-the-middle wildcards instead of "/v1/resources/*/{resource_id}"
if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" {
t.Fatal("unexpected route pattern: " + p)
}
x.RoutePatterns = []string{
"/v1/*",
"/resources/*",
// Even with many wildcards
"/*",
"/*",
"/*",
"/{resource_id}/*", // Keeping trailing wildcard
}
// Correctly removes in-the-middle wildcards instead of "/v1/resources/*/*/{resource_id}/*"
if p := x.RoutePattern(); p != "/v1/resources/{resource_id}/*" {
t.Fatal("unexpected route pattern: " + p)
}
x.RoutePatterns = []string{
"/v1/*",
"/resources/*",
// And respects asterisks as part of the paths
"/*special_path/*",
"/with_asterisks*/*",
"/{resource_id}",
}
// Correctly removes in-the-middle wildcards instead of "/v1/resourcesspecial_path/with_asterisks{resource_id}"
if p := x.RoutePattern(); p != "/v1/resources/*special_path/with_asterisks*/{resource_id}" {
t.Fatal("unexpected route pattern: " + p)
}
}
chi-5.0.7/go.mod 0000664 0000000 0000000 00000000051 14145546033 0013375 0 ustar 00root root 0000000 0000000 module github.com/go-chi/chi/v5
go 1.14
chi-5.0.7/middleware/ 0000775 0000000 0000000 00000000000 14145546033 0014410 5 ustar 00root root 0000000 0000000 chi-5.0.7/middleware/basic_auth.go 0000664 0000000 0000000 00000001524 14145546033 0017043 0 ustar 00root root 0000000 0000000 package middleware
import (
"crypto/subtle"
"fmt"
"net/http"
)
// BasicAuth implements a simple middleware handler for adding basic http auth to a route.
func BasicAuth(realm string, creds map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, pass, ok := r.BasicAuth()
if !ok {
basicAuthFailed(w, realm)
return
}
credPass, credUserOk := creds[user]
if !credUserOk || subtle.ConstantTimeCompare([]byte(pass), []byte(credPass)) != 1 {
basicAuthFailed(w, realm)
return
}
next.ServeHTTP(w, r)
})
}
}
func basicAuthFailed(w http.ResponseWriter, realm string) {
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
w.WriteHeader(http.StatusUnauthorized)
}
chi-5.0.7/middleware/clean_path.go 0000664 0000000 0000000 00000001222 14145546033 0017032 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"path"
"github.com/go-chi/chi/v5"
)
// CleanPath middleware will clean out double slash mistakes from a user's request path.
// For example, if a user requests /users//1 or //users////1 will both be treated as: /users/1
func CleanPath(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
routePath := rctx.RoutePath
if routePath == "" {
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
} else {
routePath = r.URL.Path
}
rctx.RoutePath = path.Clean(routePath)
}
next.ServeHTTP(w, r)
})
}
chi-5.0.7/middleware/compress.go 0000664 0000000 0000000 00000026561 14145546033 0016604 0 ustar 00root root 0000000 0000000 package middleware
import (
"bufio"
"compress/flate"
"compress/gzip"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
)
var defaultCompressibleContentTypes = []string{
"text/html",
"text/css",
"text/plain",
"text/javascript",
"application/javascript",
"application/x-javascript",
"application/json",
"application/atom+xml",
"application/rss+xml",
"image/svg+xml",
}
// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// NOTE: make sure to set the Content-Type header on your response
// otherwise this middleware will not compress the response body. For ex, in
// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
// or set it manually.
//
// Passing a compression level of 5 is sensible value
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
compressor := NewCompressor(level, types...)
return compressor.Handler
}
// Compressor represents a set of encoding configurations.
type Compressor struct {
// The mapping of encoder names to encoder functions.
encoders map[string]EncoderFunc
// The mapping of pooled encoders to pools.
pooledEncoders map[string]*sync.Pool
// The set of content types allowed to be compressed.
allowedTypes map[string]struct{}
allowedWildcards map[string]struct{}
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
level int // The compression level.
}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed.
func NewCompressor(level int, types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are
// provided, use the default list.
allowedTypes := make(map[string]struct{})
allowedWildcards := make(map[string]struct{})
if len(types) > 0 {
for _, t := range types {
if strings.Contains(strings.TrimSuffix(t, "/*"), "*") {
panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t))
}
if strings.HasSuffix(t, "/*") {
allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{}
} else {
allowedTypes[t] = struct{}{}
}
}
} else {
for _, t := range defaultCompressibleContentTypes {
allowedTypes[t] = struct{}{}
}
}
c := &Compressor{
level: level,
encoders: make(map[string]EncoderFunc),
pooledEncoders: make(map[string]*sync.Pool),
allowedTypes: allowedTypes,
allowedWildcards: allowedWildcards,
}
// Set the default encoders. The precedence order uses the reverse
// ordering that the encoders were added. This means adding new encoders
// will move them to the front of the order.
//
// TODO:
// lzma: Opera.
// sdch: Chrome, Android. Gzip output + dictionary header.
// br: Brotli, see https://github.com/go-chi/chi/pull/326
// HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
// wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
// checksum compared to CRC-32 used in "gzip" and thus is faster.
//
// But.. some old browsers (MSIE, Safari 5.1) incorrectly expect
// raw DEFLATE data only, without the mentioned zlib wrapper.
// Because of this major confusion, most modern browsers try it
// both ways, first looking for zlib headers.
// Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548
//
// The list of browsers having problems is quite big, see:
// http://zoompf.com/blog/2012/02/lose-the-wait-http-compression
// https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results
//
// That's why we prefer gzip over deflate. It's just more reliable
// and not significantly slower than deflate.
c.SetEncoder("deflate", encoderDeflate)
// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
c.SetEncoder("gzip", encoderGzip)
// NOTE: Not implemented, intentionally:
// case "compress": // LZW. Deprecated.
// case "bzip2": // Too slow on-the-fly.
// case "zopfli": // Too slow on-the-fly.
// case "xz": // Too slow on-the-fly.
return c
}
// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardised identifier. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
//
// For example, add the Brotli algortithm:
//
// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
//
// compressor := middleware.NewCompressor(5, "text/html")
// compressor.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
// params := brotli_enc.NewBrotliParams()
// params.SetQuality(level)
// return brotli_enc.NewBrotliWriter(params, w)
// })
func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
// If we are adding a new encoder that is already registered, we have to
// clear that one out first.
if _, ok := c.pooledEncoders[encoding]; ok {
delete(c.pooledEncoders, encoding)
}
if _, ok := c.encoders[encoding]; ok {
delete(c.encoders, encoding)
}
// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
encoder := fn(ioutil.Discard, c.level)
if encoder != nil {
if _, ok := encoder.(ioResetterWriter); ok {
pool := &sync.Pool{
New: func() interface{} {
return fn(ioutil.Discard, c.level)
},
}
c.pooledEncoders[encoding] = pool
}
}
// If the encoder is not in the pooledEncoders, add it to the normal encoders.
if _, ok := c.pooledEncoders[encoding]; !ok {
c.encoders[encoding] = fn
}
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
}
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}
// Handler returns a new middleware that will compress the response based on the
// current Compressor.
func (c *Compressor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: c.allowedTypes,
contentWildcards: c.allowedWildcards,
encoding: encoding,
compressable: false, // determined in post-handler
}
if encoder != nil {
cw.w = encoder
}
// Re-add the encoder to the pool if applicable.
defer cleanup()
defer cw.Close()
next.ServeHTTP(cw, r)
})
}
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
header := h.Get("Accept-Encoding")
// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(header), ",")
// Find supported encoder by accepted list by precedence
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder := pool.Get().(ioResetterWriter)
cleanup := func() {
pool.Put(encoder)
}
encoder.Reset(w)
return encoder, name, cleanup
}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), name, func() {}
}
}
}
// No encoder found to match the accepted encoding
return nil, "", func() {}
}
func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
if strings.Contains(v, encoding) {
return true
}
}
return false
}
// An EncoderFunc is a function that wraps the provided io.Writer with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) io.Writer
// Interface for types that allow resetting io.Writers.
type ioResetterWriter interface {
io.Writer
Reset(w io.Writer)
}
type compressResponseWriter struct {
http.ResponseWriter
// The streaming encoder writer to be used if there is one. Otherwise,
// this is just the normal writer.
w io.Writer
contentTypes map[string]struct{}
contentWildcards map[string]struct{}
encoding string
wroteHeader bool
compressable bool
}
func (cw *compressResponseWriter) isCompressable() bool {
// Parse the first part of the Content-Type response header.
contentType := cw.Header().Get("Content-Type")
if idx := strings.Index(contentType, ";"); idx >= 0 {
contentType = contentType[0:idx]
}
// Is the content type compressable?
if _, ok := cw.contentTypes[contentType]; ok {
return true
}
if idx := strings.Index(contentType, "/"); idx > 0 {
contentType = contentType[0:idx]
_, ok := cw.contentWildcards[contentType]
return ok
}
return false
}
func (cw *compressResponseWriter) WriteHeader(code int) {
if cw.wroteHeader {
cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate.
return
}
cw.wroteHeader = true
defer cw.ResponseWriter.WriteHeader(code)
// Already compressed data?
if cw.Header().Get("Content-Encoding") != "" {
return
}
if !cw.isCompressable() {
cw.compressable = false
return
}
if cw.encoding != "" {
cw.compressable = true
cw.Header().Set("Content-Encoding", cw.encoding)
cw.Header().Set("Vary", "Accept-Encoding")
// The content-length after compression is unknown
cw.Header().Del("Content-Length")
}
}
func (cw *compressResponseWriter) Write(p []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)
}
return cw.writer().Write(p)
}
func (cw *compressResponseWriter) writer() io.Writer {
if cw.compressable {
return cw.w
} else {
return cw.ResponseWriter
}
}
type compressFlusher interface {
Flush() error
}
func (cw *compressResponseWriter) Flush() {
if f, ok := cw.writer().(http.Flusher); ok {
f.Flush()
}
// If the underlying writer has a compression flush signature,
// call this Flush() method instead
if f, ok := cw.writer().(compressFlusher); ok {
f.Flush()
// Also flush the underlying response writer
if f, ok := cw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
}
func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := cw.writer().(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer")
}
func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
if ps, ok := cw.writer().(http.Pusher); ok {
return ps.Push(target, opts)
}
return errors.New("chi/middleware: http.Pusher is unavailable on the writer")
}
func (cw *compressResponseWriter) Close() error {
if c, ok := cw.writer().(io.WriteCloser); ok {
return c.Close()
}
return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
}
func encoderGzip(w io.Writer, level int) io.Writer {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
return nil
}
return gw
}
func encoderDeflate(w io.Writer, level int) io.Writer {
dw, err := flate.NewWriter(w, level)
if err != nil {
return nil
}
return dw
}
chi-5.0.7/middleware/compress_test.go 0000664 0000000 0000000 00000012354 14145546033 0017636 0 ustar 00root root 0000000 0000000 package middleware
import (
"compress/flate"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
)
func TestCompressor(t *testing.T) {
r := chi.NewRouter()
compressor := NewCompressor(5, "text/html", "text/css")
if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 {
t.Errorf("gzip and deflate should be pooled")
}
compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer {
return w
})
if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
}
r.Use(compressor.Handler)
r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
})
r.Get("/getcss", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
})
r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
})
ts := httptest.NewServer(r)
defer ts.Close()
tests := []struct {
name string
path string
expectedEncoding string
acceptedEncodings []string
}{
{
name: "no expected encodings due to no accepted encodings",
path: "/gethtml",
acceptedEncodings: nil,
expectedEncoding: "",
},
{
name: "no expected encodings due to content type",
path: "/getplain",
acceptedEncodings: nil,
expectedEncoding: "",
},
{
name: "gzip is only encoding",
path: "/gethtml",
acceptedEncodings: []string{"gzip"},
expectedEncoding: "gzip",
},
{
name: "gzip is preferred over deflate",
path: "/getcss",
acceptedEncodings: []string{"gzip", "deflate"},
expectedEncoding: "gzip",
},
{
name: "deflate is used",
path: "/getcss",
acceptedEncodings: []string{"deflate"},
expectedEncoding: "deflate",
},
{
name: "nop is preferred",
path: "/getcss",
acceptedEncodings: []string{"nop, gzip, deflate"},
expectedEncoding: "nop",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
if respString != "textstring" {
t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
}
if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding {
t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got)
}
})
}
}
func TestCompressorWildcards(t *testing.T) {
tests := []struct {
name string
recover string
types []string
typesCount int
wcCount int
}{
{
name: "defaults",
typesCount: 10,
},
{
name: "no wildcard",
types: []string{"text/plain", "text/html"},
typesCount: 2,
},
{
name: "invalid wildcard #1",
types: []string{"audio/*wav"},
recover: "middleware/compress: Unsupported content-type wildcard pattern 'audio/*wav'. Only '/*' supported",
},
{
name: "invalid wildcard #2",
types: []string{"application*/*"},
recover: "middleware/compress: Unsupported content-type wildcard pattern 'application*/*'. Only '/*' supported",
},
{
name: "valid wildcard",
types: []string{"text/*"},
wcCount: 1,
},
{
name: "mixed",
types: []string{"audio/wav", "text/*"},
typesCount: 1,
wcCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if tt.recover == "" {
tt.recover = ""
}
if r := recover(); tt.recover != fmt.Sprintf("%v", r) {
t.Errorf("Unexpected value recovered: %v", r)
}
}()
compressor := NewCompressor(5, tt.types...)
if len(compressor.allowedTypes) != tt.typesCount {
t.Errorf("expected %d allowedTypes, got %d", tt.typesCount, len(compressor.allowedTypes))
}
if len(compressor.allowedWildcards) != tt.wcCount {
t.Errorf("expected %d allowedWildcards, got %d", tt.wcCount, len(compressor.allowedWildcards))
}
})
}
}
func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, nil)
if err != nil {
t.Fatal(err)
return nil, ""
}
if len(encodings) > 0 {
encodingsString := strings.Join(encodings, ",")
req.Header.Set("Accept-Encoding", encodingsString)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody := decodeResponseBody(t, resp)
defer resp.Body.Close()
return resp, respBody
}
func decodeResponseBody(t *testing.T, resp *http.Response) string {
var reader io.ReadCloser
switch resp.Header.Get("Content-Encoding") {
case "gzip":
var err error
reader, err = gzip.NewReader(resp.Body)
if err != nil {
t.Fatal(err)
}
case "deflate":
reader = flate.NewReader(resp.Body)
default:
reader = resp.Body
}
respBody, err := ioutil.ReadAll(reader)
if err != nil {
t.Fatal(err)
return ""
}
reader.Close()
return string(respBody)
}
chi-5.0.7/middleware/content_charset.go 0000664 0000000 0000000 00000002403 14145546033 0020121 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// ContentCharset generates a handler that writes a 415 Unsupported Media Type response if none of the charsets match.
// An empty charset will allow requests with no Content-Type header or no specified charset.
func ContentCharset(charsets ...string) func(next http.Handler) http.Handler {
for i, c := range charsets {
charsets[i] = strings.ToLower(c)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !contentEncoding(r.Header.Get("Content-Type"), charsets...) {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
next.ServeHTTP(w, r)
})
}
}
// Check the content encoding against a list of acceptable values.
func contentEncoding(ce string, charsets ...string) bool {
_, ce = split(strings.ToLower(ce), ";")
_, ce = split(ce, "charset=")
ce, _ = split(ce, ";")
for _, c := range charsets {
if ce == c {
return true
}
}
return false
}
// Split a string in two parts, cleaning any whitespace.
func split(str, sep string) (string, string) {
var a, b string
var parts = strings.SplitN(str, sep, 2)
a = strings.TrimSpace(parts[0])
if len(parts) == 2 {
b = strings.TrimSpace(parts[1])
}
return a, b
}
chi-5.0.7/middleware/content_charset_test.go 0000664 0000000 0000000 00000005620 14145546033 0021164 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestContentCharset(t *testing.T) {
t.Parallel()
var tests = []struct {
name string
inputValue string
inputContentCharset []string
want int
}{
{
"should accept requests with a matching charset",
"application/json; charset=UTF-8",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should be case-insensitive",
"application/json; charset=utf-8",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should accept requests with a matching charset with extra values",
"application/json; foo=bar; charset=UTF-8; spam=eggs",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should accept requests with a matching charset when multiple charsets are supported",
"text/xml; charset=UTF-8",
[]string{"UTF-8", "Latin-1"},
http.StatusOK,
},
{
"should accept requests with no charset if empty charset headers are allowed",
"text/xml",
[]string{"UTF-8", ""},
http.StatusOK,
},
{
"should not accept requests with no charset if empty charset headers are not allowed",
"text/xml",
[]string{"UTF-8"},
http.StatusUnsupportedMediaType,
},
{
"should not accept requests with a mismatching charset",
"text/plain; charset=Latin-1",
[]string{"UTF-8"},
http.StatusUnsupportedMediaType,
},
{
"should not accept requests with a mismatching charset even if empty charsets are allowed",
"text/plain; charset=Latin-1",
[]string{"UTF-8", ""},
http.StatusUnsupportedMediaType,
},
}
for _, tt := range tests {
var tt = tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var recorder = httptest.NewRecorder()
var r = chi.NewRouter()
r.Use(ContentCharset(tt.inputContentCharset...))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {})
var req, _ = http.NewRequest("GET", "/", nil)
req.Header.Set("Content-Type", tt.inputValue)
r.ServeHTTP(recorder, req)
var res = recorder.Result()
if res.StatusCode != tt.want {
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
}
})
}
}
func TestSplit(t *testing.T) {
t.Parallel()
var s1, s2 = split(" type1;type2 ", ";")
if s1 != "type1" || s2 != "type2" {
t.Errorf("Want type1, type2 got %s, %s", s1, s2)
}
s1, s2 = split("type1 ", ";")
if s1 != "type1" {
t.Errorf("Want \"type1\" got \"%s\"", s1)
}
if s2 != "" {
t.Errorf("Want empty string got \"%s\"", s2)
}
}
func TestContentEncoding(t *testing.T) {
t.Parallel()
if !contentEncoding("application/json; foo=bar; charset=utf-8; spam=eggs", []string{"utf-8"}...) {
t.Error("Want true, got false")
}
if contentEncoding("text/plain; charset=latin-1", []string{"utf-8"}...) {
t.Error("Want false, got true")
}
if !contentEncoding("text/xml; charset=UTF-8", []string{"latin-1", "utf-8"}...) {
t.Error("Want true, got false")
}
}
chi-5.0.7/middleware/content_encoding.go 0000664 0000000 0000000 00000002077 14145546033 0020265 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// AllowContentEncoding enforces a whitelist of request Content-Encoding otherwise responds
// with a 415 Unsupported Media Type status.
func AllowContentEncoding(contentEncoding ...string) func(next http.Handler) http.Handler {
allowedEncodings := make(map[string]struct{}, len(contentEncoding))
for _, encoding := range contentEncoding {
allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))] = struct{}{}
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
requestEncodings := r.Header["Content-Encoding"]
// skip check for empty content body or no Content-Encoding
if r.ContentLength == 0 {
next.ServeHTTP(w, r)
return
}
// All encodings in the request must be allowed
for _, encoding := range requestEncodings {
if _, ok := allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))]; !ok {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
chi-5.0.7/middleware/content_encoding_test.go 0000664 0000000 0000000 00000003611 14145546033 0021317 0 ustar 00root root 0000000 0000000 package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestContentEncodingMiddleware(t *testing.T) {
t.Parallel()
// support for:
// Content-Encoding: gzip
// Content-Encoding: deflate
// Content-Encoding: gzip, deflate
// Content-Encoding: deflate, gzip
middleware := AllowContentEncoding("deflate", "gzip")
tests := []struct {
name string
encodings []string
expectedStatus int
}{
{
name: "Support no encoding",
encodings: []string{},
expectedStatus: 200,
},
{
name: "Support gzip encoding",
encodings: []string{"gzip"},
expectedStatus: 200,
},
{
name: "No support for br encoding",
encodings: []string{"br"},
expectedStatus: 415,
},
{
name: "Support for gzip and deflate encoding",
encodings: []string{"gzip", "deflate"},
expectedStatus: 200,
},
{
name: "Support for deflate and gzip encoding",
encodings: []string{"deflate", "gzip"},
expectedStatus: 200,
},
{
name: "No support for deflate and br encoding",
encodings: []string{"deflate", "br"},
expectedStatus: 415,
},
}
for _, tt := range tests {
var tt = tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body := []byte("This is my content. There are many like this but this one is mine")
r := httptest.NewRequest("POST", "/", bytes.NewReader(body))
for _, encoding := range tt.encodings {
r.Header.Set("Content-Encoding", encoding)
}
w := httptest.NewRecorder()
router := chi.NewRouter()
router.Use(middleware)
router.Post("/", func(w http.ResponseWriter, r *http.Request) {})
router.ServeHTTP(w, r)
res := w.Result()
if res.StatusCode != tt.expectedStatus {
t.Errorf("response is incorrect, got %d, want %d", w.Code, tt.expectedStatus)
}
})
}
}
chi-5.0.7/middleware/content_type.go 0000664 0000000 0000000 00000002457 14145546033 0017462 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// SetHeader is a convenience handler to set a response header key/value
func SetHeader(key, value string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(key, value)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
// AllowContentType enforces a whitelist of request Content-Types otherwise responds
// with a 415 Unsupported Media Type status.
func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handler {
allowedContentTypes := make(map[string]struct{}, len(contentTypes))
for _, ctype := range contentTypes {
allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{}
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength == 0 {
// skip check for empty content body
next.ServeHTTP(w, r)
return
}
s := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type")))
if i := strings.Index(s, ";"); i > -1 {
s = s[0:i]
}
if _, ok := allowedContentTypes[s]; ok {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusUnsupportedMediaType)
}
return http.HandlerFunc(fn)
}
}
chi-5.0.7/middleware/content_type_test.go 0000664 0000000 0000000 00000004057 14145546033 0020517 0 ustar 00root root 0000000 0000000 package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestContentType(t *testing.T) {
t.Parallel()
var tests = []struct {
name string
inputValue string
allowedContentTypes []string
want int
}{
{
"should accept requests with a matching content type",
"application/json; charset=UTF-8",
[]string{"application/json"},
http.StatusOK,
},
{
"should accept requests with a matching content type no charset",
"application/json",
[]string{"application/json"},
http.StatusOK,
},
{
"should accept requests with a matching content-type with extra values",
"application/json; foo=bar; charset=UTF-8; spam=eggs",
[]string{"application/json"},
http.StatusOK,
},
{
"should accept requests with a matching content type when multiple content types are supported",
"text/xml; charset=UTF-8",
[]string{"application/json", "text/xml"},
http.StatusOK,
},
{
"should not accept requests with a mismatching content type",
"text/plain; charset=latin-1",
[]string{"application/json"},
http.StatusUnsupportedMediaType,
},
{
"should not accept requests with a mismatching content type even if multiple content types are allowed",
"text/plain; charset=Latin-1",
[]string{"application/json", "text/xml"},
http.StatusUnsupportedMediaType,
},
}
for _, tt := range tests {
var tt = tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
r := chi.NewRouter()
r.Use(AllowContentType(tt.allowedContentTypes...))
r.Post("/", func(w http.ResponseWriter, r *http.Request) {})
body := []byte("This is my content. There are many like this but this one is mine")
req := httptest.NewRequest("POST", "/", bytes.NewReader(body))
req.Header.Set("Content-Type", tt.inputValue)
r.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tt.want {
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
}
})
}
}
chi-5.0.7/middleware/get_head.go 0000664 0000000 0000000 00000001721 14145546033 0016500 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"github.com/go-chi/chi/v5"
)
// GetHead automatically route undefined HEAD requests to GET handlers.
func GetHead(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
rctx := chi.RouteContext(r.Context())
routePath := rctx.RoutePath
if routePath == "" {
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
} else {
routePath = r.URL.Path
}
}
// Temporary routing context to look-ahead before routing the request
tctx := chi.NewRouteContext()
// Attempt to find a HEAD handler for the routing path, if not found, traverse
// the router as through its a GET route, but proceed with the request
// with the HEAD method.
if !rctx.Routes.Match(tctx, "HEAD", routePath) {
rctx.RouteMethod = "GET"
rctx.RoutePath = routePath
next.ServeHTTP(w, r)
return
}
}
next.ServeHTTP(w, r)
})
}
chi-5.0.7/middleware/get_head_test.go 0000664 0000000 0000000 00000003633 14145546033 0017543 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestGetHead(t *testing.T) {
r := chi.NewRouter()
r.Use(GetHead)
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "yes")
w.Write([]byte("bye"))
})
r.Route("/articles", func(r chi.Router) {
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
w.Header().Set("X-Article", id)
w.Write([]byte("article:" + id))
})
})
r.Route("/users", func(r chi.Router) {
r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-User", "-")
w.Write([]byte("user"))
})
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
w.Header().Set("X-User", id)
w.Write([]byte("user:" + id))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if req, body := testRequest(t, ts, "HEAD", "/hi", nil); body != "" || req.Header.Get("X-Test") != "yes" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/", nil); body != "404 page not found\n" {
t.Fatalf(body)
}
if req, body := testRequest(t, ts, "HEAD", "/", nil); body != "" || req.StatusCode != 404 {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/articles/5", nil); body != "article:5" {
t.Fatalf(body)
}
if req, body := testRequest(t, ts, "HEAD", "/articles/5", nil); body != "" || req.Header.Get("X-Article") != "5" {
t.Fatalf("expecting X-Article header '5' but got '%s'", req.Header.Get("X-Article"))
}
if _, body := testRequest(t, ts, "GET", "/users/1", nil); body != "user:1" {
t.Fatalf(body)
}
if req, body := testRequest(t, ts, "HEAD", "/users/1", nil); body != "" || req.Header.Get("X-User") != "-" {
t.Fatalf("expecting X-User header '-' but got '%s'", req.Header.Get("X-User"))
}
}
chi-5.0.7/middleware/heartbeat.go 0000664 0000000 0000000 00000001363 14145546033 0016701 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// Heartbeat endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient
// to place this above ACL middlewares as well.
func Heartbeat(endpoint string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if (r.Method == "GET" || r.Method == "HEAD") && strings.EqualFold(r.URL.Path, endpoint) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("."))
return
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return f
}
chi-5.0.7/middleware/logger.go 0000664 0000000 0000000 00000011604 14145546033 0016220 0 ustar 00root root 0000000 0000000 package middleware
import (
"bytes"
"context"
"log"
"net/http"
"os"
"runtime"
"time"
)
var (
// LogEntryCtxKey is the context.Context key to store the request log entry.
LogEntryCtxKey = &contextKey{"LogEntry"}
// DefaultLogger is called by the Logger middleware handler to log each request.
// Its made a package-level variable so that it can be reconfigured for custom
// logging configurations.
DefaultLogger func(next http.Handler) http.Handler
)
// Logger is a middleware that logs the start and end of each request, along
// with some useful data about what was requested, what the response status was,
// and how long it took to return. When standard output is a TTY, Logger will
// print in color, otherwise it will print in black and white. Logger prints a
// request ID if one is provided.
//
// Alternatively, look at https://github.com/goware/httplog for a more in-depth
// http logger with structured logging support.
//
// IMPORTANT NOTE: Logger should go before any other middleware that may change
// the response, such as `middleware.Recoverer`. Example:
//
// ```go
// r := chi.NewRouter()
// r.Use(middleware.Logger) // <--<< Logger should come before Recoverer
// r.Use(middleware.Recoverer)
// r.Get("/", handler)
// ```
func Logger(next http.Handler) http.Handler {
return DefaultLogger(next)
}
// RequestLogger returns a logger handler using a custom LogFormatter.
func RequestLogger(f LogFormatter) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
entry := f.NewLogEntry(r)
ww := NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil)
}()
next.ServeHTTP(ww, WithLogEntry(r, entry))
}
return http.HandlerFunc(fn)
}
}
// LogFormatter initiates the beginning of a new LogEntry per request.
// See DefaultLogFormatter for an example implementation.
type LogFormatter interface {
NewLogEntry(r *http.Request) LogEntry
}
// LogEntry records the final log when a request completes.
// See defaultLogEntry for an example implementation.
type LogEntry interface {
Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{})
Panic(v interface{}, stack []byte)
}
// GetLogEntry returns the in-context LogEntry for a request.
func GetLogEntry(r *http.Request) LogEntry {
entry, _ := r.Context().Value(LogEntryCtxKey).(LogEntry)
return entry
}
// WithLogEntry sets the in-context LogEntry for a request.
func WithLogEntry(r *http.Request, entry LogEntry) *http.Request {
r = r.WithContext(context.WithValue(r.Context(), LogEntryCtxKey, entry))
return r
}
// LoggerInterface accepts printing to stdlib logger or compatible logger.
type LoggerInterface interface {
Print(v ...interface{})
}
// DefaultLogFormatter is a simple logger that implements a LogFormatter.
type DefaultLogFormatter struct {
Logger LoggerInterface
NoColor bool
}
// NewLogEntry creates a new LogEntry for the request.
func (l *DefaultLogFormatter) NewLogEntry(r *http.Request) LogEntry {
useColor := !l.NoColor
entry := &defaultLogEntry{
DefaultLogFormatter: l,
request: r,
buf: &bytes.Buffer{},
useColor: useColor,
}
reqID := GetReqID(r.Context())
if reqID != "" {
cW(entry.buf, useColor, nYellow, "[%s] ", reqID)
}
cW(entry.buf, useColor, nCyan, "\"")
cW(entry.buf, useColor, bMagenta, "%s ", r.Method)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
cW(entry.buf, useColor, nCyan, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto)
entry.buf.WriteString("from ")
entry.buf.WriteString(r.RemoteAddr)
entry.buf.WriteString(" - ")
return entry
}
type defaultLogEntry struct {
*DefaultLogFormatter
request *http.Request
buf *bytes.Buffer
useColor bool
}
func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) {
switch {
case status < 200:
cW(l.buf, l.useColor, bBlue, "%03d", status)
case status < 300:
cW(l.buf, l.useColor, bGreen, "%03d", status)
case status < 400:
cW(l.buf, l.useColor, bCyan, "%03d", status)
case status < 500:
cW(l.buf, l.useColor, bYellow, "%03d", status)
default:
cW(l.buf, l.useColor, bRed, "%03d", status)
}
cW(l.buf, l.useColor, bBlue, " %dB", bytes)
l.buf.WriteString(" in ")
if elapsed < 500*time.Millisecond {
cW(l.buf, l.useColor, nGreen, "%s", elapsed)
} else if elapsed < 5*time.Second {
cW(l.buf, l.useColor, nYellow, "%s", elapsed)
} else {
cW(l.buf, l.useColor, nRed, "%s", elapsed)
}
l.Logger.Print(l.buf.String())
}
func (l *defaultLogEntry) Panic(v interface{}, stack []byte) {
PrintPrettyStack(v)
}
func init() {
color := true
if runtime.GOOS == "windows" {
color = false
}
DefaultLogger = RequestLogger(&DefaultLogFormatter{Logger: log.New(os.Stdout, "", log.LstdFlags), NoColor: !color})
}
chi-5.0.7/middleware/logger_test.go 0000664 0000000 0000000 00000002134 14145546033 0017255 0 ustar 00root root 0000000 0000000 package middleware
import (
"bufio"
"bytes"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
type testLoggerWriter struct {
*httptest.ResponseRecorder
}
func (cw testLoggerWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}
func TestRequestLogger(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := w.(http.Hijacker)
if !ok {
t.Errorf("http.Hijacker is unavailable on the writer. add the interface methods.")
}
})
r := httptest.NewRequest("GET", "/", nil)
w := testLoggerWriter{
ResponseRecorder: httptest.NewRecorder(),
}
handler := DefaultLogger(testHandler)
handler.ServeHTTP(w, r)
}
func TestRequestLoggerReadFrom(t *testing.T) {
data := []byte("file data")
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.ServeContent(w, r, "file", time.Time{}, bytes.NewReader(data))
})
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler := DefaultLogger(testHandler)
handler.ServeHTTP(w, r)
assertEqual(t, data, w.Body.Bytes())
}
chi-5.0.7/middleware/maybe.go 0000664 0000000 0000000 00000001176 14145546033 0016041 0 ustar 00root root 0000000 0000000 package middleware
import "net/http"
// Maybe middleware will allow you to change the flow of the middleware stack execution depending on return
// value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if
// a request does not satisfied the maybeFn logic.
func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if maybeFn(r) {
mw(next).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
}
chi-5.0.7/middleware/middleware.go 0000664 0000000 0000000 00000001254 14145546033 0017056 0 ustar 00root root 0000000 0000000 package middleware
import "net/http"
// New will create a new middleware handler from a http.Handler.
func New(h http.Handler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
})
}
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "chi/middleware context value " + k.name
}
chi-5.0.7/middleware/middleware_test.go 0000664 0000000 0000000 00000006463 14145546033 0020124 0 ustar 00root root 0000000 0000000 package middleware
import (
"crypto/tls"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"reflect"
"runtime"
"testing"
"time"
)
var testdataDir string
func init() {
_, filename, _, _ := runtime.Caller(0)
testdataDir = path.Join(path.Dir(filename), "/../testdata")
}
func TestWrapWriterHTTP2(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Proto != "HTTP/2.0" {
t.Fatalf("request proto should be HTTP/2.0 but was %s", r.Proto)
}
_, fl := w.(http.Flusher)
if !fl {
t.Fatal("request should have been a http.Flusher")
}
_, hj := w.(http.Hijacker)
if hj {
t.Fatal("request should not have been a http.Hijacker")
}
_, rf := w.(io.ReaderFrom)
if rf {
t.Fatal("request should not have been a io.ReaderFrom")
}
_, ps := w.(http.Pusher)
if !ps {
t.Fatal("request should have been a http.Pusher")
}
w.Write([]byte("OK"))
})
wmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(NewWrapResponseWriter(w, r.ProtoMajor), r)
})
}
server := http.Server{
Addr: ":7072",
Handler: wmw(handler),
}
// By serving over TLS, we get HTTP2 requests
go server.ListenAndServeTLS(testdataDir+"/cert.pem", testdataDir+"/key.pem")
defer server.Close()
// We need the server to start before making the request
time.Sleep(100 * time.Millisecond)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
// The certificates we are using are self signed
InsecureSkipVerify: true,
},
ForceAttemptHTTP2: true,
},
}
resp, err := client.Get("https://localhost:7072")
if err != nil {
t.Fatalf("could not get server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("non 200 response: %v", resp.StatusCode)
}
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func testRequestNoRedirect(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
// http client that doesn't redirect
httpClient := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
resp, err := httpClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func assertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("expecting no error")
}
}
func assertError(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Fatalf("expecting error")
}
}
func assertEqual(t *testing.T, a, b interface{}) {
t.Helper()
if !reflect.DeepEqual(a, b) {
t.Fatalf("expecting values to be equal but got: '%v' and '%v'", a, b)
}
}
chi-5.0.7/middleware/nocache.go 0000664 0000000 0000000 00000002636 14145546033 0016346 0 ustar 00root root 0000000 0000000 package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"net/http"
"time"
)
// Unix epoch time
var epoch = time.Unix(0, 0).Format(time.RFC1123)
// Taken from https://github.com/mytrile/nocache
var noCacheHeaders = map[string]string{
"Expires": epoch,
"Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
"Pragma": "no-cache",
"X-Accel-Expires": "0",
}
var etagHeaders = []string{
"ETag",
"If-Modified-Since",
"If-Match",
"If-None-Match",
"If-Range",
"If-Unmodified-Since",
}
// NoCache is a simple piece of middleware that sets a number of HTTP headers to prevent
// a router (or subrouter) from being cached by an upstream proxy and/or client.
//
// As per http://wiki.nginx.org/HttpProxyModule - NoCache sets:
// Expires: Thu, 01 Jan 1970 00:00:00 UTC
// Cache-Control: no-cache, private, max-age=0
// X-Accel-Expires: 0
// Pragma: no-cache (for HTTP/1.0 proxies/clients)
func NoCache(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// Delete any ETag headers that may have been set
for _, v := range etagHeaders {
if r.Header.Get(v) != "" {
r.Header.Del(v)
}
}
// Set our NoCache headers
for k, v := range noCacheHeaders {
w.Header().Set(k, v)
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
chi-5.0.7/middleware/page_route.go 0000664 0000000 0000000 00000001007 14145546033 0017067 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// PageRoute is a simple middleware which allows you to route a static GET request
// at the middleware stack level.
func PageRoute(path string, handler http.Handler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && strings.EqualFold(r.URL.Path, path) {
handler.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
}
chi-5.0.7/middleware/path_rewrite.go 0000664 0000000 0000000 00000000646 14145546033 0017442 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// PathRewrite is a simple middleware which allows you to rewrite the request URL path.
func PathRewrite(old, new string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.Replace(r.URL.Path, old, new, 1)
next.ServeHTTP(w, r)
})
}
}
chi-5.0.7/middleware/profiler.go 0000664 0000000 0000000 00000003200 14145546033 0016554 0 ustar 00root root 0000000 0000000 package middleware
import (
"expvar"
"fmt"
"net/http"
"net/http/pprof"
"github.com/go-chi/chi/v5"
)
// Profiler is a convenient subrouter used for mounting net/http/pprof. ie.
//
// func MyService() http.Handler {
// r := chi.NewRouter()
// // ..middlewares
// r.Mount("/debug", middleware.Profiler())
// // ..routes
// return r
// }
func Profiler() http.Handler {
r := chi.NewRouter()
r.Use(NoCache)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently)
})
r.HandleFunc("/pprof", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, r.RequestURI+"/", http.StatusMovedPermanently)
})
r.HandleFunc("/pprof/*", pprof.Index)
r.HandleFunc("/pprof/cmdline", pprof.Cmdline)
r.HandleFunc("/pprof/profile", pprof.Profile)
r.HandleFunc("/pprof/symbol", pprof.Symbol)
r.HandleFunc("/pprof/trace", pprof.Trace)
r.HandleFunc("/vars", expVars)
r.Handle("/pprof/goroutine", pprof.Handler("goroutine"))
r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate"))
r.Handle("/pprof/mutex", pprof.Handler("mutex"))
r.Handle("/pprof/heap", pprof.Handler("heap"))
r.Handle("/pprof/block", pprof.Handler("block"))
r.Handle("/pprof/allocs", pprof.Handler("allocs"))
return r
}
// Replicated from expvar.go as not public.
func expVars(w http.ResponseWriter, r *http.Request) {
first := true
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, "{\n")
expvar.Do(func(kv expvar.KeyValue) {
if !first {
fmt.Fprintf(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprintf(w, "\n}\n")
}
chi-5.0.7/middleware/realip.go 0000664 0000000 0000000 00000003421 14145546033 0016213 0 ustar 00root root 0000000 0000000 package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"net/http"
"strings"
)
var trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers
// (in that order).
//
// This middleware should be inserted fairly early in the middleware stack to
// ensure that subsequent layers (e.g., request loggers) which examine the
// RemoteAddr will see the intended value.
//
// You should only use this middleware if you can trust the headers passed to
// you (in particular, the two headers this middleware uses), for example
// because you have placed a reverse proxy like HAProxy or nginx in front of
// chi. If your reverse proxies are configured to pass along arbitrary header
// values from the client, or if you use this middleware without a reverse
// proxy, malicious clients will be able to make you very sad (or, depending on
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func realIP(r *http.Request) string {
var ip string
if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
i := strings.Index(xff, ", ")
if i == -1 {
i = len(xff)
}
ip = xff[:i]
}
return ip
}
chi-5.0.7/middleware/realip_test.go 0000664 0000000 0000000 00000003203 14145546033 0017250 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestXRealIP(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Real-IP", "100.100.100.100")
w := httptest.NewRecorder()
r := chi.NewRouter()
r.Use(RealIP)
realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Fatal("Response Code should be 200")
}
if realIP != "100.100.100.100" {
t.Fatal("Test get real IP error.")
}
}
func TestXForwardForIP(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "100.100.100.100")
w := httptest.NewRecorder()
r := chi.NewRouter()
r.Use(RealIP)
realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Fatal("Response Code should be 200")
}
if realIP != "100.100.100.100" {
t.Fatal("Test get real IP error.")
}
}
func TestXForwardForXRealIPPrecedence(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "0.0.0.0")
req.Header.Add("X-Real-IP", "100.100.100.100")
w := httptest.NewRecorder()
r := chi.NewRouter()
r.Use(RealIP)
realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Fatal("Response Code should be 200")
}
if realIP != "100.100.100.100" {
t.Fatal("Test get real IP precedence error.")
}
}
chi-5.0.7/middleware/recoverer.go 0000664 0000000 0000000 00000011106 14145546033 0016732 0 ustar 00root root 0000000 0000000 package middleware
// The original work was derived from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"os"
"runtime/debug"
"strings"
)
// Recoverer is a middleware that recovers from panics, logs the panic (and a
// backtrace), and returns a HTTP 500 (Internal Server Error) status if
// possible. Recoverer prints a request ID if one is provided.
//
// Alternatively, look at https://github.com/pressly/lg middleware pkgs.
func Recoverer(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler {
logEntry := GetLogEntry(r)
if logEntry != nil {
logEntry.Panic(rvr, debug.Stack())
} else {
PrintPrettyStack(rvr)
}
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// for ability to test the PrintPrettyStack function
var recovererErrorWriter io.Writer = os.Stderr
func PrintPrettyStack(rvr interface{}) {
debugStack := debug.Stack()
s := prettyStack{}
out, err := s.parse(debugStack, rvr)
if err == nil {
recovererErrorWriter.Write(out)
} else {
// print stdlib output as a fallback
os.Stderr.Write(debugStack)
}
}
type prettyStack struct {
}
func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) {
var err error
useColor := true
buf := &bytes.Buffer{}
cW(buf, false, bRed, "\n")
cW(buf, useColor, bCyan, " panic: ")
cW(buf, useColor, bBlue, "%v", rvr)
cW(buf, false, bWhite, "\n \n")
// process debug stack info
stack := strings.Split(string(debugStack), "\n")
lines := []string{}
// locate panic line, as we may have nested panics
for i := len(stack) - 1; i > 0; i-- {
lines = append(lines, stack[i])
if strings.HasPrefix(stack[i], "panic(") {
lines = lines[0 : len(lines)-2] // remove boilerplate
break
}
}
// reverse
for i := len(lines)/2 - 1; i >= 0; i-- {
opp := len(lines) - 1 - i
lines[i], lines[opp] = lines[opp], lines[i]
}
// decorate
for i, line := range lines {
lines[i], err = s.decorateLine(line, useColor, i)
if err != nil {
return nil, err
}
}
for _, l := range lines {
fmt.Fprintf(buf, "%s", l)
}
return buf.Bytes(), nil
}
func (s prettyStack) decorateLine(line string, useColor bool, num int) (string, error) {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "\t") || strings.Contains(line, ".go:") {
return s.decorateSourceLine(line, useColor, num)
} else if strings.HasSuffix(line, ")") {
return s.decorateFuncCallLine(line, useColor, num)
} else {
if strings.HasPrefix(line, "\t") {
return strings.Replace(line, "\t", " ", 1), nil
} else {
return fmt.Sprintf(" %s\n", line), nil
}
}
}
func (s prettyStack) decorateFuncCallLine(line string, useColor bool, num int) (string, error) {
idx := strings.LastIndex(line, "(")
if idx < 0 {
return "", errors.New("not a func call line")
}
buf := &bytes.Buffer{}
pkg := line[0:idx]
// addr := line[idx:]
method := ""
if idx := strings.LastIndex(pkg, string(os.PathSeparator)); idx < 0 {
if idx := strings.Index(pkg, "."); idx > 0 {
method = pkg[idx:]
pkg = pkg[0:idx]
}
} else {
method = pkg[idx+1:]
pkg = pkg[0 : idx+1]
if idx := strings.Index(method, "."); idx > 0 {
pkg += method[0:idx]
method = method[idx:]
}
}
pkgColor := nYellow
methodColor := bGreen
if num == 0 {
cW(buf, useColor, bRed, " -> ")
pkgColor = bMagenta
methodColor = bRed
} else {
cW(buf, useColor, bWhite, " ")
}
cW(buf, useColor, pkgColor, "%s", pkg)
cW(buf, useColor, methodColor, "%s\n", method)
// cW(buf, useColor, nBlack, "%s", addr)
return buf.String(), nil
}
func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (string, error) {
idx := strings.LastIndex(line, ".go:")
if idx < 0 {
return "", errors.New("not a source line")
}
buf := &bytes.Buffer{}
path := line[0 : idx+3]
lineno := line[idx+3:]
idx = strings.LastIndex(path, string(os.PathSeparator))
dir := path[0 : idx+1]
file := path[idx+1:]
idx = strings.Index(lineno, " ")
if idx > 0 {
lineno = lineno[0:idx]
}
fileColor := bCyan
lineColor := bGreen
if num == 1 {
cW(buf, useColor, bRed, " -> ")
fileColor = bRed
lineColor = bMagenta
} else {
cW(buf, false, bWhite, " ")
}
cW(buf, useColor, bWhite, "%s", dir)
cW(buf, useColor, fileColor, "%s", file)
cW(buf, useColor, lineColor, "%s", lineno)
if num == 1 {
cW(buf, false, bWhite, "\n")
}
cW(buf, false, bWhite, "\n")
return buf.String(), nil
}
chi-5.0.7/middleware/recoverer_test.go 0000664 0000000 0000000 00000001742 14145546033 0017776 0 ustar 00root root 0000000 0000000 package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
)
func panicingHandler(http.ResponseWriter, *http.Request) { panic("foo") }
func TestRecoverer(t *testing.T) {
r := chi.NewRouter()
oldRecovererErrorWriter := recovererErrorWriter
defer func() { recovererErrorWriter = oldRecovererErrorWriter }()
buf := &bytes.Buffer{}
recovererErrorWriter = buf
r.Use(Recoverer)
r.Get("/", panicingHandler)
ts := httptest.NewServer(r)
defer ts.Close()
res, _ := testRequest(t, ts, "GET", "/", nil)
assertEqual(t, res.StatusCode, http.StatusInternalServerError)
lines := strings.Split(buf.String(), "\n")
for _, line := range lines {
if strings.HasPrefix(strings.TrimSpace(line), "->") {
if !strings.Contains(line, "panicingHandler") {
t.Fatalf("First func call line should refer to panicingHandler, but actual line:\n%v\n", line)
}
return
}
}
t.Fatal("First func call line should start with ->.")
}
chi-5.0.7/middleware/request_id.go 0000664 0000000 0000000 00000005626 14145546033 0017114 0 ustar 00root root 0000000 0000000 package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
)
// Key to use when setting the request ID.
type ctxKeyRequestID int
// RequestIDKey is the key that holds the unique request ID in a request context.
const RequestIDKey ctxKeyRequestID = 0
// RequestIDHeader is the name of the HTTP Header which contains the request id.
// Exported so that it can be changed by developers
var RequestIDHeader = "X-Request-Id"
var prefix string
var reqid uint64
// A quick note on the statistics here: we're trying to calculate the chance that
// two randomly generated base62 prefixes will collide. We use the formula from
// http://en.wikipedia.org/wiki/Birthday_problem
//
// P[m, n] \approx 1 - e^{-m^2/2n}
//
// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server
// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$
//
// For a $k$ character base-62 identifier, we have $n(k) = 62^k$
//
// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for
// our purposes, and is surely more than anyone would ever need in practice -- a
// process that is rebooted a handful of times a day for a hundred years has less
// than a millionth of a percent chance of generating two colliding IDs.
func init() {
hostname, err := os.Hostname()
if hostname == "" || err != nil {
hostname = "localhost"
}
var buf [12]byte
var b64 string
for len(b64) < 10 {
rand.Read(buf[:])
b64 = base64.StdEncoding.EncodeToString(buf[:])
b64 = strings.NewReplacer("+", "", "/", "").Replace(b64)
}
prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10])
}
// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
// process, and where the last number is an atomically incremented request
// counter.
func RequestID(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get(RequestIDHeader)
if requestID == "" {
myid := atomic.AddUint64(&reqid, 1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// GetReqID returns a request ID from the given context if one is present.
// Returns the empty string if a request ID cannot be found.
func GetReqID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
return reqID
}
return ""
}
// NextRequestID generates the next request ID in the sequence.
func NextRequestID() uint64 {
return atomic.AddUint64(&reqid, 1)
}
chi-5.0.7/middleware/request_id_test.go 0000664 0000000 0000000 00000002572 14145546033 0020150 0 ustar 00root root 0000000 0000000 package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func maintainDefaultRequestID() func() {
original := RequestIDHeader
return func() {
RequestIDHeader = original
}
}
func TestRequestID(t *testing.T) {
tests := map[string]struct {
requestIDHeader string
request func() *http.Request
expectedResponse string
}{
"Retrieves Request Id from default header": {
"X-Request-Id",
func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Request-Id", "req-123456")
return req
},
"RequestID: req-123456",
},
"Retrieves Request Id from custom header": {
"X-Trace-Id",
func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Trace-Id", "trace:abc123")
return req
},
"RequestID: trace:abc123",
},
}
defer maintainDefaultRequestID()()
for _, test := range tests {
w := httptest.NewRecorder()
r := chi.NewRouter()
RequestIDHeader = test.requestIDHeader
r.Use(RequestID)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
requestID := GetReqID(r.Context())
response := fmt.Sprintf("RequestID: %s", requestID)
w.Write([]byte(response))
})
r.ServeHTTP(w, test.request())
if w.Body.String() != test.expectedResponse {
t.Fatalf("RequestID was not the expected value")
}
}
}
chi-5.0.7/middleware/route_headers.go 0000664 0000000 0000000 00000010354 14145546033 0017573 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strings"
)
// RouteHeaders is a neat little header-based router that allows you to direct
// the flow of a request through a middleware stack based on a request header.
//
// For example, lets say you'd like to setup multiple routers depending on the
// request Host header, you could then do something as so:
//
// r := chi.NewRouter()
// rSubdomain := chi.NewRouter()
//
// r.Use(middleware.RouteHeaders().
// Route("Host", "example.com", middleware.New(r)).
// Route("Host", "*.example.com", middleware.New(rSubdomain)).
// Handler)
//
// r.Get("/", h)
// rSubdomain.Get("/", h2)
//
//
// Another example, imagine you want to setup multiple CORS handlers, where for
// your origin servers you allow authorized requests, but for third-party public
// requests, authorization is disabled.
//
// r := chi.NewRouter()
//
// r.Use(middleware.RouteHeaders().
// Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{
// AllowedOrigins: []string{"https://api.skyweaver.net"},
// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
// AllowCredentials: true, // <----------<<< allow credentials
// })).
// Route("Origin", "*", cors.Handler(cors.Options{
// AllowedOrigins: []string{"*"},
// AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// AllowedHeaders: []string{"Accept", "Content-Type"},
// AllowCredentials: false, // <----------<<< do not allow credentials
// })).
// Handler)
//
func RouteHeaders() HeaderRouter {
return HeaderRouter{}
}
type HeaderRouter map[string][]HeaderRoute
func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
header = strings.ToLower(header)
k := hr[header]
if k == nil {
hr[header] = []HeaderRoute{}
}
hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})
return hr
}
func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
header = strings.ToLower(header)
k := hr[header]
if k == nil {
hr[header] = []HeaderRoute{}
}
patterns := []Pattern{}
for _, m := range match {
patterns = append(patterns, NewPattern(m))
}
hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})
return hr
}
func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
hr["*"] = []HeaderRoute{{Middleware: handler}}
return hr
}
func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(hr) == 0 {
// skip if no routes set
next.ServeHTTP(w, r)
}
// find first matching header route, and continue
for header, matchers := range hr {
headerValue := r.Header.Get(header)
if headerValue == "" {
continue
}
headerValue = strings.ToLower(headerValue)
for _, matcher := range matchers {
if matcher.IsMatch(headerValue) {
matcher.Middleware(next).ServeHTTP(w, r)
return
}
}
}
// if no match, check for "*" default route
matcher, ok := hr["*"]
if !ok || matcher[0].Middleware == nil {
next.ServeHTTP(w, r)
return
}
matcher[0].Middleware(next).ServeHTTP(w, r)
})
}
type HeaderRoute struct {
Middleware func(next http.Handler) http.Handler
MatchOne Pattern
MatchAny []Pattern
}
func (r HeaderRoute) IsMatch(value string) bool {
if len(r.MatchAny) > 0 {
for _, m := range r.MatchAny {
if m.Match(value) {
return true
}
}
} else if r.MatchOne.Match(value) {
return true
}
return false
}
type Pattern struct {
prefix string
suffix string
wildcard bool
}
func NewPattern(value string) Pattern {
p := Pattern{}
if i := strings.IndexByte(value, '*'); i >= 0 {
p.wildcard = true
p.prefix = value[0:i]
p.suffix = value[i+1:]
} else {
p.prefix = value
}
return p
}
func (p Pattern) Match(v string) bool {
if !p.wildcard {
if p.prefix == v {
return true
} else {
return false
}
}
return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
}
chi-5.0.7/middleware/strip.go 0000664 0000000 0000000 00000003232 14145546033 0016100 0 ustar 00root root 0000000 0000000 package middleware
import (
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
)
// StripSlashes is a middleware that will match request paths with a trailing
// slash, strip it from the path and continue routing through the mux, if a route
// matches, then it will serve the handler.
func StripSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
var path string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
if len(path) > 1 && path[len(path)-1] == '/' {
newPath := path[:len(path)-1]
if rctx == nil {
r.URL.Path = newPath
} else {
rctx.RoutePath = newPath
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// RedirectSlashes is a middleware that will match request paths with a trailing
// slash and redirect to the same path, less the trailing slash.
//
// NOTE: RedirectSlashes middleware is *incompatible* with http.FileServer,
// see https://github.com/go-chi/chi/issues/343
func RedirectSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
var path string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
if len(path) > 1 && path[len(path)-1] == '/' {
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path[:len(path)-1], r.URL.RawQuery)
} else {
path = path[:len(path)-1]
}
redirectURL := fmt.Sprintf("//%s%s", r.Host, path)
http.Redirect(w, r, redirectURL, 301)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
chi-5.0.7/middleware/strip_test.go 0000664 0000000 0000000 00000015457 14145546033 0017153 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/go-chi/chi/v5"
)
func TestStripSlashes(t *testing.T) {
r := chi.NewRouter()
// This middleware must be mounted at the top level of the router, not at the end-handler
// because then it'll be too late and will end up in a 404
r.Use(StripSlashes)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/nothing-here", nil); resp != "nothing here" {
t.Fatalf(resp)
}
}
func TestStripSlashesInRoute(t *testing.T) {
r := chi.NewRouter()
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Use(StripSlashes)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("accounts index"))
})
r.Get("/query", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/hi", nil); resp != "hi" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/hi/", nil); resp != "nothing here" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "accounts index" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "accounts index" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query/", nil); resp != "admin" {
t.Fatalf(resp)
}
}
func TestRedirectSlashes(t *testing.T) {
r := chi.NewRouter()
// This middleware must be mounted at the top level of the router, not at the end-handler
// because then it'll be too late and will end up in a 404
r.Use(RedirectSlashes)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.Route("/accounts/{accountID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
accountID := chi.URLParam(r, "accountID")
w.Write([]byte(accountID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if resp, body := testRequest(t, ts, "GET", "/", nil); body != "root" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// NOTE: the testRequest client will follow the redirection..
if resp, body := testRequest(t, ts, "GET", "//", nil); body != "root" && resp.StatusCode != 200 {
t.Fatalf(body)
}
if resp, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// NOTE: the testRequest client will follow the redirection..
if resp, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" && resp.StatusCode != 200 {
t.Fatalf(body)
}
if resp, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" && resp.StatusCode != 200 {
t.Fatalf(body)
}
// Ensure redirect Location url is correct
{
resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/", nil)
if resp.StatusCode != 301 {
t.Fatalf(body)
}
location := resp.Header.Get("Location")
if !strings.HasPrefix(location, "//") || !strings.HasSuffix(location, "/accounts/someuser") {
t.Fatalf("invalid redirection, should be /accounts/someuser")
}
}
// Ensure query params are kept in tact upon redirecting a slash
{
resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/?a=1&b=2", nil)
if resp.StatusCode != 301 {
t.Fatalf(body)
}
location := resp.Header.Get("Location")
if !strings.HasPrefix(location, "//") || !strings.HasSuffix(location, "/accounts/someuser?a=1&b=2") {
t.Fatalf("invalid redirection, should be /accounts/someuser?a=1&b=2")
}
}
// Ensure that we don't redirect to 'evil.com', but rather to 'server.url/evil.com/'
{
paths := []string{"//evil.com/", "///evil.com/"}
for _, p := range paths {
resp, body := testRequest(t, ts, "GET", p, nil)
if u, err := url.Parse(ts.URL); err != nil && resp.Request.URL.Host != u.Host {
t.Fatalf("host should remain the same. got: %q, want: %q", resp.Request.URL.Host, ts.URL)
}
if body != "nothing here" && resp.StatusCode != 404 {
t.Fatalf(body)
}
}
}
// Ensure that we don't redirect to 'evil.com', but rather to 'server.url/evil.com/'
{
resp, body := testRequest(t, ts, "GET", "//evil.com/", nil)
if u, err := url.Parse(ts.URL); err != nil && resp.Request.URL.Host != u.Host {
t.Fatalf("host should remain the same. got: %q, want: %q", resp.Request.URL.Host, ts.URL)
}
if body != "nothing here" && resp.StatusCode != 404 {
t.Fatalf(body)
}
}
}
// This tests a http.Handler that is not chi.Router
// In these cases, the routeContext is nil
func TestStripSlashesWithNilContext(t *testing.T) {
r := http.NewServeMux()
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.HandleFunc("/accounts", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("accounts"))
})
r.HandleFunc("/accounts/admin", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("admin"))
})
ts := httptest.NewServer(StripSlashes(r))
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts", nil); resp != "accounts" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/", nil); resp != "accounts" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" {
t.Fatalf(resp)
}
}
chi-5.0.7/middleware/terminal.go 0000664 0000000 0000000 00000003535 14145546033 0016560 0 ustar 00root root 0000000 0000000 package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"fmt"
"io"
"os"
)
var (
// Normal colors
nBlack = []byte{'\033', '[', '3', '0', 'm'}
nRed = []byte{'\033', '[', '3', '1', 'm'}
nGreen = []byte{'\033', '[', '3', '2', 'm'}
nYellow = []byte{'\033', '[', '3', '3', 'm'}
nBlue = []byte{'\033', '[', '3', '4', 'm'}
nMagenta = []byte{'\033', '[', '3', '5', 'm'}
nCyan = []byte{'\033', '[', '3', '6', 'm'}
nWhite = []byte{'\033', '[', '3', '7', 'm'}
// Bright colors
bBlack = []byte{'\033', '[', '3', '0', ';', '1', 'm'}
bRed = []byte{'\033', '[', '3', '1', ';', '1', 'm'}
bGreen = []byte{'\033', '[', '3', '2', ';', '1', 'm'}
bYellow = []byte{'\033', '[', '3', '3', ';', '1', 'm'}
bBlue = []byte{'\033', '[', '3', '4', ';', '1', 'm'}
bMagenta = []byte{'\033', '[', '3', '5', ';', '1', 'm'}
bCyan = []byte{'\033', '[', '3', '6', ';', '1', 'm'}
bWhite = []byte{'\033', '[', '3', '7', ';', '1', 'm'}
reset = []byte{'\033', '[', '0', 'm'}
)
var IsTTY bool
func init() {
// This is sort of cheating: if stdout is a character device, we assume
// that means it's a TTY. Unfortunately, there are many non-TTY
// character devices, but fortunately stdout is rarely set to any of
// them.
//
// We could solve this properly by pulling in a dependency on
// code.google.com/p/go.crypto/ssh/terminal, for instance, but as a
// heuristic for whether to print in color or in black-and-white, I'd
// really rather not.
fi, err := os.Stdout.Stat()
if err == nil {
m := os.ModeDevice | os.ModeCharDevice
IsTTY = fi.Mode()&m == m
}
}
// colorWrite
func cW(w io.Writer, useColor bool, color []byte, s string, args ...interface{}) {
if IsTTY && useColor {
w.Write(color)
}
fmt.Fprintf(w, s, args...)
if IsTTY && useColor {
w.Write(reset)
}
}
chi-5.0.7/middleware/throttle.go 0000664 0000000 0000000 00000007355 14145546033 0016616 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"strconv"
"time"
)
const (
errCapacityExceeded = "Server capacity exceeded."
errTimedOut = "Timed out while waiting for a pending request to complete."
errContextCanceled = "Context was canceled."
)
var (
defaultBacklogTimeout = time.Second * 60
)
// ThrottleOpts represents a set of throttling options.
type ThrottleOpts struct {
RetryAfterFn func(ctxDone bool) time.Duration
Limit int
BacklogLimit int
BacklogTimeout time.Duration
}
// Throttle is a middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currentl in-flight requests
// being processed from the point from where the Throttle middleware is mounted.
func Throttle(limit int) func(http.Handler) http.Handler {
return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
}
// ThrottleBacklog is a middleware that limits number of currently processed
// requests at a time and provides a backlog for holding a finite number of
// pending requests.
func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout})
}
// ThrottleWithOpts is a middleware that limits number of currently processed requests using passed ThrottleOpts.
func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
if opts.Limit < 1 {
panic("chi/middleware: Throttle expects limit > 0")
}
if opts.BacklogLimit < 0 {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}
t := throttler{
tokens: make(chan token, opts.Limit),
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
backlogTimeout: opts.BacklogTimeout,
retryAfterFn: opts.RetryAfterFn,
}
// Filling tokens.
for i := 0; i < opts.Limit+opts.BacklogLimit; i++ {
if i < opts.Limit {
t.tokens <- token{}
}
t.backlogTokens <- token{}
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
select {
case <-ctx.Done():
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
return
case btok := <-t.backlogTokens:
timer := time.NewTimer(t.backlogTimeout)
defer func() {
t.backlogTokens <- btok
}()
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errTimedOut, http.StatusTooManyRequests)
return
case <-ctx.Done():
timer.Stop()
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
return
case tok := <-t.tokens:
defer func() {
timer.Stop()
t.tokens <- tok
}()
next.ServeHTTP(w, r)
}
return
default:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
return
}
}
return http.HandlerFunc(fn)
}
}
// token represents a request that is being processed.
type token struct{}
// throttler limits number of currently processed requests at a time.
type throttler struct {
tokens chan token
backlogTokens chan token
retryAfterFn func(ctxDone bool) time.Duration
backlogTimeout time.Duration
}
// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) {
if t.retryAfterFn == nil {
return
}
w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds())))
}
chi-5.0.7/middleware/throttle_test.go 0000664 0000000 0000000 00000012604 14145546033 0017646 0 ustar 00root root 0000000 0000000 package middleware
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/go-chi/chi/v5"
)
var testContent = []byte("Hello world!")
func TestThrottleBacklog(t *testing.T) {
r := chi.NewRouter()
r.Use(ThrottleBacklog(10, 50, time.Second*10))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 1) // Expensive operation.
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()
client := http.Client{
Timeout: time.Second * 5, // Maximum waiting time.
}
var wg sync.WaitGroup
// The throttler proccesses 10 consecutive requests, each one of those
// requests lasts 1s. The maximum number of requests this can possible serve
// before the clients time out (5s) is 40.
for i := 0; i < 40; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)
}(i)
}
wg.Wait()
}
func TestThrottleClientTimeout(t *testing.T) {
r := chi.NewRouter()
r.Use(ThrottleBacklog(10, 50, time.Second*10))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 5) // Expensive operation.
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()
client := http.Client{
Timeout: time.Second * 3, // Maximum waiting time.
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_, err := client.Get(server.URL)
assertError(t, err)
}(i)
}
wg.Wait()
}
func TestThrottleTriggerGatewayTimeout(t *testing.T) {
r := chi.NewRouter()
r.Use(ThrottleBacklog(50, 100, time.Second*5))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 10) // Expensive operation.
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()
client := http.Client{
Timeout: time.Second * 60, // Maximum waiting time.
}
var wg sync.WaitGroup
// These requests will be processed normally until they finish.
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)
}(i)
}
time.Sleep(time.Second * 1)
// These requests will wait for the first batch to complete but it will take
// too much time, so they will eventually receive a timeout error.
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))
}(i)
}
wg.Wait()
}
func TestThrottleMaximum(t *testing.T) {
r := chi.NewRouter()
r.Use(ThrottleBacklog(10, 10, time.Second*5))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 3) // Expensive operation.
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()
client := http.Client{
Timeout: time.Second * 60, // Maximum waiting time.
}
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)
}(i)
}
// Wait less time than what the server takes to reply.
time.Sleep(time.Second * 2)
// At this point the server is still processing, all the following request
// will be beyond the server capacity.
for i := 0; i < 20; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))
}(i)
}
wg.Wait()
}
// NOTE: test is disabled as it requires some refactoring. It is prone to intermittent failure.
/*func TestThrottleRetryAfter(t *testing.T) {
r := chi.NewRouter()
retryAfterFn := func(ctxDone bool) time.Duration { return time.Hour * 1 }
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 10, RetryAfterFn: retryAfterFn}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 4) // Expensive operation.
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()
client := http.Client{
Timeout: time.Second * 60, // Maximum waiting time.
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, res.Header.Get("Retry-After"), "3600")
}(i)
}
wg.Wait()
}*/
chi-5.0.7/middleware/timeout.go 0000664 0000000 0000000 00000002325 14145546033 0016427 0 ustar 00root root 0000000 0000000 package middleware
import (
"context"
"net/http"
"time"
)
// Timeout is a middleware that cancels ctx after a given timeout and return
// a 504 Gateway Timeout error to the client.
//
// It's required that you select the ctx.Done() channel to check for the signal
// if the context has reached its deadline and return, otherwise the timeout
// signal will be just ignored.
//
// ie. a route/handler may look like:
//
// r.Get("/long", func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// processTime := time.Duration(rand.Intn(4)+1) * time.Second
//
// select {
// case <-ctx.Done():
// return
//
// case <-time.After(processTime):
// // The above channel simulates some hard work.
// }
//
// w.Write([]byte("done"))
// })
//
func Timeout(timeout time.Duration) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer func() {
cancel()
if ctx.Err() == context.DeadlineExceeded {
w.WriteHeader(http.StatusGatewayTimeout)
}
}()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
chi-5.0.7/middleware/url_format.go 0000664 0000000 0000000 00000003325 14145546033 0017114 0 ustar 00root root 0000000 0000000 package middleware
import (
"context"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
)
var (
// URLFormatCtxKey is the context.Context key to store the URL format data
// for a request.
URLFormatCtxKey = &contextKey{"URLFormat"}
)
// URLFormat is a middleware that parses the url extension from a request path and stores it
// on the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will
// trim the suffix from the routing path and continue routing.
//
// Routers should not include a url parameter for the suffix when using this middleware.
//
// Sample usage.. for url paths: `/articles/1`, `/articles/1.json` and `/articles/1.xml`
//
// func routes() http.Handler {
// r := chi.NewRouter()
// r.Use(middleware.URLFormat)
//
// r.Get("/articles/{id}", ListArticles)
//
// return r
// }
//
// func ListArticles(w http.ResponseWriter, r *http.Request) {
// urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string)
//
// switch urlFormat {
// case "json":
// render.JSON(w, r, articles)
// case "xml:"
// render.XML(w, r, articles)
// default:
// render.JSON(w, r, articles)
// }
// }
//
func URLFormat(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var format string
path := r.URL.Path
if strings.Index(path, ".") > 0 {
base := strings.LastIndex(path, "/")
idx := strings.LastIndex(path[base:], ".")
if idx > 0 {
idx += base
format = path[idx+1:]
rctx := chi.RouteContext(r.Context())
rctx.RoutePath = path[:idx]
}
}
r = r.WithContext(context.WithValue(ctx, URLFormatCtxKey, format))
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
chi-5.0.7/middleware/url_format_test.go 0000664 0000000 0000000 00000002261 14145546033 0020151 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestURLFormat(t *testing.T) {
r := chi.NewRouter()
r.Use(URLFormat)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
r.Route("/samples/articles/samples.{articleID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
articleID := chi.URLParam(r, "articleID")
w.Write([]byte(articleID))
})
})
r.Route("/articles/{articleID}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
articleID := chi.URLParam(r, "articleID")
w.Write([]byte(articleID))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, resp := testRequest(t, ts, "GET", "/articles/1.json", nil); resp != "1" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/articles/1.xml", nil); resp != "1" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/samples/articles/samples.1.json", nil); resp != "1" {
t.Fatalf(resp)
}
if _, resp := testRequest(t, ts, "GET", "/samples/articles/samples.1.xml", nil); resp != "1" {
t.Fatalf(resp)
}
}
chi-5.0.7/middleware/value.go 0000664 0000000 0000000 00000000664 14145546033 0016061 0 ustar 00root root 0000000 0000000 package middleware
import (
"context"
"net/http"
)
// WithValue is a middleware that sets a given key/value in a context chain.
func WithValue(key, val interface{}) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), key, val))
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
chi-5.0.7/middleware/wrap_writer.go 0000664 0000000 0000000 00000012501 14145546033 0017303 0 ustar 00root root 0000000 0000000 package middleware
// The original work was derived from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"bufio"
"io"
"net"
"net/http"
)
// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
// hook into various parts of the response process.
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
_, fl := w.(http.Flusher)
bw := basicWriter{ResponseWriter: w}
if protoMajor == 2 {
_, ps := w.(http.Pusher)
if fl && ps {
return &http2FancyWriter{bw}
}
} else {
_, hj := w.(http.Hijacker)
_, rf := w.(io.ReaderFrom)
if fl && hj && rf {
return &httpFancyWriter{bw}
}
if fl && hj {
return &flushHijackWriter{bw}
}
if hj {
return &hijackWriter{bw}
}
}
if fl {
return &flushWriter{bw}
}
return &bw
}
// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook
// into various parts of the response process.
type WrapResponseWriter interface {
http.ResponseWriter
// Status returns the HTTP status of the request, or 0 if one has not
// yet been sent.
Status() int
// BytesWritten returns the total number of bytes sent to the client.
BytesWritten() int
// Tee causes the response body to be written to the given io.Writer in
// addition to proxying the writes through. Only one io.Writer can be
// tee'd to at once: setting a second one will overwrite the first.
// Writes will be sent to the proxy before being written to this
// io.Writer. It is illegal for the tee'd writer to be modified
// concurrently with writes.
Tee(io.Writer)
// Unwrap returns the original proxied target.
Unwrap() http.ResponseWriter
}
// basicWriter wraps a http.ResponseWriter that implements the minimal
// http.ResponseWriter interface.
type basicWriter struct {
http.ResponseWriter
wroteHeader bool
code int
bytes int
tee io.Writer
}
func (b *basicWriter) WriteHeader(code int) {
if !b.wroteHeader {
b.code = code
b.wroteHeader = true
b.ResponseWriter.WriteHeader(code)
}
}
func (b *basicWriter) Write(buf []byte) (int, error) {
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
}
}
b.bytes += n
return n, err
}
func (b *basicWriter) maybeWriteHeader() {
if !b.wroteHeader {
b.WriteHeader(http.StatusOK)
}
}
func (b *basicWriter) Status() int {
return b.code
}
func (b *basicWriter) BytesWritten() int {
return b.bytes
}
func (b *basicWriter) Tee(w io.Writer) {
b.tee = w
}
func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}
// flushWriter ...
type flushWriter struct {
basicWriter
}
func (f *flushWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
var _ http.Flusher = &flushWriter{}
// hijackWriter ...
type hijackWriter struct {
basicWriter
}
func (f *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
return hj.Hijack()
}
var _ http.Hijacker = &hijackWriter{}
// flushHijackWriter ...
type flushHijackWriter struct {
basicWriter
}
func (f *flushHijackWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
func (f *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
return hj.Hijack()
}
var _ http.Flusher = &flushHijackWriter{}
var _ http.Hijacker = &flushHijackWriter{}
// httpFancyWriter is a HTTP writer that additionally satisfies
// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case
// of wrapping the http.ResponseWriter that package http gives you, in order to
// make the proxied object support the full method set of the proxied object.
type httpFancyWriter struct {
basicWriter
}
func (f *httpFancyWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
return hj.Hijack()
}
func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error {
return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts)
}
func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) {
if f.basicWriter.tee != nil {
n, err := io.Copy(&f.basicWriter, r)
f.basicWriter.bytes += int(n)
return n, err
}
rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
f.basicWriter.maybeWriteHeader()
n, err := rf.ReadFrom(r)
f.basicWriter.bytes += int(n)
return n, err
}
var _ http.Flusher = &httpFancyWriter{}
var _ http.Hijacker = &httpFancyWriter{}
var _ http.Pusher = &http2FancyWriter{}
var _ io.ReaderFrom = &httpFancyWriter{}
// http2FancyWriter is a HTTP2 writer that additionally satisfies
// http.Flusher, and io.ReaderFrom. It exists for the common case
// of wrapping the http.ResponseWriter that package http gives you, in order to
// make the proxied object support the full method set of the proxied object.
type http2FancyWriter struct {
basicWriter
}
func (f *http2FancyWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
var _ http.Flusher = &http2FancyWriter{}
chi-5.0.7/middleware/wrap_writer_test.go 0000664 0000000 0000000 00000001056 14145546033 0020345 0 ustar 00root root 0000000 0000000 package middleware
import (
"net/http/httptest"
"testing"
)
func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: httptest.NewRecorder()}}
f.Flush()
if !f.wroteHeader {
t.Fatal("want Flush to have set wroteHeader=true")
}
}
func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
f.Flush()
if !f.wroteHeader {
t.Fatal("want Flush to have set wroteHeader=true")
}
}
chi-5.0.7/mux.go 0000664 0000000 0000000 00000036205 14145546033 0013441 0 ustar 00root root 0000000 0000000 package chi
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
)
var _ Router = &Mux{}
// Mux is a simple HTTP route multiplexer that parses a request path,
// records any URL params, and executes an end handler. It implements
// the http.Handler interface and is friendly with the standard library.
//
// Mux is designed to be fast, minimal and offer a powerful API for building
// modular and composable HTTP services with a large set of handlers. It's
// particularly useful for writing large REST API services that break a handler
// into many smaller parts composed of middlewares and end handlers.
type Mux struct {
// The computed mux handler made of the chained middleware stack and
// the tree router
handler http.Handler
// The radix trie router
tree *node
// Custom method not allowed handler
methodNotAllowedHandler http.HandlerFunc
// Controls the behaviour of middleware chain generation when a mux
// is registered as an inline group inside another mux.
parent *Mux
// Routing context pool
pool *sync.Pool
// Custom route not found handler
notFoundHandler http.HandlerFunc
// The middleware stack
middlewares []func(http.Handler) http.Handler
inline bool
}
// NewMux returns a newly initialized Mux object that implements the Router
// interface.
func NewMux() *Mux {
mux := &Mux{tree: &node{}, pool: &sync.Pool{}}
mux.pool.New = func() interface{} {
return NewRouteContext()
}
return mux
}
// ServeHTTP is the single method of the http.Handler interface that makes
// Mux interoperable with the standard library. It uses a sync.Pool to get and
// reuse routing contexts for each request.
func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Ensure the mux has some routes defined on the mux
if mx.handler == nil {
mx.NotFoundHandler().ServeHTTP(w, r)
return
}
// Check if a routing context already exists from a parent router.
rctx, _ := r.Context().Value(RouteCtxKey).(*Context)
if rctx != nil {
mx.handler.ServeHTTP(w, r)
return
}
// Fetch a RouteContext object from the sync pool, and call the computed
// mx.handler that is comprised of mx.middlewares + mx.routeHTTP.
// Once the request is finished, reset the routing context and put it back
// into the pool for reuse from another request.
rctx = mx.pool.Get().(*Context)
rctx.Reset()
rctx.Routes = mx
rctx.parentCtx = r.Context()
// NOTE: r.WithContext() causes 2 allocations and context.WithValue() causes 1 allocation
r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
// Serve the request and once its done, put the request context back in the sync pool
mx.handler.ServeHTTP(w, r)
mx.pool.Put(rctx)
}
// Use appends a middleware handler to the Mux middleware stack.
//
// The middleware stack for any Mux will execute before searching for a matching
// route to a specific handler, which provides opportunity to respond early,
// change the course of the request execution, or set request-scoped values for
// the next http.Handler.
func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) {
if mx.handler != nil {
panic("chi: all middlewares must be defined before routes on a mux")
}
mx.middlewares = append(mx.middlewares, middlewares...)
}
// Handle adds the route `pattern` that matches any http method to
// execute the `handler` http.Handler.
func (mx *Mux) Handle(pattern string, handler http.Handler) {
mx.handle(mALL, pattern, handler)
}
// HandleFunc adds the route `pattern` that matches any http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mALL, pattern, handlerFn)
}
// Method adds the route `pattern` that matches `method` http method to
// execute the `handler` http.Handler.
func (mx *Mux) Method(method, pattern string, handler http.Handler) {
m, ok := methodMap[strings.ToUpper(method)]
if !ok {
panic(fmt.Sprintf("chi: '%s' http method is not supported.", method))
}
mx.handle(m, pattern, handler)
}
// MethodFunc adds the route `pattern` that matches `method` http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) {
mx.Method(method, pattern, handlerFn)
}
// Connect adds the route `pattern` that matches a CONNECT http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mCONNECT, pattern, handlerFn)
}
// Delete adds the route `pattern` that matches a DELETE http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mDELETE, pattern, handlerFn)
}
// Get adds the route `pattern` that matches a GET http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mGET, pattern, handlerFn)
}
// Head adds the route `pattern` that matches a HEAD http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mHEAD, pattern, handlerFn)
}
// Options adds the route `pattern` that matches a OPTIONS http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mOPTIONS, pattern, handlerFn)
}
// Patch adds the route `pattern` that matches a PATCH http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mPATCH, pattern, handlerFn)
}
// Post adds the route `pattern` that matches a POST http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mPOST, pattern, handlerFn)
}
// Put adds the route `pattern` that matches a PUT http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mPUT, pattern, handlerFn)
}
// Trace adds the route `pattern` that matches a TRACE http method to
// execute the `handlerFn` http.HandlerFunc.
func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) {
mx.handle(mTRACE, pattern, handlerFn)
}
// NotFound sets a custom http.HandlerFunc for routing paths that could
// not be found. The default 404 handler is `http.NotFound`.
func (mx *Mux) NotFound(handlerFn http.HandlerFunc) {
// Build NotFound handler chain
m := mx
hFn := handlerFn
if mx.inline && mx.parent != nil {
m = mx.parent
hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
}
// Update the notFoundHandler from this point forward
m.notFoundHandler = hFn
m.updateSubRoutes(func(subMux *Mux) {
if subMux.notFoundHandler == nil {
subMux.NotFound(hFn)
}
})
}
// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the
// method is unresolved. The default handler returns a 405 with an empty body.
func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) {
// Build MethodNotAllowed handler chain
m := mx
hFn := handlerFn
if mx.inline && mx.parent != nil {
m = mx.parent
hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
}
// Update the methodNotAllowedHandler from this point forward
m.methodNotAllowedHandler = hFn
m.updateSubRoutes(func(subMux *Mux) {
if subMux.methodNotAllowedHandler == nil {
subMux.MethodNotAllowed(hFn)
}
})
}
// With adds inline middlewares for an endpoint handler.
func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router {
// Similarly as in handle(), we must build the mux handler once additional
// middleware registration isn't allowed for this stack, like now.
if !mx.inline && mx.handler == nil {
mx.updateRouteHandler()
}
// Copy middlewares from parent inline muxs
var mws Middlewares
if mx.inline {
mws = make(Middlewares, len(mx.middlewares))
copy(mws, mx.middlewares)
}
mws = append(mws, middlewares...)
im := &Mux{
pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws,
notFoundHandler: mx.notFoundHandler, methodNotAllowedHandler: mx.methodNotAllowedHandler,
}
return im
}
// Group creates a new inline-Mux with a fresh middleware stack. It's useful
// for a group of handlers along the same routing path that use an additional
// set of middlewares. See _examples/.
func (mx *Mux) Group(fn func(r Router)) Router {
im := mx.With().(*Mux)
if fn != nil {
fn(im)
}
return im
}
// Route creates a new Mux with a fresh middleware stack and mounts it
// along the `pattern` as a subrouter. Effectively, this is a short-hand
// call to Mount. See _examples/.
func (mx *Mux) Route(pattern string, fn func(r Router)) Router {
if fn == nil {
panic(fmt.Sprintf("chi: attempting to Route() a nil subrouter on '%s'", pattern))
}
subRouter := NewRouter()
fn(subRouter)
mx.Mount(pattern, subRouter)
return subRouter
}
// Mount attaches another http.Handler or chi Router as a subrouter along a routing
// path. It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount. See _examples/.
//
// Note that Mount() simply sets a wildcard along the `pattern` that will continue
// routing at the `handler`, which in most cases is another chi.Router. As a result,
// if you define two Mount() routes on the exact same pattern the mount will panic.
func (mx *Mux) Mount(pattern string, handler http.Handler) {
if handler == nil {
panic(fmt.Sprintf("chi: attempting to Mount() a nil handler on '%s'", pattern))
}
// Provide runtime safety for ensuring a pattern isn't mounted on an existing
// routing pattern.
if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") {
panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern))
}
// Assign sub-Router's with the parent not found & method not allowed handler if not specified.
subr, ok := handler.(*Mux)
if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil {
subr.NotFound(mx.notFoundHandler)
}
if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil {
subr.MethodNotAllowed(mx.methodNotAllowedHandler)
}
mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rctx := RouteContext(r.Context())
// shift the url path past the previous subrouter
rctx.RoutePath = mx.nextRoutePath(rctx)
// reset the wildcard URLParam which connects the subrouter
n := len(rctx.URLParams.Keys) - 1
if n >= 0 && rctx.URLParams.Keys[n] == "*" && len(rctx.URLParams.Values) > n {
rctx.URLParams.Values[n] = ""
}
handler.ServeHTTP(w, r)
})
if pattern == "" || pattern[len(pattern)-1] != '/' {
mx.handle(mALL|mSTUB, pattern, mountHandler)
mx.handle(mALL|mSTUB, pattern+"/", mountHandler)
pattern += "/"
}
method := mALL
subroutes, _ := handler.(Routes)
if subroutes != nil {
method |= mSTUB
}
n := mx.handle(method, pattern+"*", mountHandler)
if subroutes != nil {
n.subroutes = subroutes
}
}
// Routes returns a slice of routing information from the tree,
// useful for traversing available routes of a router.
func (mx *Mux) Routes() []Route {
return mx.tree.routes()
}
// Middlewares returns a slice of middleware handler functions.
func (mx *Mux) Middlewares() Middlewares {
return mx.middlewares
}
// Match searches the routing tree for a handler that matches the method/path.
// It's similar to routing a http request, but without executing the handler
// thereafter.
//
// Note: the *Context state is updated during execution, so manage
// the state carefully or make a NewRouteContext().
func (mx *Mux) Match(rctx *Context, method, path string) bool {
m, ok := methodMap[method]
if !ok {
return false
}
node, _, h := mx.tree.FindRoute(rctx, m, path)
if node != nil && node.subroutes != nil {
rctx.RoutePath = mx.nextRoutePath(rctx)
return node.subroutes.Match(rctx, method, rctx.RoutePath)
}
return h != nil
}
// NotFoundHandler returns the default Mux 404 responder whenever a route
// cannot be found.
func (mx *Mux) NotFoundHandler() http.HandlerFunc {
if mx.notFoundHandler != nil {
return mx.notFoundHandler
}
return http.NotFound
}
// MethodNotAllowedHandler returns the default Mux 405 responder whenever
// a method cannot be resolved for a route.
func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc {
if mx.methodNotAllowedHandler != nil {
return mx.methodNotAllowedHandler
}
return methodNotAllowedHandler
}
// handle registers a http.Handler in the routing tree for a particular http method
// and routing pattern.
func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node {
if len(pattern) == 0 || pattern[0] != '/' {
panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern))
}
// Build the computed routing handler for this routing pattern.
if !mx.inline && mx.handler == nil {
mx.updateRouteHandler()
}
// Build endpoint handler with inline middlewares for the route
var h http.Handler
if mx.inline {
mx.handler = http.HandlerFunc(mx.routeHTTP)
h = Chain(mx.middlewares...).Handler(handler)
} else {
h = handler
}
// Add the endpoint to the tree and return the node
return mx.tree.InsertRoute(method, pattern, h)
}
// routeHTTP routes a http.Request through the Mux routing tree to serve
// the matching handler for a particular http method.
func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
// Grab the route context object
rctx := r.Context().Value(RouteCtxKey).(*Context)
// The request routing path
routePath := rctx.RoutePath
if routePath == "" {
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
} else {
routePath = r.URL.Path
}
if routePath == "" {
routePath = "/"
}
}
// Check if method is supported by chi
if rctx.RouteMethod == "" {
rctx.RouteMethod = r.Method
}
method, ok := methodMap[rctx.RouteMethod]
if !ok {
mx.MethodNotAllowedHandler().ServeHTTP(w, r)
return
}
// Find the route
if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil {
h.ServeHTTP(w, r)
return
}
if rctx.methodNotAllowed {
mx.MethodNotAllowedHandler().ServeHTTP(w, r)
} else {
mx.NotFoundHandler().ServeHTTP(w, r)
}
}
func (mx *Mux) nextRoutePath(rctx *Context) string {
routePath := "/"
nx := len(rctx.routeParams.Keys) - 1 // index of last param in list
if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx {
routePath = "/" + rctx.routeParams.Values[nx]
}
return routePath
}
// Recursively update data on child routers.
func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) {
for _, r := range mx.tree.routes() {
subMux, ok := r.SubRoutes.(*Mux)
if !ok {
continue
}
fn(subMux)
}
}
// updateRouteHandler builds the single mux handler that is a chain of the middleware
// stack, as defined by calls to Use(), and the tree router (Mux) itself. After this
// point, no other middlewares can be registered on this Mux's stack. But you can still
// compose additional middlewares via Group()'s or using a chained middleware handler.
func (mx *Mux) updateRouteHandler() {
mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP))
}
// methodNotAllowedHandler is a helper function to respond with a 405,
// method not allowed.
func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write(nil)
}
chi-5.0.7/mux_test.go 0000664 0000000 0000000 00000144310 14145546033 0014475 0 ustar 00root root 0000000 0000000 package chi
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
)
func TestMuxBasic(t *testing.T) {
var count uint64
countermw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
next.ServeHTTP(w, r)
})
}
usermw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, ctxKey{"user"}, "peter")
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
exmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a")
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
logbuf := bytes.NewBufferString("")
logmsg := "logmw test"
logmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logbuf.WriteString(logmsg)
next.ServeHTTP(w, r)
})
}
cxindex := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := ctx.Value(ctxKey{"user"}).(string)
w.WriteHeader(200)
w.Write([]byte(fmt.Sprintf("hi %s", user)))
}
ping := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("."))
}
headPing := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Ping", "1")
w.WriteHeader(200)
}
createPing := func(w http.ResponseWriter, r *http.Request) {
// create ....
w.WriteHeader(201)
}
pingAll := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("ping all"))
}
pingAll2 := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("ping all2"))
}
pingOne := func(w http.ResponseWriter, r *http.Request) {
idParam := URLParam(r, "id")
w.WriteHeader(200)
w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam)))
}
pingWoop := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("woop." + URLParam(r, "iidd")))
}
catchAll := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("catchall"))
}
m := NewRouter()
m.Use(countermw)
m.Use(usermw)
m.Use(exmw)
m.Use(logmw)
m.Get("/", cxindex)
m.Method("GET", "/ping", http.HandlerFunc(ping))
m.MethodFunc("GET", "/pingall", pingAll)
m.MethodFunc("get", "/ping/all", pingAll)
m.Get("/ping/all2", pingAll2)
m.Head("/ping", headPing)
m.Post("/ping", createPing)
m.Get("/ping/{id}", pingWoop)
m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler
m.Get("/ping/{iidd}/woop", pingWoop)
m.HandleFunc("/admin/*", catchAll)
// m.Post("/admin/*", catchAll)
ts := httptest.NewServer(m)
defer ts.Close()
// GET /
if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" {
t.Fatalf(body)
}
tlogmsg, _ := logbuf.ReadString(0)
if tlogmsg != logmsg {
t.Error("expecting log message from middleware:", logmsg)
}
// GET /ping
if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." {
t.Fatalf(body)
}
// GET /pingall
if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" {
t.Fatalf(body)
}
// GET /ping/all
if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" {
t.Fatalf(body)
}
// GET /ping/all2
if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" {
t.Fatalf(body)
}
// GET /ping/123
if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" {
t.Fatalf(body)
}
// GET /ping/allan
if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" {
t.Fatalf(body)
}
// GET /ping/1/woop
if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" {
t.Fatalf(body)
}
// HEAD /ping
resp, err := http.Head(ts.URL + "/ping")
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Error("head failed, should be 200")
}
if resp.Header.Get("X-Ping") == "" {
t.Error("expecting X-Ping header")
}
// GET /admin/catch-this
if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" {
t.Fatalf(body)
}
// POST /admin/catch-this
resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatal(err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Error("POST failed, should be 200")
}
if string(body) != "catchall" {
t.Error("expecting response body: 'catchall'")
}
// Custom http method DIE /ping/1/woop
if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 {
t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body))
}
}
func TestMuxMounts(t *testing.T) {
r := NewRouter()
r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
w.Write([]byte(fmt.Sprintf("/%s", v)))
})
r.Route("/{hash}/share", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
w.Write([]byte(fmt.Sprintf("/%s/share", v)))
})
r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) {
v := URLParam(r, "hash")
n := URLParam(r, "network")
w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n)))
})
})
m := NewRouter()
m.Mount("/sharing", r)
ts := httptest.NewServer(m)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" {
t.Fatalf(body)
}
}
func TestMuxPlain(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxEmptyRoutes(t *testing.T) {
mux := NewRouter()
apiRouter := NewRouter()
// oops, we forgot to declare any route handlers
mux.Handle("/api*", apiRouter)
if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" {
t.Fatalf(body)
}
if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" {
t.Fatalf(body)
}
}
// Test a mux that routes a trailing slash, see also middleware/strip_test.go
// for an example of using a middleware to handle trailing slashes.
func TestMuxTrailingSlash(t *testing.T) {
r := NewRouter()
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
subRoutes := NewRouter()
indexHandler := func(w http.ResponseWriter, r *http.Request) {
accountID := URLParam(r, "accountID")
w.Write([]byte(accountID))
}
subRoutes.Get("/", indexHandler)
r.Mount("/accounts/{accountID}", subRoutes)
r.Get("/accounts/{accountID}/", indexHandler)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxNestedNotFound(t *testing.T) {
r := NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw"))
next.ServeHTTP(w, r)
})
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with"))
next.ServeHTTP(w, r)
})
}).NotFound(func(w http.ResponseWriter, r *http.Request) {
chkMw := r.Context().Value(ctxKey{"mw"}).(string)
chkWith := r.Context().Value(ctxKey{"with"}).(string)
w.WriteHeader(404)
w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith)))
})
sr1 := NewRouter()
sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub"))
})
sr1.Group(func(sr1 Router) {
sr1.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2"))
next.ServeHTTP(w, r)
})
})
sr1.NotFound(func(w http.ResponseWriter, r *http.Request) {
chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string)
w.WriteHeader(404)
w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2)))
})
})
sr2 := NewRouter()
sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub2"))
})
r.Mount("/admin1", sr1)
r.Mount("/admin2", sr2)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" {
t.Fatalf(body)
}
// Not found pages should bubble up to the root.
if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" {
t.Fatalf(body)
}
}
func TestMuxNestedMethodNotAllowed(t *testing.T) {
r := NewRouter()
r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root"))
})
r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("root 405"))
})
sr1 := NewRouter()
sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub1"))
})
sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("sub1 405"))
})
sr2 := NewRouter()
sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sub2"))
})
pathVar := NewRouter()
pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pv"))
})
pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("pv 405"))
})
r.Mount("/prefix1", sr1)
r.Mount("/prefix2", sr2)
r.Mount("/pathVar", pathVar)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" {
t.Fatalf(body)
}
}
func TestMuxComplicatedNotFound(t *testing.T) {
decorateRouter := func(r *Mux) {
// Root router with groups
r.Get("/auth", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("auth get"))
})
r.Route("/public", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("public get"))
})
})
// sub router with groups
sub0 := NewRouter()
sub0.Route("/resource", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("private get"))
})
})
r.Mount("/private", sub0)
// sub router with groups
sub1 := NewRouter()
sub1.Route("/resource", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("private get"))
})
})
r.With(func(next http.Handler) http.Handler { return next }).Mount("/private_mw", sub1)
}
testNotFound := func(t *testing.T, r *Mux) {
ts := httptest.NewServer(r)
defer ts.Close()
// check that we didn't break correct routes
if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" {
t.Fatalf(body)
}
// check custom not-found on all levels
if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" {
t.Fatalf(body)
}
// check custom not-found on trailing slash routes
if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" {
t.Fatalf(body)
}
}
t.Run("pre", func(t *testing.T) {
r := NewRouter()
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom not-found"))
})
decorateRouter(r)
testNotFound(t, r)
})
t.Run("post", func(t *testing.T) {
r := NewRouter()
decorateRouter(r)
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom not-found"))
})
testNotFound(t, r)
})
}
func TestMuxWith(t *testing.T) {
var cmwInit1, cmwHandler1 uint64
var cmwInit2, cmwHandler2 uint64
mw1 := func(next http.Handler) http.Handler {
cmwInit1++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cmwHandler1++
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes"))
next.ServeHTTP(w, r)
})
}
mw2 := func(next http.Handler) http.Handler {
cmwInit2++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cmwHandler2++
r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes"))
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) {
v1 := r.Context().Value(ctxKey{"inline1"}).(string)
v2 := r.Context().Value(ctxKey{"inline2"}).(string)
w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2)))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" {
t.Fatalf(body)
}
if cmwInit1 != 1 {
t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1)
}
if cmwHandler1 != 1 {
t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1)
}
if cmwInit2 != 1 {
t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2)
}
if cmwHandler2 != 1 {
t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2)
}
}
func TestRouterFromMuxWith(t *testing.T) {
t.Parallel()
r := NewRouter()
with := r.With(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
})
with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {})
ts := httptest.NewServer(with)
defer ts.Close()
// Without the fix this test was committed with, this causes a panic.
testRequest(t, ts, http.MethodGet, "/with_middleware", nil)
}
func TestMuxMiddlewareStack(t *testing.T) {
var stdmwInit, stdmwHandler uint64
stdmw := func(next http.Handler) http.Handler {
stdmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler++
next.ServeHTTP(w, r)
})
}
_ = stdmw
var ctxmwInit, ctxmwHandler uint64
ctxmw := func(next http.Handler) http.Handler {
ctxmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxmwHandler++
ctx := r.Context()
ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
var inCtxmwInit, inCtxmwHandler uint64
inCtxmw := func(next http.Handler) http.Handler {
inCtxmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
inCtxmwHandler++
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Use(stdmw)
r.Use(ctxmw)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ping" {
w.Write([]byte("pong"))
return
}
next.ServeHTTP(w, r)
})
})
var handlerCount uint64
r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) {
handlerCount++
ctx := r.Context()
ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64)
w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount)))
})
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("wooot"))
})
ts := httptest.NewServer(r)
defer ts.Close()
testRequest(t, ts, "GET", "/", nil)
testRequest(t, ts, "GET", "/", nil)
var body string
_, body = testRequest(t, ts, "GET", "/", nil)
if body != "inits:1 reqs:3 ctxValue:3" {
t.Fatalf("got: '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/ping", nil)
if body != "pong" {
t.Fatalf("got: '%s'", body)
}
}
func TestMuxRouteGroups(t *testing.T) {
var stdmwInit, stdmwHandler uint64
stdmw := func(next http.Handler) http.Handler {
stdmwInit++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler++
next.ServeHTTP(w, r)
})
}
var stdmwInit2, stdmwHandler2 uint64
stdmw2 := func(next http.Handler) http.Handler {
stdmwInit2++
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stdmwHandler2++
next.ServeHTTP(w, r)
})
}
r := NewRouter()
r.Group(func(r Router) {
r.Use(stdmw)
r.Get("/group", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root group"))
})
})
r.Group(func(r Router) {
r.Use(stdmw2)
r.Get("/group2", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("root group2"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
// GET /group
_, body := testRequest(t, ts, "GET", "/group", nil)
if body != "root group" {
t.Fatalf("got: '%s'", body)
}
if stdmwInit != 1 || stdmwHandler != 1 {
t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler)
}
// GET /group2
_, body = testRequest(t, ts, "GET", "/group2", nil)
if body != "root group2" {
t.Fatalf("got: '%s'", body)
}
if stdmwInit2 != 1 || stdmwHandler2 != 1 {
t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2)
}
}
func TestMuxBig(t *testing.T) {
r := bigMux()
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/favicon.ico", nil)
if body != "fav" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/4/view", nil)
if body != "/hubs/4/view reqid:1 session:anonymous" {
t.Fatalf("got '%v'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil)
if body != "/hubs/4/view/index.html reqid:1 session:anonymous" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil)
if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/", nil)
if body != "/ reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/suggestions", nil)
if body != "/suggestions reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil)
if body != "/woot/444/hiiii" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123", nil)
expected = "/hubs/123 reqid:1 session:elvis"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil)
if body != "/hubs/123/touch reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil)
if body != "/hubs/123/webhooks reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil)
if body != "/hubs/123/posts reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders", nil)
if body != "404 page not found\n" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/", nil)
if body != "/folders/ reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/public", nil)
if body != "/folders/public reqid:1 session:elvis" {
t.Fatalf("got '%s'", body)
}
_, body = testRequest(t, ts, "GET", "/folders/nothing", nil)
if body != "404 page not found\n" {
t.Fatalf("got '%s'", body)
}
}
func bigMux() Router {
var r *Mux
var sr3 *Mux
// var sr1, sr2, sr3, sr4, sr5, sr6 *Mux
r = NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
})
r.Group(func(r Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("fav"))
})
r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"),
URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"),
URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
r.Group(func(r Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis")
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) {
s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*"))
w.Write([]byte(s))
})
r.Route("/hubs", func(r Router) {
_ = r.(*Mux) // sr1
r.Route("/{hubID}", func(r Router) {
_ = r.(*Mux) // sr2
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s reqid:%s session:%s",
URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/touch", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
sr3 = NewRouter()
sr3.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
sr3.Route("/{webhookID}", func(r Router) {
_ = r.(*Mux) // sr4
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"),
URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true)))
})
}).Handler(sr3))
r.Route("/posts", func(r Router) {
_ = r.(*Mux) // sr5
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"),
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
})
})
r.Route("/folders/", func(r Router) {
_ = r.(*Mux) // sr6
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/folders/ reqid:%s session:%s",
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
r.Get("/public", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s := fmt.Sprintf("/folders/public reqid:%s session:%s",
ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
w.Write([]byte(s))
})
})
})
return r
}
func TestMuxSubroutesBasic(t *testing.T) {
hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("articles-list"))
})
hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("search-articles"))
})
hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id"))))
})
hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id"))))
})
r := NewRouter()
// var rr1, rr2 *Mux
r.Get("/", hIndex)
r.Route("/articles", func(r Router) {
// rr1 = r.(*Mux)
r.Get("/", hArticlesList)
r.Get("/search", hSearchArticles)
r.Route("/{id}", func(r Router) {
// rr2 = r.(*Mux)
r.Get("/", hGetArticle)
r.Get("/sync", hSyncArticle)
})
})
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, r.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, rr1.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, rr2.tree, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/", nil)
expected = "index"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles", nil)
expected = "articles-list"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/search", nil)
expected = "search-articles"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/123", nil)
expected = "get-article:123"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/articles/123/sync", nil)
expected = "sync-article:123"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
}
func TestMuxSubroutes(t *testing.T) {
hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub1"))
})
hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub2"))
})
hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub3"))
})
hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("account1"))
})
hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("account2"))
})
r := NewRouter()
r.Get("/hubs/{hubID}/view", hHubView1)
r.Get("/hubs/{hubID}/view/*", hHubView2)
sr := NewRouter()
sr.Get("/", hHubView3)
r.Mount("/hubs/{hubID}/users", sr)
r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hub3 override"))
})
sr3 := NewRouter()
sr3.Get("/", hAccountView1)
sr3.Get("/hi", hAccountView2)
// var sr2 *Mux
r.Route("/accounts/{accountID}", func(r Router) {
_ = r.(*Mux) // sr2
// r.Get("/", hAccountView1)
r.Mount("/", sr3)
})
// This is the same as the r.Route() call mounted on sr2
// sr2 := NewRouter()
// sr2.Mount("/", sr3)
// r.Mount("/accounts/{accountID}", sr2)
ts := httptest.NewServer(r)
defer ts.Close()
var body, expected string
_, body = testRequest(t, ts, "GET", "/hubs/123/view", nil)
expected = "hub1"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil)
expected = "hub2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/users", nil)
expected = "hub3"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil)
expected = "hub3 override"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
_, body = testRequest(t, ts, "GET", "/accounts/44", nil)
expected = "account1"
if body != expected {
t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body)
}
_, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil)
expected = "account2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
// Test that we're building the routingPatterns properly
router := r
req, _ := http.NewRequest("GET", "/accounts/44/hi", nil)
rctx := NewRouteContext()
req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
body = w.Body.String()
expected = "account2"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
routePatterns := rctx.RoutePatterns
if len(rctx.RoutePatterns) != 3 {
t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns))
}
expected = "/accounts/{accountID}/*"
if routePatterns[0] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0])
}
expected = "/*"
if routePatterns[1] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1])
}
expected = "/hi"
if routePatterns[2] != expected {
t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2])
}
}
func TestSingleHandler(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := URLParam(r, "name")
w.Write([]byte("hi " + name))
})
r, _ := http.NewRequest("GET", "/", nil)
rctx := NewRouteContext()
r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
rctx.URLParams.Add("name", "joe")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
body := w.Body.String()
expected := "hi joe"
if body != expected {
t.Fatalf("expected:%s got:%s", expected, body)
}
}
// TODO: a Router wrapper test..
//
// type ACLMux struct {
// *Mux
// XX string
// }
//
// func NewACLMux() *ACLMux {
// return &ACLMux{Mux: NewRouter(), XX: "hihi"}
// }
//
// // TODO: this should be supported...
// func TestWoot(t *testing.T) {
// var r Router = NewRouter()
//
// var r2 Router = NewACLMux() //NewRouter()
// r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("hi"))
// })
//
// r.Mount("/", r2)
// }
func TestServeHTTPExistingContext(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
w.Write([]byte(s))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
w.WriteHeader(404)
w.Write([]byte(s))
})
testcases := []struct {
Ctx context.Context
Method string
Path string
ExpectedBody string
ExpectedStatus int
}{
{
Method: "GET",
Path: "/hi",
Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"),
ExpectedStatus: 200,
ExpectedBody: "hi ctx",
},
{
Method: "GET",
Path: "/hello",
Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"),
ExpectedStatus: 404,
ExpectedBody: "nothing here ctx",
},
}
for _, tc := range testcases {
resp := httptest.NewRecorder()
req, err := http.NewRequest(tc.Method, tc.Path, nil)
if err != nil {
t.Fatalf("%v", err)
}
req = req.WithContext(tc.Ctx)
r.ServeHTTP(resp, req)
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("%v", err)
}
if resp.Code != tc.ExpectedStatus {
t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
}
if string(b) != tc.ExpectedBody {
t.Fatalf("%s != %s", tc.ExpectedBody, b)
}
}
}
func TestNestedGroups(t *testing.T) {
handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) {
counter, _ := r.Context().Value(ctxKey{"counter"}).(int)
w.Write([]byte(fmt.Sprintf("%v", counter)))
}
mwIncreaseCounter := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
counter, _ := ctx.Value(ctxKey{"counter"}).(int)
counter++
ctx = context.WithValue(ctx, ctxKey{"counter"}, counter)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Each route represents value of its counter (number of applied middlewares).
r := NewRouter() // counter == 0
r.Get("/0", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter) // counter == 1
r.Get("/1", handlerPrintCounter)
// r.Handle(GET, "/2", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
r.Get("/3", handlerPrintCounter)
})
r.Route("/", func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
// r.Handle(GET, "/4", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter)
r.Group(func(r Router) {
r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 5
r.Get("/5", handlerPrintCounter)
// r.Handle(GET, "/6", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter)
})
})
})
ts := httptest.NewServer(r)
defer ts.Close()
for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} {
if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route {
t.Errorf("expected %v, got %v", route, body)
}
}
}
func TestMiddlewarePanicOnLateUse(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello\n"))
}
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/", handler)
r.Use(mw) // Too late to apply middleware, we're expecting panic().
}
func TestMountingExistingPath(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/", handler)
r.Mount("/hi", http.HandlerFunc(handler))
r.Mount("/hi", http.HandlerFunc(handler))
}
func TestMountingSimilarPattern(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bye"))
})
r2 := NewRouter()
r2.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})
r3 := NewRouter()
r3.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
})
r.Mount("/foobar", r2)
r.Mount("/foo", r3)
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
t.Fatalf(body)
}
}
func TestMuxEmptyParams(t *testing.T) {
r := NewRouter()
r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) {
x := URLParam(r, "x")
y := URLParam(r, "y")
z := URLParam(r, "z")
w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z)))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" {
t.Fatalf(body)
}
}
func TestMuxMissingParams(t *testing.T) {
r := NewRouter()
r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) {
userID := URLParam(r, "userId")
w.Write([]byte(fmt.Sprintf("userId = '%s'", userID)))
})
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte("nothing here"))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" {
t.Fatalf(body)
}
}
func TestMuxWildcardRoute(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/*/wildcard/must/be/at/end", handler)
}
func TestMuxWildcardRouteCheckTwo(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {}
defer func() {
if recover() == nil {
t.Error("expected panic()")
}
}()
r := NewRouter()
r.Get("/*/wildcard/{must}/be/at/end", handler)
}
func TestMuxRegexp(t *testing.T) {
r := NewRouter()
r.Route("/{param:[0-9]*}/test", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param"))))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " {
t.Fatalf(body)
}
}
func TestMuxRegexp2(t *testing.T) {
r := NewRouter()
r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(URLParam(r, "suffix")))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" {
t.Fatalf(body)
}
}
func TestMuxRegexp3(t *testing.T) {
r := NewRouter()
r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("first"))
})
r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("second"))
})
r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("third"))
})
r.Route("/one", func(r Router) {
r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) {
writer.Write([]byte("_"))
})
r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) {
writer.Write([]byte("_"))
})
r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) {
writer.Write([]byte("forth"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" {
t.Fatalf(body)
}
}
func TestMuxSubrouterWildcardParam(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*"))
})
r := NewRouter()
r.Get("/bare/{param}", h)
r.Get("/bare/{param}/*", h)
r.Route("/case0", func(r Router) {
r.Get("/{param}", h)
r.Get("/{param}/*", h)
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" {
t.Fatalf(body)
}
}
func TestMuxContextIsThreadSafe(t *testing.T) {
router := NewRouter()
router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond)
defer cancel()
<-ctx.Done()
})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10000; j++ {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/ok", nil)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(r.Context())
r = r.WithContext(ctx)
go func() {
cancel()
}()
router.ServeHTTP(w, r)
}
}()
}
wg.Wait()
}
func TestEscapedURLParams(t *testing.T) {
m := NewRouter()
m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
rctx := RouteContext(r.Context())
if rctx == nil {
t.Error("no context")
return
}
identifier := URLParam(r, "identifier")
if identifier != "http:%2f%2fexample.com%2fimage.png" {
t.Errorf("identifier path parameter incorrect %s", identifier)
return
}
region := URLParam(r, "region")
if region != "full" {
t.Errorf("region path parameter incorrect %s", region)
return
}
size := URLParam(r, "size")
if size != "max" {
t.Errorf("size path parameter incorrect %s", size)
return
}
rotation := URLParam(r, "rotation")
if rotation != "0" {
t.Errorf("rotation path parameter incorrect %s", rotation)
return
}
w.Write([]byte("success"))
})
ts := httptest.NewServer(m)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" {
t.Fatalf(body)
}
}
func TestCustomHTTPMethod(t *testing.T) {
// first we must register this method to be accepted, then we
// can define method handlers on the router below
RegisterMethod("BOO")
r := NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("."))
})
// note the custom BOO method for route /hi
r.MethodFunc("BOO", "/hi", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("custom method"))
})
ts := httptest.NewServer(r)
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/", nil); body != "." {
t.Fatalf(body)
}
if _, body := testRequest(t, ts, "BOO", "/hi", nil); body != "custom method" {
t.Fatalf(body)
}
}
func TestMuxMatch(t *testing.T) {
r := NewRouter()
r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "yes")
w.Write([]byte("bye"))
})
r.Route("/articles", func(r Router) {
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := URLParam(r, "id")
w.Header().Set("X-Article", id)
w.Write([]byte("article:" + id))
})
})
r.Route("/users", func(r Router) {
r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-User", "-")
w.Write([]byte("user"))
})
r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
id := URLParam(r, "id")
w.Header().Set("X-User", id)
w.Write([]byte("user:" + id))
})
})
tctx := NewRouteContext()
tctx.Reset()
if r.Match(tctx, "GET", "/users/1") == false {
t.Fatal("expecting to find match for route:", "GET", "/users/1")
}
tctx.Reset()
if r.Match(tctx, "HEAD", "/articles/10") == true {
t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10")
}
}
func TestServerBaseContext(t *testing.T) {
r := NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
baseYes := r.Context().Value(ctxKey{"base"}).(string)
if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok {
panic("missing server context")
}
if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok {
panic("missing local addr context")
}
w.Write([]byte(baseYes))
})
// Setup http Server with a base context
ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes")
ts := httptest.NewUnstartedServer(r)
ts.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
ts.Start()
defer ts.Close()
if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" {
t.Fatalf(body)
}
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) {
r, _ := http.NewRequest(method, path, body)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
return w.Result(), w.Body.String()
}
type testFileSystem struct {
open func(name string) (http.File, error)
}
func (fs *testFileSystem) Open(name string) (http.File, error) {
return fs.open(name)
}
type testFile struct {
name string
contents []byte
}
func (tf *testFile) Close() error {
return nil
}
func (tf *testFile) Read(p []byte) (n int, err error) {
copy(p, tf.contents)
return len(p), nil
}
func (tf *testFile) Seek(offset int64, whence int) (int64, error) {
return 0, nil
}
func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) {
stat, _ := tf.Stat()
return []os.FileInfo{stat}, nil
}
func (tf *testFile) Stat() (os.FileInfo, error) {
return &testFileInfo{tf.name, int64(len(tf.contents))}, nil
}
type testFileInfo struct {
name string
size int64
}
func (tfi *testFileInfo) Name() string { return tfi.name }
func (tfi *testFileInfo) Size() int64 { return tfi.size }
func (tfi *testFileInfo) Mode() os.FileMode { return 0755 }
func (tfi *testFileInfo) ModTime() time.Time { return time.Now() }
func (tfi *testFileInfo) IsDir() bool { return false }
func (tfi *testFileInfo) Sys() interface{} { return nil }
type ctxKey struct {
name string
}
func (k ctxKey) String() string {
return "context value " + k.name
}
func BenchmarkMux(b *testing.B) {
h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
mx := NewRouter()
mx.Get("/", h1)
mx.Get("/hi", h2)
mx.Get("/sup/{id}/and/{this}", h3)
mx.Get("/sup/{id}/{bar:foo}/{this}", h3)
mx.Route("/sharing/{x}/{hash}", func(mx Router) {
mx.Get("/", h4) // subrouter-1
mx.Get("/{network}", h5) // subrouter-1
mx.Get("/twitter", h5)
mx.Route("/direct", func(mx Router) {
mx.Get("/", h6) // subrouter-2
mx.Get("/download", h6)
})
})
routes := []string{
"/",
"/hi",
"/sup/123/and/this",
"/sup/123/foo/this",
"/sharing/z/aBc", // subrouter-1
"/sharing/z/aBc/twitter", // subrouter-1
"/sharing/z/aBc/direct", // subrouter-2
"/sharing/z/aBc/direct/download", // subrouter-2
}
for _, path := range routes {
b.Run("route:"+path, func(b *testing.B) {
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", path, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
mx.ServeHTTP(w, r)
}
})
}
}
chi-5.0.7/testdata/ 0000775 0000000 0000000 00000000000 14145546033 0014104 5 ustar 00root root 0000000 0000000 chi-5.0.7/testdata/cert.pem 0000664 0000000 0000000 00000002112 14145546033 0015540 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIIC/zCCAeegAwIBAgIRANioW0Re7DtpT4qZpJU1iK8wDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjEyMzExNDU0MzBaFw0xNzEyMzExNDU0
MzBaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQDpFfOsaXDYlL+ektfsqGYrSAsoTbe7zqjpow9nqUU4PmLRu2YMaaW8
fAoneUnJxsJw7ql38+VMpphZUOmOWvsO7uV/lfnTIQfTwllHDdgAR5A11d84Zy/y
TiNIFJduuaPtEhQs1dxPhU7TG8sEfFRhBoUDPv473akeGPNkVU756RVBYM6rUc3b
YygD0PXGsQ2obrImbYUyyHH5YClCvGl1No57n3ugLqSSfwbgR3/Gw7kkGKy0PMOu
TuHuJnTEmofJPkqEyFRVMlIAtfqFqJUfDHTOuQGWIUPnjDg+fqTI9EPJ+pElBqDQ
IqW93BY5XePMdrTQc1h6xkduDfuLeA7TAgMBAAGjUDBOMA4GA1UdDwEB/wQEAwIF
oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBkGA1UdEQQSMBCC
DmxvY2FsaG9zdDo3MDcyMA0GCSqGSIb3DQEBCwUAA4IBAQDnsWmZdf7209A/XHUe
xoONCbU8jaYFVoA+CN9J+3CASzrzTQ4fh9RJdm2FZuv4sWnb5c5hDN7H/M/nLcb0
+uu7ACBGhd7yACYCQm/z3Pm3CY2BRIo0vCCRioGx+6J3CPGWFm0vHwNBge0iBOKC
Wn+/YOlTDth/M3auHYlr7hdFmf57U4V/5iTr4wiKxwM9yMPcVRQF/1XpPd7A0VqM
nFSEfDpFjrA7MvT3DrRqQGqF/ZXxDbro2nyki3YG8FwgKlFNVN9w55zNiriQ+WNA
uz86lKg1FTc+m/R/0CD//7+7mme28N813EPVdV83TgxWNrfvAIRazkHE7YxETry0
BJDg
-----END CERTIFICATE----- chi-5.0.7/testdata/key.pem 0000664 0000000 0000000 00000003216 14145546033 0015401 0 ustar 00root root 0000000 0000000 -----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA6RXzrGlw2JS/npLX7KhmK0gLKE23u86o6aMPZ6lFOD5i0btm
DGmlvHwKJ3lJycbCcO6pd/PlTKaYWVDpjlr7Du7lf5X50yEH08JZRw3YAEeQNdXf
OGcv8k4jSBSXbrmj7RIULNXcT4VO0xvLBHxUYQaFAz7+O92pHhjzZFVO+ekVQWDO
q1HN22MoA9D1xrENqG6yJm2FMshx+WApQrxpdTaOe597oC6kkn8G4Ed/xsO5JBis
tDzDrk7h7iZ0xJqHyT5KhMhUVTJSALX6haiVHwx0zrkBliFD54w4Pn6kyPRDyfqR
JQag0CKlvdwWOV3jzHa00HNYesZHbg37i3gO0wIDAQABAoIBAFvqYDE5U1rVLctm
tOeKcN/YhS3bl/zjvhCEUOrcAYPwdh+m+tMiRk1RzN9MISEE1GCcfQ/kiiPz/lga
ZD/S+PYmlzH8/ouXlvKWzYYLm4ZgsinIsUIYzvuKfLdMB3uOkWpHmtUjcMGbHD57
009tiAjK/WEOUkthWfOYe0KxsXczBn3PTAWZuiIkuA3RVWa7pCCFHUENkViP58wl
Ky1hYKnunKPApRwuiC6qIT5ZOCSukdCCbkmRnj/x+P8+nsosu+1d85MNZb8uLRi0
RzMmuOfOK2poDsrNHQX7itKlu7rzMJQc3+RauqIZovNe/BmSq+tYBLboXvUp18g/
+VqKeEECgYEA/LaD1tJepzD/1lhgunFcnDjxsDJqLUpfR5eDMX1qhGJphuPBLOXS
ushmVVjbVIn25Wxeoe4RYrZ6Tuu0FEJJgV44Lt42OOFgK2gyrCJpYmlxpRaw+7jc
Dbp1Sh3/9VqMZjR/mQIzTnfOtS2n4Fk1Q53hdJn5Pn+uPMmMO4hF87sCgYEA7B4V
BACsd6eqVxKkEMc72VLeYb0Ri0bl0FwbvIKXImppwA0tbMDmeA+6yhcRm23dhd5v
cfNhJepRIzkM2CkhnazlsAbDoJPqb7/sbNzodtW1P0op7YIFYbrkcX4yOu9O1DNI
Ij4PR8H1WcpPjhvr3q+iNO5agQX7bMQ1BnnJg8kCgYBA1tdm090DSrgpl81hqNpZ
HucsDRNfAXkG1mIL3aDpzJJE0MTsrx7tW6Od/ElyHF/jp3V0WK/PQwCIpUMz+3n+
nl0N8We6GmFhYb+2mLGvVVyaPgM04s5bG18ioCXfHtdtFcUzTfQ6CtVXeRpcnqbi
7Ww+TY88sOfUouW/FIzWJwKBgQCsLauJhaw+fOc8I328NmywJzu+7g5TD9oZvHEF
X/0xvYNr5rAPNANb3ayKHZRbURxOuEtwPtfCvEF6e+mf3y6COkgrumMBP5ue7cdM
AzMJJQHMKxqz9TJTd+OJ10ptq4BCQTsCrVqbKxbs6RhmOnofoteX3Y/lsiULxXAd
TsXh8QKBgQDQHosH8VoL7vIK+SqY5uoHAhMytSVNx4IaZZg4ho8oyjw12QXcidgV
QJZQMdPEv8cAK78WcQdSthop+O/tu2cKLHyAmWmO3oU7gIQECui0aMXSqraO6Vde
C5tqYlyLa7bHZS3AqrjRv9BRfwPKVkmBoYdA652rN/tE/K4UWsghnA==
-----END RSA PRIVATE KEY----- chi-5.0.7/tree.go 0000664 0000000 0000000 00000047122 14145546033 0013567 0 ustar 00root root 0000000 0000000 package chi
// Radix tree implementation below is a based on the original work by
// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go
// (MIT licensed). It's been heavily modified for use as a HTTP routing tree.
import (
"fmt"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
)
type methodTyp uint
const (
mSTUB methodTyp = 1 << iota
mCONNECT
mDELETE
mGET
mHEAD
mOPTIONS
mPATCH
mPOST
mPUT
mTRACE
)
var mALL = mCONNECT | mDELETE | mGET | mHEAD |
mOPTIONS | mPATCH | mPOST | mPUT | mTRACE
var methodMap = map[string]methodTyp{
http.MethodConnect: mCONNECT,
http.MethodDelete: mDELETE,
http.MethodGet: mGET,
http.MethodHead: mHEAD,
http.MethodOptions: mOPTIONS,
http.MethodPatch: mPATCH,
http.MethodPost: mPOST,
http.MethodPut: mPUT,
http.MethodTrace: mTRACE,
}
// RegisterMethod adds support for custom HTTP method handlers, available
// via Router#Method and Router#MethodFunc
func RegisterMethod(method string) {
if method == "" {
return
}
method = strings.ToUpper(method)
if _, ok := methodMap[method]; ok {
return
}
n := len(methodMap)
if n > strconv.IntSize-2 {
panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize))
}
mt := methodTyp(2 << n)
methodMap[method] = mt
mALL |= mt
}
type nodeTyp uint8
const (
ntStatic nodeTyp = iota // /home
ntRegexp // /{id:[0-9]+}
ntParam // /{user}
ntCatchAll // /api/v1/*
)
type node struct {
// subroutes on the leaf node
subroutes Routes
// regexp matcher for regexp nodes
rex *regexp.Regexp
// HTTP handler endpoints on the leaf node
endpoints endpoints
// prefix is the common prefix we ignore
prefix string
// child nodes should be stored in-order for iteration,
// in groups of the node type.
children [ntCatchAll + 1]nodes
// first byte of the child prefix
tail byte
// node type: static, regexp, param, catchAll
typ nodeTyp
// first byte of the prefix
label byte
}
// endpoints is a mapping of http method constants to handlers
// for a given route.
type endpoints map[methodTyp]*endpoint
type endpoint struct {
// endpoint handler
handler http.Handler
// pattern is the routing pattern for handler nodes
pattern string
// parameter keys recorded on handler nodes
paramKeys []string
}
func (s endpoints) Value(method methodTyp) *endpoint {
mh, ok := s[method]
if !ok {
mh = &endpoint{}
s[method] = mh
}
return mh
}
func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node {
var parent *node
search := pattern
for {
// Handle key exhaustion
if len(search) == 0 {
// Insert or update the node's leaf handler
n.setEndpoint(method, handler, pattern)
return n
}
// We're going to be searching for a wild node next,
// in this case, we need to get the tail
var label = search[0]
var segTail byte
var segEndIdx int
var segTyp nodeTyp
var segRexpat string
if label == '{' || label == '*' {
segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
}
var prefix string
if segTyp == ntRegexp {
prefix = segRexpat
}
// Look for the edge to attach to
parent = n
n = n.getEdge(segTyp, label, segTail, prefix)
// No edge, create one
if n == nil {
child := &node{label: label, tail: segTail, prefix: search}
hn := parent.addChild(child, search)
hn.setEndpoint(method, handler, pattern)
return hn
}
// Found an edge to match the pattern
if n.typ > ntStatic {
// We found a param node, trim the param from the search path and continue.
// This param/wild pattern segment would already be on the tree from a previous
// call to addChild when creating a new node.
search = search[segEndIdx:]
continue
}
// Static nodes fall below here.
// Determine longest prefix of the search key on match.
commonPrefix := longestPrefix(search, n.prefix)
if commonPrefix == len(n.prefix) {
// the common prefix is as long as the current node's prefix we're attempting to insert.
// keep the search going.
search = search[commonPrefix:]
continue
}
// Split the node
child := &node{
typ: ntStatic,
prefix: search[:commonPrefix],
}
parent.replaceChild(search[0], segTail, child)
// Restore the existing node
n.label = n.prefix[commonPrefix]
n.prefix = n.prefix[commonPrefix:]
child.addChild(n, n.prefix)
// If the new key is a subset, set the method/handler on this node and finish.
search = search[commonPrefix:]
if len(search) == 0 {
child.setEndpoint(method, handler, pattern)
return child
}
// Create a new edge for the node
subchild := &node{
typ: ntStatic,
label: search[0],
prefix: search,
}
hn := child.addChild(subchild, search)
hn.setEndpoint(method, handler, pattern)
return hn
}
}
// addChild appends the new `child` node to the tree using the `pattern` as the trie key.
// For a URL router like chi's, we split the static, param, regexp and wildcard segments
// into different nodes. In addition, addChild will recursively call itself until every
// pattern segment is added to the url pattern tree as individual nodes, depending on type.
func (n *node) addChild(child *node, prefix string) *node {
search := prefix
// handler leaf node added to the tree is the child.
// this may be overridden later down the flow
hn := child
// Parse next segment
segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)
// Add child depending on next up segment
switch segTyp {
case ntStatic:
// Search prefix is all static (that is, has no params in path)
// noop
default:
// Search prefix contains a param, regexp or wildcard
if segTyp == ntRegexp {
rex, err := regexp.Compile(segRexpat)
if err != nil {
panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat))
}
child.prefix = segRexpat
child.rex = rex
}
if segStartIdx == 0 {
// Route starts with a param
child.typ = segTyp
if segTyp == ntCatchAll {
segStartIdx = -1
} else {
segStartIdx = segEndIdx
}
if segStartIdx < 0 {
segStartIdx = len(search)
}
child.tail = segTail // for params, we set the tail
if segStartIdx != len(search) {
// add static edge for the remaining part, split the end.
// its not possible to have adjacent param nodes, so its certainly
// going to be a static node next.
search = search[segStartIdx:] // advance search position
nn := &node{
typ: ntStatic,
label: search[0],
prefix: search,
}
hn = child.addChild(nn, search)
}
} else if segStartIdx > 0 {
// Route has some param
// starts with a static segment
child.typ = ntStatic
child.prefix = search[:segStartIdx]
child.rex = nil
// add the param edge node
search = search[segStartIdx:]
nn := &node{
typ: segTyp,
label: search[0],
tail: segTail,
}
hn = child.addChild(nn, search)
}
}
n.children[child.typ] = append(n.children[child.typ], child)
n.children[child.typ].Sort()
return hn
}
func (n *node) replaceChild(label, tail byte, child *node) {
for i := 0; i < len(n.children[child.typ]); i++ {
if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
n.children[child.typ][i] = child
n.children[child.typ][i].label = label
n.children[child.typ][i].tail = tail
return
}
}
panic("chi: replacing missing child")
}
func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
nds := n.children[ntyp]
for i := 0; i < len(nds); i++ {
if nds[i].label == label && nds[i].tail == tail {
if ntyp == ntRegexp && nds[i].prefix != prefix {
continue
}
return nds[i]
}
}
return nil
}
func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) {
// Set the handler for the method type on the node
if n.endpoints == nil {
n.endpoints = make(endpoints)
}
paramKeys := patParamKeys(pattern)
if method&mSTUB == mSTUB {
n.endpoints.Value(mSTUB).handler = handler
}
if method&mALL == mALL {
h := n.endpoints.Value(mALL)
h.handler = handler
h.pattern = pattern
h.paramKeys = paramKeys
for _, m := range methodMap {
h := n.endpoints.Value(m)
h.handler = handler
h.pattern = pattern
h.paramKeys = paramKeys
}
} else {
h := n.endpoints.Value(method)
h.handler = handler
h.pattern = pattern
h.paramKeys = paramKeys
}
}
func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) {
// Reset the context routing pattern and params
rctx.routePattern = ""
rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
rctx.routeParams.Values = rctx.routeParams.Values[:0]
// Find the routing handlers for the path
rn := n.findRoute(rctx, method, path)
if rn == nil {
return nil, nil, nil
}
// Record the routing params in the request lifecycle
rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...)
rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...)
// Record the routing pattern in the request lifecycle
if rn.endpoints[method].pattern != "" {
rctx.routePattern = rn.endpoints[method].pattern
rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
}
return rn, rn.endpoints, rn.endpoints[method].handler
}
// Recursive edge traversal by checking all nodeTyp groups along the way.
// It's like searching through a multi-dimensional radix trie.
func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
nn := n
search := path
for t, nds := range nn.children {
ntyp := nodeTyp(t)
if len(nds) == 0 {
continue
}
var xn *node
xsearch := search
var label byte
if search != "" {
label = search[0]
}
switch ntyp {
case ntStatic:
xn = nds.findEdge(label)
if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
continue
}
xsearch = xsearch[len(xn.prefix):]
case ntParam, ntRegexp:
// short-circuit and return no matching route for empty param values
if xsearch == "" {
continue
}
// serially loop through each node grouped by the tail delimiter
for idx := 0; idx < len(nds); idx++ {
xn = nds[idx]
// label for param nodes is the delimiter byte
p := strings.IndexByte(xsearch, xn.tail)
if p < 0 {
if xn.tail == '/' {
p = len(xsearch)
} else {
continue
}
} else if ntyp == ntRegexp && p == 0 {
continue
}
if ntyp == ntRegexp && xn.rex != nil {
if !xn.rex.MatchString(xsearch[:p]) {
continue
}
} else if strings.IndexByte(xsearch[:p], '/') != -1 {
// avoid a match across path segments
continue
}
prevlen := len(rctx.routeParams.Values)
rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
xsearch = xsearch[p:]
if len(xsearch) == 0 {
if xn.isLeaf() {
h := xn.endpoints[method]
if h != nil && h.handler != nil {
rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
return xn
}
// flag that the routing context found a route, but not a corresponding
// supported method
rctx.methodNotAllowed = true
}
}
// recursively find the next node on this branch
fin := xn.findRoute(rctx, method, xsearch)
if fin != nil {
return fin
}
// not found on this branch, reset vars
rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
xsearch = search
}
rctx.routeParams.Values = append(rctx.routeParams.Values, "")
default:
// catch-all nodes
rctx.routeParams.Values = append(rctx.routeParams.Values, search)
xn = nds[0]
xsearch = ""
}
if xn == nil {
continue
}
// did we find it yet?
if len(xsearch) == 0 {
if xn.isLeaf() {
h := xn.endpoints[method]
if h != nil && h.handler != nil {
rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
return xn
}
// flag that the routing context found a route, but not a corresponding
// supported method
rctx.methodNotAllowed = true
}
}
// recursively find the next node..
fin := xn.findRoute(rctx, method, xsearch)
if fin != nil {
return fin
}
// Did not find final handler, let's remove the param here if it was set
if xn.typ > ntStatic {
if len(rctx.routeParams.Values) > 0 {
rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
}
}
}
return nil
}
func (n *node) findEdge(ntyp nodeTyp, label byte) *node {
nds := n.children[ntyp]
num := len(nds)
idx := 0
switch ntyp {
case ntStatic, ntParam, ntRegexp:
i, j := 0, num-1
for i <= j {
idx = i + (j-i)/2
if label > nds[idx].label {
i = idx + 1
} else if label < nds[idx].label {
j = idx - 1
} else {
i = num // breaks cond
}
}
if nds[idx].label != label {
return nil
}
return nds[idx]
default: // catch all
return nds[idx]
}
}
func (n *node) isLeaf() bool {
return n.endpoints != nil
}
func (n *node) findPattern(pattern string) bool {
nn := n
for _, nds := range nn.children {
if len(nds) == 0 {
continue
}
n = nn.findEdge(nds[0].typ, pattern[0])
if n == nil {
continue
}
var idx int
var xpattern string
switch n.typ {
case ntStatic:
idx = longestPrefix(pattern, n.prefix)
if idx < len(n.prefix) {
continue
}
case ntParam, ntRegexp:
idx = strings.IndexByte(pattern, '}') + 1
case ntCatchAll:
idx = longestPrefix(pattern, "*")
default:
panic("chi: unknown node type")
}
xpattern = pattern[idx:]
if len(xpattern) == 0 {
return true
}
return n.findPattern(xpattern)
}
return false
}
func (n *node) routes() []Route {
rts := []Route{}
n.walk(func(eps endpoints, subroutes Routes) bool {
if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil {
return false
}
// Group methodHandlers by unique patterns
pats := make(map[string]endpoints)
for mt, h := range eps {
if h.pattern == "" {
continue
}
p, ok := pats[h.pattern]
if !ok {
p = endpoints{}
pats[h.pattern] = p
}
p[mt] = h
}
for p, mh := range pats {
hs := make(map[string]http.Handler)
if mh[mALL] != nil && mh[mALL].handler != nil {
hs["*"] = mh[mALL].handler
}
for mt, h := range mh {
if h.handler == nil {
continue
}
m := methodTypString(mt)
if m == "" {
continue
}
hs[m] = h.handler
}
rt := Route{subroutes, hs, p}
rts = append(rts, rt)
}
return false
})
return rts
}
func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool {
// Visit the leaf values if any
if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) {
return true
}
// Recurse on the children
for _, ns := range n.children {
for _, cn := range ns {
if cn.walk(fn) {
return true
}
}
}
return false
}
// patNextSegment returns the next segment details from a pattern:
// node type, param key, regexp string, param tail byte, param starting index, param ending index
func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
ps := strings.Index(pattern, "{")
ws := strings.Index(pattern, "*")
if ps < 0 && ws < 0 {
return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing
}
// Sanity check
if ps >= 0 && ws >= 0 && ws < ps {
panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
}
var tail byte = '/' // Default endpoint tail to / byte
if ps >= 0 {
// Param/Regexp pattern is next
nt := ntParam
// Read to closing } taking into account opens and closes in curl count (cc)
cc := 0
pe := ps
for i, c := range pattern[ps:] {
if c == '{' {
cc++
} else if c == '}' {
cc--
if cc == 0 {
pe = ps + i
break
}
}
}
if pe == ps {
panic("chi: route param closing delimiter '}' is missing")
}
key := pattern[ps+1 : pe]
pe++ // set end to next position
if pe < len(pattern) {
tail = pattern[pe]
}
var rexpat string
if idx := strings.Index(key, ":"); idx >= 0 {
nt = ntRegexp
rexpat = key[idx+1:]
key = key[:idx]
}
if len(rexpat) > 0 {
if rexpat[0] != '^' {
rexpat = "^" + rexpat
}
if rexpat[len(rexpat)-1] != '$' {
rexpat += "$"
}
}
return nt, key, rexpat, tail, ps, pe
}
// Wildcard pattern as finale
if ws < len(pattern)-1 {
panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
}
return ntCatchAll, "*", "", 0, ws, len(pattern)
}
func patParamKeys(pattern string) []string {
pat := pattern
paramKeys := []string{}
for {
ptyp, paramKey, _, _, _, e := patNextSegment(pat)
if ptyp == ntStatic {
return paramKeys
}
for i := 0; i < len(paramKeys); i++ {
if paramKeys[i] == paramKey {
panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
}
}
paramKeys = append(paramKeys, paramKey)
pat = pat[e:]
}
}
// longestPrefix finds the length of the shared prefix
// of two strings
func longestPrefix(k1, k2 string) int {
max := len(k1)
if l := len(k2); l < max {
max = l
}
var i int
for i = 0; i < max; i++ {
if k1[i] != k2[i] {
break
}
}
return i
}
func methodTypString(method methodTyp) string {
for s, t := range methodMap {
if method == t {
return s
}
}
return ""
}
type nodes []*node
// Sort the list of nodes by label
func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() }
func (ns nodes) Len() int { return len(ns) }
func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] }
func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }
// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes.
// The list order determines the traversal order.
func (ns nodes) tailSort() {
for i := len(ns) - 1; i >= 0; i-- {
if ns[i].typ > ntStatic && ns[i].tail == '/' {
ns.Swap(i, len(ns)-1)
return
}
}
}
func (ns nodes) findEdge(label byte) *node {
num := len(ns)
idx := 0
i, j := 0, num-1
for i <= j {
idx = i + (j-i)/2
if label > ns[idx].label {
i = idx + 1
} else if label < ns[idx].label {
j = idx - 1
} else {
i = num // breaks cond
}
}
if ns[idx].label != label {
return nil
}
return ns[idx]
}
// Route describes the details of a routing handler.
// Handlers map key is an HTTP method
type Route struct {
SubRoutes Routes
Handlers map[string]http.Handler
Pattern string
}
// WalkFunc is the type of the function called for each method and route visited by Walk.
type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error
// Walk walks any router tree that implements Routes interface.
func Walk(r Routes, walkFn WalkFunc) error {
return walk(r, walkFn, "")
}
func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error {
for _, route := range r.Routes() {
mws := make([]func(http.Handler) http.Handler, len(parentMw))
copy(mws, parentMw)
mws = append(mws, r.Middlewares()...)
if route.SubRoutes != nil {
if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil {
return err
}
continue
}
for method, handler := range route.Handlers {
if method == "*" {
// Ignore a "catchAll" method, since we pass down all the specific methods for each route.
continue
}
fullRoute := parentRoute + route.Pattern
fullRoute = strings.Replace(fullRoute, "/*/", "/", -1)
if chain, ok := handler.(*ChainHandler); ok {
if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil {
return err
}
} else {
if err := walkFn(method, fullRoute, handler, mws...); err != nil {
return err
}
}
}
}
return nil
}
chi-5.0.7/tree_test.go 0000664 0000000 0000000 00000054221 14145546033 0014624 0 ustar 00root root 0000000 0000000 package chi
import (
"fmt"
"log"
"net/http"
"testing"
)
func TestTree(t *testing.T) {
hStub := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hFavicon := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleNear := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleShowRelated := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleShowOpts := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleSlug := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hArticleByUser := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hUserList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hUserShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hAdminCatchall := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hAdminAppShow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hAdminAppShowCatchall := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hUserProfile := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hUserSuper := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hUserAll := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
tr := &node{}
tr.InsertRoute(mGET, "/", hIndex)
tr.InsertRoute(mGET, "/favicon.ico", hFavicon)
tr.InsertRoute(mGET, "/pages/*", hStub)
tr.InsertRoute(mGET, "/article", hArticleList)
tr.InsertRoute(mGET, "/article/", hArticleList)
tr.InsertRoute(mGET, "/article/near", hArticleNear)
tr.InsertRoute(mGET, "/article/{id}", hStub)
tr.InsertRoute(mGET, "/article/{id}", hArticleShow)
tr.InsertRoute(mGET, "/article/{id}", hArticleShow) // duplicate will have no effect
tr.InsertRoute(mGET, "/article/@{user}", hArticleByUser)
tr.InsertRoute(mGET, "/article/{sup}/{opts}", hArticleShowOpts)
tr.InsertRoute(mGET, "/article/{id}/{opts}", hArticleShowOpts) // overwrite above route, latest wins
tr.InsertRoute(mGET, "/article/{iffd}/edit", hStub)
tr.InsertRoute(mGET, "/article/{id}//related", hArticleShowRelated)
tr.InsertRoute(mGET, "/article/slug/{month}/-/{day}/{year}", hArticleSlug)
tr.InsertRoute(mGET, "/admin/user", hUserList)
tr.InsertRoute(mGET, "/admin/user/", hStub) // will get replaced by next route
tr.InsertRoute(mGET, "/admin/user/", hUserList)
tr.InsertRoute(mGET, "/admin/user//{id}", hUserShow)
tr.InsertRoute(mGET, "/admin/user/{id}", hUserShow)
tr.InsertRoute(mGET, "/admin/apps/{id}", hAdminAppShow)
tr.InsertRoute(mGET, "/admin/apps/{id}/*", hAdminAppShowCatchall)
tr.InsertRoute(mGET, "/admin/*", hStub) // catchall segment will get replaced by next route
tr.InsertRoute(mGET, "/admin/*", hAdminCatchall)
tr.InsertRoute(mGET, "/users/{userID}/profile", hUserProfile)
tr.InsertRoute(mGET, "/users/super/*", hUserSuper)
tr.InsertRoute(mGET, "/users/*", hUserAll)
tr.InsertRoute(mGET, "/hubs/{hubID}/view", hHubView1)
tr.InsertRoute(mGET, "/hubs/{hubID}/view/*", hHubView2)
sr := NewRouter()
sr.Get("/users", hHubView3)
tr.InsertRoute(mGET, "/hubs/{hubID}/*", sr)
tr.InsertRoute(mGET, "/hubs/{hubID}/users", hHubView3)
tests := []struct {
r string // input request path
h http.Handler // output matched handler
k []string // output param keys
v []string // output param values
}{
{r: "/", h: hIndex, k: []string{}, v: []string{}},
{r: "/favicon.ico", h: hFavicon, k: []string{}, v: []string{}},
{r: "/pages", h: nil, k: []string{}, v: []string{}},
{r: "/pages/", h: hStub, k: []string{"*"}, v: []string{""}},
{r: "/pages/yes", h: hStub, k: []string{"*"}, v: []string{"yes"}},
{r: "/article", h: hArticleList, k: []string{}, v: []string{}},
{r: "/article/", h: hArticleList, k: []string{}, v: []string{}},
{r: "/article/near", h: hArticleNear, k: []string{}, v: []string{}},
{r: "/article/neard", h: hArticleShow, k: []string{"id"}, v: []string{"neard"}},
{r: "/article/123", h: hArticleShow, k: []string{"id"}, v: []string{"123"}},
{r: "/article/123/456", h: hArticleShowOpts, k: []string{"id", "opts"}, v: []string{"123", "456"}},
{r: "/article/@peter", h: hArticleByUser, k: []string{"user"}, v: []string{"peter"}},
{r: "/article/22//related", h: hArticleShowRelated, k: []string{"id"}, v: []string{"22"}},
{r: "/article/111/edit", h: hStub, k: []string{"iffd"}, v: []string{"111"}},
{r: "/article/slug/sept/-/4/2015", h: hArticleSlug, k: []string{"month", "day", "year"}, v: []string{"sept", "4", "2015"}},
{r: "/article/:id", h: hArticleShow, k: []string{"id"}, v: []string{":id"}},
{r: "/admin/user", h: hUserList, k: []string{}, v: []string{}},
{r: "/admin/user/", h: hUserList, k: []string{}, v: []string{}},
{r: "/admin/user/1", h: hUserShow, k: []string{"id"}, v: []string{"1"}},
{r: "/admin/user//1", h: hUserShow, k: []string{"id"}, v: []string{"1"}},
{r: "/admin/hi", h: hAdminCatchall, k: []string{"*"}, v: []string{"hi"}},
{r: "/admin/lots/of/:fun", h: hAdminCatchall, k: []string{"*"}, v: []string{"lots/of/:fun"}},
{r: "/admin/apps/333", h: hAdminAppShow, k: []string{"id"}, v: []string{"333"}},
{r: "/admin/apps/333/woot", h: hAdminAppShowCatchall, k: []string{"id", "*"}, v: []string{"333", "woot"}},
{r: "/hubs/123/view", h: hHubView1, k: []string{"hubID"}, v: []string{"123"}},
{r: "/hubs/123/view/index.html", h: hHubView2, k: []string{"hubID", "*"}, v: []string{"123", "index.html"}},
{r: "/hubs/123/users", h: hHubView3, k: []string{"hubID"}, v: []string{"123"}},
{r: "/users/123/profile", h: hUserProfile, k: []string{"userID"}, v: []string{"123"}},
{r: "/users/super/123/okay/yes", h: hUserSuper, k: []string{"*"}, v: []string{"123/okay/yes"}},
{r: "/users/123/okay/yes", h: hUserAll, k: []string{"*"}, v: []string{"123/okay/yes"}},
}
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, tr, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
for i, tt := range tests {
rctx := NewRouteContext()
_, handlers, _ := tr.FindRoute(rctx, mGET, tt.r)
var handler http.Handler
if methodHandler, ok := handlers[mGET]; ok {
handler = methodHandler.handler
}
paramKeys := rctx.routeParams.Keys
paramValues := rctx.routeParams.Values
if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) {
t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler)
}
if !stringSliceEqual(tt.k, paramKeys) {
t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys)
}
if !stringSliceEqual(tt.v, paramValues) {
t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues)
}
}
}
func TestTreeMoar(t *testing.T) {
hStub := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub7 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub8 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub9 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub10 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub11 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub12 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub13 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub14 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub15 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub16 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
// TODO: panic if we see {id}{x} because we're missing a delimiter, its not possible.
// also {:id}* is not possible.
tr := &node{}
tr.InsertRoute(mGET, "/articlefun", hStub5)
tr.InsertRoute(mGET, "/articles/{id}", hStub)
tr.InsertRoute(mDELETE, "/articles/{slug}", hStub8)
tr.InsertRoute(mGET, "/articles/search", hStub1)
tr.InsertRoute(mGET, "/articles/{id}:delete", hStub8)
tr.InsertRoute(mGET, "/articles/{iidd}!sup", hStub4)
tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub3)
tr.InsertRoute(mGET, "/articles/{id}:{op}", hStub2) // this route sets a new handler for the above route
tr.InsertRoute(mGET, "/articles/{slug:^[a-z]+}/posts", hStub) // up to tail '/' will only match if contents match the rex
tr.InsertRoute(mGET, "/articles/{id}/posts/{pid}", hStub6) // /articles/123/posts/1
tr.InsertRoute(mGET, "/articles/{id}/posts/{month}/{day}/{year}/{slug}", hStub7) // /articles/123/posts/09/04/1984/juice
tr.InsertRoute(mGET, "/articles/{id}.json", hStub10)
tr.InsertRoute(mGET, "/articles/{id}/data.json", hStub11)
tr.InsertRoute(mGET, "/articles/files/{file}.{ext}", hStub12)
tr.InsertRoute(mPUT, "/articles/me", hStub13)
// TODO: make a separate test case for this one..
// tr.InsertRoute(mGET, "/articles/{id}/{id}", hStub1) // panic expected, we're duplicating param keys
tr.InsertRoute(mGET, "/pages/*", hStub)
tr.InsertRoute(mGET, "/pages/*", hStub9)
tr.InsertRoute(mGET, "/users/{id}", hStub14)
tr.InsertRoute(mGET, "/users/{id}/settings/{key}", hStub15)
tr.InsertRoute(mGET, "/users/{id}/settings/*", hStub16)
tests := []struct {
h http.Handler
r string
k []string
v []string
m methodTyp
}{
{m: mGET, r: "/articles/search", h: hStub1, k: []string{}, v: []string{}},
{m: mGET, r: "/articlefun", h: hStub5, k: []string{}, v: []string{}},
{m: mGET, r: "/articles/123", h: hStub, k: []string{"id"}, v: []string{"123"}},
{m: mDELETE, r: "/articles/123mm", h: hStub8, k: []string{"slug"}, v: []string{"123mm"}},
{m: mGET, r: "/articles/789:delete", h: hStub8, k: []string{"id"}, v: []string{"789"}},
{m: mGET, r: "/articles/789!sup", h: hStub4, k: []string{"iidd"}, v: []string{"789"}},
{m: mGET, r: "/articles/123:sync", h: hStub2, k: []string{"id", "op"}, v: []string{"123", "sync"}},
{m: mGET, r: "/articles/456/posts/1", h: hStub6, k: []string{"id", "pid"}, v: []string{"456", "1"}},
{m: mGET, r: "/articles/456/posts/09/04/1984/juice", h: hStub7, k: []string{"id", "month", "day", "year", "slug"}, v: []string{"456", "09", "04", "1984", "juice"}},
{m: mGET, r: "/articles/456.json", h: hStub10, k: []string{"id"}, v: []string{"456"}},
{m: mGET, r: "/articles/456/data.json", h: hStub11, k: []string{"id"}, v: []string{"456"}},
{m: mGET, r: "/articles/files/file.zip", h: hStub12, k: []string{"file", "ext"}, v: []string{"file", "zip"}},
{m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}},
{m: mGET, r: "/articles/files/photos.tar.gz", h: hStub12, k: []string{"file", "ext"}, v: []string{"photos", "tar.gz"}},
{m: mPUT, r: "/articles/me", h: hStub13, k: []string{}, v: []string{}},
{m: mGET, r: "/articles/me", h: hStub, k: []string{"id"}, v: []string{"me"}},
{m: mGET, r: "/pages", h: nil, k: []string{}, v: []string{}},
{m: mGET, r: "/pages/", h: hStub9, k: []string{"*"}, v: []string{""}},
{m: mGET, r: "/pages/yes", h: hStub9, k: []string{"*"}, v: []string{"yes"}},
{m: mGET, r: "/users/1", h: hStub14, k: []string{"id"}, v: []string{"1"}},
{m: mGET, r: "/users/", h: nil, k: []string{}, v: []string{}},
{m: mGET, r: "/users/2/settings/password", h: hStub15, k: []string{"id", "key"}, v: []string{"2", "password"}},
{m: mGET, r: "/users/2/settings/", h: hStub16, k: []string{"id", "*"}, v: []string{"2", ""}},
}
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, tr, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
for i, tt := range tests {
rctx := NewRouteContext()
_, handlers, _ := tr.FindRoute(rctx, tt.m, tt.r)
var handler http.Handler
if methodHandler, ok := handlers[tt.m]; ok {
handler = methodHandler.handler
}
paramKeys := rctx.routeParams.Keys
paramValues := rctx.routeParams.Values
if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) {
t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler)
}
if !stringSliceEqual(tt.k, paramKeys) {
t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys)
}
if !stringSliceEqual(tt.v, paramValues) {
t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues)
}
}
}
func TestTreeRegexp(t *testing.T) {
hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub7 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
tr := &node{}
tr.InsertRoute(mGET, "/articles/{rid:^[0-9]{5,6}}", hStub7)
tr.InsertRoute(mGET, "/articles/{zid:^0[0-9]+}", hStub3)
tr.InsertRoute(mGET, "/articles/{name:^@[a-z]+}/posts", hStub4)
tr.InsertRoute(mGET, "/articles/{op:^[0-9]+}/run", hStub5)
tr.InsertRoute(mGET, "/articles/{id:^[0-9]+}", hStub1)
tr.InsertRoute(mGET, "/articles/{id:^[1-9]+}-{aux}", hStub6)
tr.InsertRoute(mGET, "/articles/{slug}", hStub2)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, tr, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
tests := []struct {
r string // input request path
h http.Handler // output matched handler
k []string // output param keys
v []string // output param values
}{
{r: "/articles", h: nil, k: []string{}, v: []string{}},
{r: "/articles/12345", h: hStub7, k: []string{"rid"}, v: []string{"12345"}},
{r: "/articles/123", h: hStub1, k: []string{"id"}, v: []string{"123"}},
{r: "/articles/how-to-build-a-router", h: hStub2, k: []string{"slug"}, v: []string{"how-to-build-a-router"}},
{r: "/articles/0456", h: hStub3, k: []string{"zid"}, v: []string{"0456"}},
{r: "/articles/@pk/posts", h: hStub4, k: []string{"name"}, v: []string{"@pk"}},
{r: "/articles/1/run", h: hStub5, k: []string{"op"}, v: []string{"1"}},
{r: "/articles/1122", h: hStub1, k: []string{"id"}, v: []string{"1122"}},
{r: "/articles/1122-yes", h: hStub6, k: []string{"id", "aux"}, v: []string{"1122", "yes"}},
}
for i, tt := range tests {
rctx := NewRouteContext()
_, handlers, _ := tr.FindRoute(rctx, mGET, tt.r)
var handler http.Handler
if methodHandler, ok := handlers[mGET]; ok {
handler = methodHandler.handler
}
paramKeys := rctx.routeParams.Keys
paramValues := rctx.routeParams.Values
if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) {
t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler)
}
if !stringSliceEqual(tt.k, paramKeys) {
t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys)
}
if !stringSliceEqual(tt.v, paramValues) {
t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues)
}
}
}
func TestTreeRegexpRecursive(t *testing.T) {
hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
tr := &node{}
tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-]+}/{secondId:[a-z0-9-]+}/first", hStub1)
tr.InsertRoute(mGET, "/one/{firstId:[a-z0-9-_]+}/{secondId:[a-z0-9-_]+}/second", hStub2)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
// debugPrintTree(0, 0, tr, 0)
// log.Println("~~~~~~~~~")
// log.Println("~~~~~~~~~")
tests := []struct {
r string // input request path
h http.Handler // output matched handler
k []string // output param keys
v []string // output param values
}{
{r: "/one/hello/world/first", h: hStub1, k: []string{"firstId", "secondId"}, v: []string{"hello", "world"}},
{r: "/one/hi_there/ok/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi_there", "ok"}},
{r: "/one///first", h: nil, k: []string{}, v: []string{}},
{r: "/one/hi/123/second", h: hStub2, k: []string{"firstId", "secondId"}, v: []string{"hi", "123"}},
}
for i, tt := range tests {
rctx := NewRouteContext()
_, handlers, _ := tr.FindRoute(rctx, mGET, tt.r)
var handler http.Handler
if methodHandler, ok := handlers[mGET]; ok {
handler = methodHandler.handler
}
paramKeys := rctx.routeParams.Keys
paramValues := rctx.routeParams.Values
if fmt.Sprintf("%v", tt.h) != fmt.Sprintf("%v", handler) {
t.Errorf("input [%d]: find '%s' expecting handler:%v , got:%v", i, tt.r, tt.h, handler)
}
if !stringSliceEqual(tt.k, paramKeys) {
t.Errorf("input [%d]: find '%s' expecting paramKeys:(%d)%v , got:(%d)%v", i, tt.r, len(tt.k), tt.k, len(paramKeys), paramKeys)
}
if !stringSliceEqual(tt.v, paramValues) {
t.Errorf("input [%d]: find '%s' expecting paramValues:(%d)%v , got:(%d)%v", i, tt.r, len(tt.v), tt.v, len(paramValues), paramValues)
}
}
}
func TestTreeRegexMatchWholeParam(t *testing.T) {
hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
rctx := NewRouteContext()
tr := &node{}
tr.InsertRoute(mGET, "/{id:[0-9]+}", hStub1)
tr.InsertRoute(mGET, "/{x:.+}/foo", hStub1)
tr.InsertRoute(mGET, "/{param:[0-9]*}/test", hStub1)
tests := []struct {
expectedHandler http.Handler
url string
}{
{url: "/13", expectedHandler: hStub1},
{url: "/a13", expectedHandler: nil},
{url: "/13.jpg", expectedHandler: nil},
{url: "/a13.jpg", expectedHandler: nil},
{url: "/a/foo", expectedHandler: hStub1},
{url: "//foo", expectedHandler: nil},
{url: "//test", expectedHandler: hStub1},
}
for _, tc := range tests {
_, _, handler := tr.FindRoute(rctx, mGET, tc.url)
if fmt.Sprintf("%v", tc.expectedHandler) != fmt.Sprintf("%v", handler) {
t.Errorf("url %v: expecting handler:%v , got:%v", tc.url, tc.expectedHandler, handler)
}
}
}
func TestTreeFindPattern(t *testing.T) {
hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
hStub3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
tr := &node{}
tr.InsertRoute(mGET, "/pages/*", hStub1)
tr.InsertRoute(mGET, "/articles/{id}/*", hStub2)
tr.InsertRoute(mGET, "/articles/{slug}/{uid}/*", hStub3)
if tr.findPattern("/pages") != false {
t.Errorf("find /pages failed")
}
if tr.findPattern("/pages*") != false {
t.Errorf("find /pages* failed - should be nil")
}
if tr.findPattern("/pages/*") == false {
t.Errorf("find /pages/* failed")
}
if tr.findPattern("/articles/{id}/*") == false {
t.Errorf("find /articles/{id}/* failed")
}
if tr.findPattern("/articles/{something}/*") == false {
t.Errorf("find /articles/{something}/* failed")
}
if tr.findPattern("/articles/{slug}/{uid}/*") == false {
t.Errorf("find /articles/{slug}/{uid}/* failed")
}
}
func debugPrintTree(parent int, i int, n *node, label byte) bool {
numEdges := 0
for _, nds := range n.children {
numEdges += len(nds)
}
// if n.handlers != nil {
// log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v handler:%v pat:%s keys:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.handlers, n.pattern, n.paramKeys)
// } else {
// log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v pat:%s keys:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.pattern, n.paramKeys)
// }
if n.endpoints != nil {
log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v handler:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf(), n.endpoints)
} else {
log.Printf("[node %d parent:%d] typ:%d prefix:%s label:%s tail:%s numEdges:%d isLeaf:%v\n", i, parent, n.typ, n.prefix, string(label), string(n.tail), numEdges, n.isLeaf())
}
parent = i
for _, nds := range n.children {
for _, e := range nds {
i++
if debugPrintTree(parent, i, e, e.label) {
return true
}
}
}
return false
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if b[i] != a[i] {
return false
}
}
return true
}
func BenchmarkTreeGet(b *testing.B) {
h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
tr := &node{}
tr.InsertRoute(mGET, "/", h1)
tr.InsertRoute(mGET, "/ping", h2)
tr.InsertRoute(mGET, "/pingall", h2)
tr.InsertRoute(mGET, "/ping/{id}", h2)
tr.InsertRoute(mGET, "/ping/{id}/woop", h2)
tr.InsertRoute(mGET, "/ping/{id}/{opt}", h2)
tr.InsertRoute(mGET, "/pinggggg", h2)
tr.InsertRoute(mGET, "/hello", h1)
mctx := NewRouteContext()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
mctx.Reset()
tr.FindRoute(mctx, mGET, "/ping/123/456")
}
}
func TestWalker(t *testing.T) {
r := bigMux()
// Walk the muxBig router tree.
if err := Walk(r, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
t.Logf("%v %v", method, route)
return nil
}); err != nil {
t.Error(err)
}
}