pax_global_header 0000666 0000000 0000000 00000000064 14033353662 0014517 g ustar 00root root 0000000 0000000 52 comment=3604edcb857415cb2c1213d63328cdcd738f2328 websocket-1.8.7/ 0000775 0000000 0000000 00000000000 14033353662 0013522 5 ustar 00root root 0000000 0000000 websocket-1.8.7/.github/ 0000775 0000000 0000000 00000000000 14033353662 0015062 5 ustar 00root root 0000000 0000000 websocket-1.8.7/.github/CODEOWNERS 0000664 0000000 0000000 00000000012 14033353662 0016446 0 ustar 00root root 0000000 0000000 * @nhooyr websocket-1.8.7/.github/FUNDING.yml 0000664 0000000 0000000 00000000017 14033353662 0016675 0 ustar 00root root 0000000 0000000 github: nhooyr websocket-1.8.7/.github/workflows/ 0000775 0000000 0000000 00000000000 14033353662 0017117 5 ustar 00root root 0000000 0000000 websocket-1.8.7/.github/workflows/ci.yaml 0000664 0000000 0000000 00000001572 14033353662 0020403 0 ustar 00root root 0000000 0000000 name: ci on: [push, pull_request] jobs: fmt: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/fmt.sh uses: ./ci/container with: args: ./ci/fmt.sh lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/lint.sh uses: ./ci/container with: args: ./ci/lint.sh test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/test.sh uses: ./ci/container with: args: ./ci/test.sh env: NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} NETLIFY_SITE_ID: 9b3ee4dc-8297-4774-b4b9-a61561fbbce7 - name: Upload coverage.html uses: actions/upload-artifact@v2 with: name: coverage.html path: ./ci/out/coverage.html websocket-1.8.7/.gitignore 0000664 0000000 0000000 00000000017 14033353662 0015510 0 ustar 00root root 0000000 0000000 websocket.test websocket-1.8.7/LICENSE.txt 0000664 0000000 0000000 00000002054 14033353662 0015346 0 ustar 00root root 0000000 0000000 MIT License Copyright (c) 2018 Anmol Sethi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. websocket-1.8.7/README.md 0000664 0000000 0000000 00000012216 14033353662 0015003 0 ustar 00root root 0000000 0000000 # websocket [](https://pkg.go.dev/nhooyr.io/websocket) [](https://nhooyrio-websocket-coverage.netlify.app) websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```bash go get nhooyr.io/websocket ``` ## Highlights - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) ## Examples For a production quality example that demonstrates the complete API, see the [echo example](./examples/echo). For a full stack example, see the [chat example](./examples/chat). ### Server ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { // ... } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { // ... } log.Printf("received: %v", v) c.Close(websocket.StatusNormalClosure, "") }) ``` ### Client ```go ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } defer c.Close(websocket.StatusInternalError, "the sky is falling") err = wsjson.Write(ctx, c, "hi") if err != nil { // ... } c.Close(websocket.StatusNormalClosure, "") ``` ## Comparison ### gorilla/websocket Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) Advantages of nhooyr.io/websocket: - Minimal and idiomatic API - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) - [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning to nhooyr.io/websocket. #### gobwas/ws [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. websocket-1.8.7/accept.go 0000664 0000000 0000000 00000025262 14033353662 0015317 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "bytes" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // // You probably want to use OriginPatterns instead. InsecureSkipVerify bool // OriginPatterns lists the host patterns for authorized origins. // The request host is always authorized. // Use this to enable cross origin WebSockets. // // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. // In such a case, example.com is the origin and chat.example.com is the request host. // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host // with filepath.Match. // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. // // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead // to bring attention to the danger of such a setting. OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // // Accept will not allow cross origin requests by default. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") if opts == nil { opts = &AcceptOptions{} } opts = &*opts errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { if errors.Is(err, filepath.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } } hj, ok := w.(http.Hijacker) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } copts, err := acceptCompression(r, w, opts.CompressionMode) if err != nil { return nil, err } w.WriteHeader(http.StatusSwitchingProtocols) // See https://github.com/nhooyr/websocket/issues/166 if ginWriter, ok := w.(interface { WriteHeaderNow() }); ok { ginWriter.WriteHeaderNow() } netConn, brw, err := hj.Hijack() if err != nil { err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } // https://github.com/golang/go/issues/32314 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, copts: copts, flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return 0, nil } func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") if origin == "" { return nil } u, err := url.Parse(origin) if err != nil { return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if strings.EqualFold(r.Host, u.Host) { return nil } for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } if matched { return nil } } return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } func match(pattern, s string) (bool, error) { return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { return cp } } } return "" } func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { if mode == CompressionDisabled { return nil, nil } for _, ext := range websocketExtensions(r.Header) { switch ext.name { case "permessage-deflate": return acceptDeflate(w, ext, mode) // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 // case "x-webkit-deflate-frame": // return acceptWebkitDeflate(w, ext, mode) } } return nil, nil } func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } if strings.HasPrefix(p, "client_max_window_bits") { // We cannot adjust the read sliding window so cannot make use of this. continue } err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } copts.setHeader(w.Header()) return copts, nil } func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. copts.serverNoContextTakeover = false for _, p := range ext.params { if p == "no_context_takeover" { copts.serverNoContextTakeover = true continue } // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead // of ignoring it as the draft spec is unclear. It says the server can ignore it // but the server has no way of signalling to the client it was ignored as the parameters // are set one way. // Thus us ignoring it would make the client think we understood it which would cause issues. // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } s := "x-webkit-deflate-frame" if copts.clientNoContextTakeover { s += "; no_context_takeover" } w.Header().Set("Sec-WebSocket-Extensions", s) return copts, nil } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { if strings.EqualFold(t, token) { return true } } return false } type websocketExtension struct { name string params []string } func websocketExtensions(h http.Header) []websocketExtension { var exts []websocketExtension extStrs := headerTokens(h, "Sec-WebSocket-Extensions") for _, extStr := range extStrs { if extStr == "" { continue } vals := strings.Split(extStr, ";") for i := range vals { vals[i] = strings.TrimSpace(vals[i]) } e := websocketExtension{ name: vals[0], params: vals[1:], } exts = append(exts, e) } return exts } func headerTokens(h http.Header, key string) []string { key = textproto.CanonicalMIMEHeaderKey(key) var tokens []string for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { t = strings.TrimSpace(t) tokens = append(tokens, t) } } return tokens } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } websocket-1.8.7/accept_js.go 0000664 0000000 0000000 00000000703 14033353662 0016004 0 ustar 00root root 0000000 0000000 package websocket import ( "errors" "net/http" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { Subprotocols []string InsecureSkipVerify bool OriginPatterns []string CompressionMode CompressionMode CompressionThreshold int } // Accept is stubbed out for Wasm. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return nil, errors.New("unimplemented") } websocket-1.8.7/accept_test.go 0000664 0000000 0000000 00000023117 14033353662 0016353 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "bufio" "errors" "net" "net/http" "net/http/httptest" "strings" "testing" "nhooyr.io/websocket/internal/test/assert" ) func TestAccept(t *testing.T) { t.Parallel() t.Run("badClientHandshake", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) assert.Contains(t, err, "protocol violation") }) t.Run("badOrigin", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) }) t.Run("badCompression", func(t *testing.T) { t.Parallel() w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), } r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") _, err := Accept(w, r, nil) assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) }) t.Run("badHijack", func(t *testing.T) { t.Parallel() w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { return nil, nil, errors.New("haha") }, } r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) } func Test_verifyClientHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string method string http1 bool h map[string]string success bool }{ { name: "badConnection", h: map[string]string{ "Connection": "notUpgrade", }, }, { name: "badUpgrade", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "notWebSocket", }, }, { name: "badMethod", method: "POST", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", }, }, { name: "badWebSocketVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "14", }, }, { name: "badWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "", }, }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "meow123", }, http1: true, }, { name: "success", h: map[string]string{ "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "meow123", }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(tc.method, "/", nil) r.ProtoMajor = 1 r.ProtoMinor = 1 if tc.http1 { r.ProtoMinor = 0 } for k, v := range tc.h { r.Header.Set(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_selectSubprotocol(t *testing.T) { t.Parallel() testCases := []struct { name string clientProtocols []string serverProtocols []string negotiated string }{ { name: "empty", clientProtocols: nil, serverProtocols: nil, negotiated: "", }, { name: "basic", clientProtocols: []string{"echo", "echo2"}, serverProtocols: []string{"echo2", "echo"}, negotiated: "echo2", }, { name: "none", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo4"}, negotiated: "", }, { name: "fallback", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, { name: "clientCasePresered", clientProtocols: []string{"Echo1"}, serverProtocols: []string{"echo1"}, negotiated: "Echo1", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { name string origin string host string originPatterns []string success bool }{ { name: "none", success: true, host: "example.com", }, { name: "invalid", origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", host: "example.com", success: false, }, { name: "unauthorized", origin: "https://example.com", host: "example1.com", success: false, }, { name: "authorized", origin: "https://example.com", host: "example.com", success: true, }, { name: "authorizedCaseInsensitive", origin: "https://examplE.com", host: "example.com", success: true, }, { name: "originPatterns", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "*.example.com", "bar.com", }, success: true, }, { name: "originPatternsUnauthorized", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "exam3.com", "bar.com", }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_acceptCompression(t *testing.T) { t.Parallel() testCases := []struct { name string mode CompressionMode reqSecWebSocketExtensions string respSecWebSocketExtensions string expCopts *compressionOptions error bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, }, { name: "permessage-deflate", mode: CompressionNoContextTakeover, reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, }, { name: "permessage-deflate/error", mode: CompressionNoContextTakeover, reqSecWebSocketExtensions: "permessage-deflate; meow", error: true, }, // { // name: "x-webkit-deflate-frame", // mode: CompressionNoContextTakeover, // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", // expCopts: &compressionOptions{ // clientNoContextTakeover: true, // serverNoContextTakeover: true, // }, // }, // { // name: "x-webkit-deflate/error", // mode: CompressionNoContextTakeover, // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", // error: true, // }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) w := httptest.NewRecorder() copts, err := acceptCompression(r, w, tc.mode) if tc.error { assert.Error(t, err) return } assert.Success(t, err) assert.Equal(t, "compression options", tc.expCopts, copts) assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } type mockHijacker struct { http.ResponseWriter hijack func() (net.Conn, *bufio.ReadWriter, error) } var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } websocket-1.8.7/autobahn_test.go 0000664 0000000 0000000 00000012407 14033353662 0016715 0 ustar 00root root 0000000 0000000 // +build !js package websocket_test import ( "context" "encoding/json" "fmt" "io/ioutil" "net" "os" "os/exec" "strconv" "strings" "testing" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" ) var excludedAutobahnCases = []string{ // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just // more performance overhead. "6.*", "7.5.1", // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } var autobahnCases = []string{"*"} func TestAutobahn(t *testing.T) { t.Parallel() if os.Getenv("AUTOBAHN_TEST") == "" { t.SkipNow() } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) assert.Success(t, err) defer closeFn() err = waitWS(ctx, wstestURL) assert.Success(t, err) cases, err := wstestCaseCount(ctx, wstestURL) assert.Success(t, err) t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) } }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") } func waitWS(ctx context.Context, url string) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() for ctx.Err() == nil { c, _, err := websocket.Dial(ctx, url, nil) if err != nil { continue } c.Close(websocket.StatusNormalClosure, "") return nil } return ctx.Err() } func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { serverAddr, err := unusedListenAddr() if err != nil { return "", nil, err } url = "ws://" + serverAddr specFile, err := tempJSONFile(map[string]interface{}{ "url": url, "outdir": "ci/out/wstestClientReports", "cases": autobahnCases, "exclude-cases": excludedAutobahnCases, }) if err != nil { return "", nil, fmt.Errorf("failed to write spec: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer func() { if err != nil { cancel() } }() args := []string{"--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", } wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { return "", nil, fmt.Errorf("failed to start wstest: %w", err) } return url, func() { wstest.Process.Kill() }, nil } func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { defer errd.Wrap(&err, "failed to get case count") c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) if err != nil { return 0, err } defer c.Close(websocket.StatusInternalError, "") _, r, err := c.Reader(ctx) if err != nil { return 0, err } b, err := ioutil.ReadAll(r) if err != nil { return 0, err } cases, err = strconv.Atoi(string(b)) if err != nil { return 0, err } c.Close(websocket.StatusNormalClosure, "") return cases, nil } func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) assert.Success(t, err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) assert.Success(t, err) for _, tests := range indexJSON { for test, result := range tests { t.Run(test, func(t *testing.T) { switch result.BehaviorClose { case "OK", "INFORMATIONAL": default: t.Errorf("bad close behaviour") } switch result.Behavior { case "OK", "NON-STRICT", "INFORMATIONAL": default: t.Errorf("failed") } }) } } if t.Failed() { htmlPath := strings.Replace(path, ".json", ".html", 1) t.Errorf("detected autobahn violation, see %q", htmlPath) } } func unusedListenAddr() (_ string, err error) { defer errd.Wrap(&err, "failed to get unused listen address") l, err := net.Listen("tcp", "localhost:0") if err != nil { return "", err } l.Close() return l.Addr().String(), nil } func tempJSONFile(v interface{}) (string, error) { f, err := ioutil.TempFile("", "temp.json") if err != nil { return "", fmt.Errorf("temp file: %w", err) } defer f.Close() e := json.NewEncoder(f) e.SetIndent("", "\t") err = e.Encode(v) if err != nil { return "", fmt.Errorf("json encode: %w", err) } err = f.Close() if err != nil { return "", fmt.Errorf("close temp file: %w", err) } return f.Name(), nil } websocket-1.8.7/ci/ 0000775 0000000 0000000 00000000000 14033353662 0014115 5 ustar 00root root 0000000 0000000 websocket-1.8.7/ci/all.sh 0000775 0000000 0000000 00000000211 14033353662 0015216 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." ./ci/fmt.sh ./ci/lint.sh ./ci/test.sh "$@" } main "$@" websocket-1.8.7/ci/container/ 0000775 0000000 0000000 00000000000 14033353662 0016077 5 ustar 00root root 0000000 0000000 websocket-1.8.7/ci/container/Dockerfile 0000664 0000000 0000000 00000000623 14033353662 0020072 0 ustar 00root root 0000000 0000000 FROM golang RUN apt-get update RUN apt-get install -y npm shellcheck chromium ENV GO111MODULE=on RUN go get golang.org/x/tools/cmd/goimports RUN go get mvdan.cc/sh/v3/cmd/shfmt RUN go get golang.org/x/tools/cmd/stringer RUN go get golang.org/x/lint/golint RUN go get github.com/agnivade/wasmbrowsertest RUN npm --unsafe-perm=true install -g prettier RUN npm --unsafe-perm=true install -g netlify-cli websocket-1.8.7/ci/fmt.sh 0000775 0000000 0000000 00000001353 14033353662 0015244 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go mod tidy gofmt -w -s . goimports -w "-local=$(go list -m)" . prettier \ --write \ --print-width=120 \ --no-semi \ --trailing-comma=all \ --loglevel=warn \ --arrow-parens=avoid \ $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") shfmt -i 2 -w -s -sr $(git ls-files "*.sh") stringer -type=opcode,MessageType,StatusCode -output=stringer.go if [[ ${CI-} ]]; then ensure_fmt fi } ensure_fmt() { if [[ $(git ls-files --other --modified --exclude-standard) ]]; then git -c color.ui=always --no-pager diff echo echo "Please run the following locally:" echo " ./ci/fmt.sh" exit 1 fi } main "$@" websocket-1.8.7/ci/lint.sh 0000775 0000000 0000000 00000000425 14033353662 0015423 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go vet ./... GOOS=js GOARCH=wasm go vet ./... golint -set_exit_status ./... GOOS=js GOARCH=wasm golint -set_exit_status ./... shellcheck --exclude=SC2046 $(git ls-files "*.sh") } main "$@" websocket-1.8.7/ci/out/ 0000775 0000000 0000000 00000000000 14033353662 0014724 5 ustar 00root root 0000000 0000000 websocket-1.8.7/ci/out/.gitignore 0000664 0000000 0000000 00000000002 14033353662 0016704 0 ustar 00root root 0000000 0000000 * websocket-1.8.7/ci/test.sh 0000775 0000000 0000000 00000001320 14033353662 0015427 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof sed -i '/examples/d' ci/out/coverage.prof # Last line is the total coverage. go tool cover -func ci/out/coverage.prof | tail -n1 go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html if [[ ${CI-} && ${GITHUB_REF-} == *master ]]; then local deployDir deployDir="$(mktemp -d)" cp ci/out/coverage.html "$deployDir/index.html" netlify deploy --prod "--dir=$deployDir" fi } main "$@" websocket-1.8.7/close.go 0000664 0000000 0000000 00000004526 14033353662 0015165 0 ustar 00root root 0000000 0000000 package websocket import ( "errors" "fmt" ) // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 } websocket-1.8.7/close_notjs.go 0000664 0000000 0000000 00000010750 14033353662 0016376 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "context" "encoding/binary" "errors" "fmt" "log" "time" "nhooyr.io/websocket/internal/errd" ) // Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. // All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // // The maximum length of reason must be 125 bytes. Avoid // sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { return c.closeHandshake(code, reason) } func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") writeErr := c.writeClose(code, reason) closeHandshakeErr := c.waitCloseHandshake() if writeErr != nil { return writeErr } if CloseStatus(closeHandshakeErr) == -1 { return closeHandshakeErr } return nil } var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if wroteClose { return errAlreadyWroteClose } ce := CloseError{ Code: code, Reason: reason, } var p []byte var marshalErr error if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() if marshalErr != nil { log.Printf("websocket: %v", marshalErr) } } writeErr := c.writeControl(context.Background(), opClose, p) if CloseStatus(writeErr) != -1 { // Not a real error if it's due to a close frame being received. writeErr = nil } // We do this after in case there was an error writing the close frame. c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) if marshalErr != nil { return marshalErr } return writeErr } func (c *Conn) waitCloseHandshake() error { defer c.close(nil) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) if err != nil { return err } defer c.readMu.unlock() if c.readCloseFrameErr != nil { return c.readCloseFrameErr } for { h, err := c.readLoop(ctx) if err != nil { return err } for i := int64(0); i < h.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } } } func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ Code: StatusNoStatusRcvd, }, nil } if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ Code: StatusCode(binary.BigEndian.Uint16(p)), Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { switch code { case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: return false } if code >= StatusNormalClosure && code <= StatusBadGateway { return true } if code >= 3000 && code <= 4999 { return true } return false } func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } return p, err } const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() c.setCloseErrLocked(err) c.closeMu.Unlock() } func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } websocket-1.8.7/close_test.go 0000664 0000000 0000000 00000006363 14033353662 0016225 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "io" "math" "strings" "testing" "nhooyr.io/websocket/internal/test/assert" ) func TestCloseError(t *testing.T) { t.Parallel() testCases := []struct { name string ce CloseError success bool }{ { name: "normal", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason), }, success: true, }, { name: "bigReason", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason+1), }, success: false, }, { name: "bigCode", ce: CloseError{ Code: math.MaxUint16, Reason: strings.Repeat("x", maxCloseReason), }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } t.Run("Error", func(t *testing.T) { exp := `status = StatusInternalError and reason = "meow"` act := CloseError{ Code: StatusInternalError, Reason: "meow", }.Error() assert.Equal(t, "CloseError.Error()", exp, act) }) } func Test_parseClosePayload(t *testing.T) { t.Parallel() testCases := []struct { name string p []byte success bool ce CloseError }{ { name: "normal", p: append([]byte{0x3, 0xE8}, []byte("hello")...), success: true, ce: CloseError{ Code: StatusNormalClosure, Reason: "hello", }, }, { name: "nothing", success: true, ce: CloseError{ Code: StatusNoStatusRcvd, }, }, { name: "oneByte", p: []byte{0}, success: false, }, { name: "badStatusCode", p: []byte{0x17, 0x70}, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ce, err := parseClosePayload(tc.p) if tc.success { assert.Success(t, err) assert.Equal(t, "close payload", tc.ce, ce) } else { assert.Error(t, err) } }) } } func Test_validWireCloseCode(t *testing.T) { t.Parallel() testCases := []struct { name string code StatusCode valid bool }{ { name: "normal", code: StatusNormalClosure, valid: true, }, { name: "noStatus", code: StatusNoStatusRcvd, valid: false, }, { name: "3000", code: 3000, valid: true, }, { name: "4999", code: 4999, valid: true, }, { name: "unknown", code: 5000, valid: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := validWireCloseCode(tc.code) assert.Equal(t, "wire close code", tc.valid, act) }) } } func TestCloseStatus(t *testing.T) { t.Parallel() testCases := []struct { name string in error exp StatusCode }{ { name: "nil", in: nil, exp: -1, }, { name: "io.EOF", in: io.EOF, exp: -1, }, { name: "StatusInternalError", in: CloseError{ Code: StatusInternalError, }, exp: StatusInternalError, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := CloseStatus(tc.in) assert.Equal(t, "close status", tc.exp, act) }) } } websocket-1.8.7/compress.go 0000664 0000000 0000000 00000003504 14033353662 0015706 0 ustar 00root root 0000000 0000000 package websocket // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // // A compatibility layer is implemented for the older deflate-frame extension used // by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 // It will work the same in every way except that we cannot signal to the peer we // want to use no context takeover on our side, we can only signal that they should. // It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 type CompressionMode int const ( // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. CompressionContextTakeover // CompressionDisabled disables the deflate extension. // // Use this if you are using a predominantly binary protocol with very // little duplication in between messages or CPU and memory are more // important than bandwidth. CompressionDisabled ) websocket-1.8.7/compress_notjs.go 0000664 0000000 0000000 00000006707 14033353662 0017133 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "io" "net/http" "sync" "github.com/klauspost/compress/flate" ) func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool } func (copts *compressionOptions) setHeader(h http.Header) { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" } if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } h.Set("Sec-WebSocket-Extensions", s) } // These bytes are required to get flate.Reader to return. // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { w io.Writer tail []byte } func (tw *trimLastFourBytesWriter) reset() { if tw != nil && tw.tail != nil { tw.tail = tw.tail[:0] } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if tw.tail == nil { tw.tail = make([]byte, 0, 4) } extra := len(tw.tail) + len(p) - 4 if extra <= 0 { tw.tail = append(tw.tail, p...) return len(p), nil } // Now we need to write as many extra bytes as we can from the previous tail. if extra > len(tw.tail) { extra = len(tw.tail) } if extra > 0 { _, err := tw.w.Write(tw.tail[:extra]) if err != nil { return 0, err } // Shift remaining bytes in tail over. n := copy(tw.tail, tw.tail[extra:]) tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, // all of it is is part of the tail. if len(p) <= 4 { tw.tail = append(tw.tail, p...) return len(p), nil } // Otherwise, only the last 4 bytes are. tw.tail = append(tw.tail, p[len(p)-4:]...) p = p[:len(p)-4] n, err := tw.w.Write(p) return n + 4, err } var flateReaderPool sync.Pool func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { return flate.NewReaderDict(r, dict) } fr.(flate.Resetter).Reset(r, dict) return fr } func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } type slidingWindow struct { buf []byte } var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} func slidingWindowPool(n int) *sync.Pool { swPoolMu.RLock() p, ok := swPool[n] swPoolMu.RUnlock() if ok { return p } p = &sync.Pool{} swPoolMu.Lock() swPool[n] = p swPoolMu.Unlock() return p } func (sw *slidingWindow) init(n int) { if sw.buf != nil { return } if n == 0 { n = 32768 } p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { sw.buf = buf[:0] } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { if sw.buf == nil { return } swPoolMu.Lock() swPool[cap(sw.buf)].Put(sw.buf) swPoolMu.Unlock() sw.buf = nil } func (sw *slidingWindow) write(p []byte) { if len(p) >= cap(sw.buf) { sw.buf = sw.buf[:cap(sw.buf)] p = p[len(p)-cap(sw.buf):] copy(sw.buf, p) return } left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left copy(sw.buf, sw.buf[spaceNeeded:]) sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } sw.buf = append(sw.buf, p...) } websocket-1.8.7/compress_test.go 0000664 0000000 0000000 00000001270 14033353662 0016743 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "strings" "testing" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { t.Parallel() const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { t.Run("", func(t *testing.T) { t.Parallel() input := xrand.String(maxWindow) windowLength := xrand.Int(maxWindow) var sw slidingWindow sw.init(windowLength) sw.write([]byte(input)) assert.Equal(t, "window length", windowLength, cap(sw.buf)) if !strings.HasSuffix(input, string(sw.buf)) { t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) } }) } } websocket-1.8.7/conn.go 0000664 0000000 0000000 00000000551 14033353662 0015007 0 ustar 00root root 0000000 0000000 package websocket // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int // MessageType constants. const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary ) websocket-1.8.7/conn_notjs.go 0000664 0000000 0000000 00000012444 14033353662 0016230 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "runtime" "strconv" "sync" "sync/atomic" ) // Conn represents a WebSocket connection. // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control // frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. // // On any error from any method, the connection is closed // with an appropriate reason. type Conn struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context // Read state. readMu *mu readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. msgWriterState *msgWriterState writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte writeHeader header closed chan struct{} closeMu sync.Mutex closeErr error wroteClose bool pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } type connConfig struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer } func newConn(cfg connConfig) *Conn { c := &Conn{ subprotocol: cfg.subprotocol, rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } c.readMu = newMu(c) c.writeFrameMu = newMu(c) c.msgReader = newMsgReader(c) c.msgWriterState = newMsgWriterState(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } } runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) go c.timeoutLoop() return c } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } func (c *Conn) close(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { return } c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. c.rwc.Close() go func() { c.msgWriterState.close() c.msgReader.close() }() } func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() for { select { case <-c.closed: return case writeCtx = <-c.writeTimeout: case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } } func (c *Conn) flate() bool { return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does // not read from the connection but instead waits for a Reader call // to read the pong. // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { p := atomic.AddInt32(&c.pingCounter, 1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } return nil } func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}, 1) c.activePingsMu.Lock() c.activePings[p] = pong c.activePingsMu.Unlock() defer func() { c.activePingsMu.Lock() delete(c.activePings, p) c.activePingsMu.Unlock() }() err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { case <-c.closed: return c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: return nil } } type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected // over the receive on closed. select { case <-m.c.closed: // Make sure to release. m.unlock() return m.c.closeErr default: } return nil } } func (m *mu) unlock() { select { case <-m.ch: default: } } websocket-1.8.7/conn_test.go 0000664 0000000 0000000 00000027305 14033353662 0016054 0 ustar 00root root 0000000 0000000 // +build !js package websocket_test import ( "bytes" "context" "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "os" "os/exec" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wspb" ) func TestConn(t *testing.T) { t.Parallel() t.Run("fuzzData", func(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) } for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }, &websocket.AcceptOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) defer tt.cleanup() tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { err := wstest.Echo(tt.ctx, c1, 131072) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } }) t.Run("badClose", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) defer tt.cleanup() err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { err := c1.Ping(tt.ctx) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() c2.CloseRead(tt.ctx) ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) defer cancel() err := c1.Ping(ctx) assert.Contains(t, err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) const count = 100 errs := make(chan error, count) for i := 0; i < count; i++ { go func() { select { case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg): case <-tt.ctx.Done(): return } }() } for i := 0; i < count; i++ { select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) assert.Equal(t, "write error", context.DeadlineExceeded, err) }) t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) // Does not give any confidence but at least ensures no crashes. d, _ := tt.ctx.Deadline() n1.SetDeadline(d) n1.SetDeadline(time.Time{}) assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) if err != nil { return err } return n2.Close() }) b, err := ioutil.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } assert.Equal(t, "read msg", []byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) if err != nil { return err } return nil }) _, err := ioutil.ReadAll(n1) assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } }) t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) select { case err := <-werr: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("wspb", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goEchoLoop(c2) exp := ptypes.DurationProto(100) err := wspb.Write(tt.ctx, c1, exp) assert.Success(t, err) act := &duration.Duration{} err = wspb.Read(tt.ctx, c1, act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } func TestWasm(t *testing.T) { t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, }) if err != nil { t.Error(err) } })) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } } func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { return fmt.Errorf("expected close status %v but got %v", exp, err) } return nil } type connTest struct { t testing.TB ctx context.Context doneFuncs []func() } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { if t, ok := t.(*testing.T); ok { t.Parallel() } t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} tt.appendDone(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } tt.appendDone(func() { c2.Close(websocket.StatusInternalError, "") c1.Close(websocket.StatusInternalError, "") }) return tt, c1, c2 } func (tt *connTest) appendDone(f func()) { tt.doneFuncs = append(tt.doneFuncs, f) } func (tt *connTest) cleanup() { for i := len(tt.doneFuncs) - 1; i >= 0; i-- { tt.doneFuncs[i]() } } func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) echoLoopErr := xsync.Go(func() error { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) tt.appendDone(func() { cancel() err := <-echoLoopErr if err != nil { tt.t.Errorf("echo loop error: %v", err) } }) } func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { defer c.Close(websocket.StatusInternalError, "") for { _, _, err := c.Read(ctx) if err != nil { return assertCloseStatus(websocket.StatusNormalClosure, err) } } }) tt.appendDone(func() { cancel() err := <-discardLoopErr if err != nil { tt.t.Errorf("discard loop error: %v", err) } }) } func BenchmarkConn(b *testing.B) { var benchCases = []struct { name string mode websocket.CompressionMode }{ { name: "disabledCompress", mode: websocket.CompressionDisabled, }, { name: "compress", mode: websocket.CompressionContextTakeover, }, { name: "compressNoContext", mode: websocket.CompressionNoContextTakeover, }, } for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ CompressionMode: bc.mode, }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) defer bb.cleanup() bb.goEchoLoop(c2) bytesWritten := c1.RecordBytesWritten() bytesRead := c1.RecordBytesRead() msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) defer close(writes) werrs := make(chan error) go func() { for range writes { select { case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg): case <-bb.ctx.Done(): return } } }() b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { select { case writes <- struct{}{}: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } typ, r, err := c1.Reader(bb.ctx) if err != nil { b.Fatal(err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) } _, err = io.ReadFull(r, readBuf) if err != nil { b.Fatal(err) } n2, err := r.Read(readBuf) if err != io.EOF { assert.Equal(b, "read err", io.EOF, err) } if n2 != 0 { assert.Equal(b, "n2", 0, n2) } if !bytes.Equal(msg, readBuf) { assert.Equal(b, "msg", msg, readBuf) } select { case err = <-werrs: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } if err != nil { b.Fatal(err) } } b.StopTimer() b.ReportMetric(float64(*bytesWritten/b.N), "written/op") b.ReportMetric(float64(*bytesRead/b.N), "read/op") err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(b, err) }) } } func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { defer errd.Wrap(&err, "echo server failed") c, err := websocket.Accept(w, r, opts) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") err = wstest.EchoLoop(r.Context(), c) return assertCloseStatus(websocket.StatusNormalClosure, err) } func TestGin(t *testing.T) { t.Parallel() gin.SetMode(gin.ReleaseMode) r := gin.New() r.GET("/", func(ginCtx *gin.Context) { err := echoServer(ginCtx.Writer, ginCtx.Request, nil) if err != nil { t.Error(err) } }) s := httptest.NewServer(r) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() c, _, err := websocket.Dial(ctx, s.URL, nil) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") err = wsjson.Write(ctx, c, "hello") assert.Success(t, err) var v interface{} err = wsjson.Read(ctx, c, &v) assert.Success(t, err) assert.Equal(t, "read msg", "hello", v) err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } websocket-1.8.7/dial.go 0000664 0000000 0000000 00000017325 14033353662 0014772 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "sync" "time" "nhooyr.io/websocket/internal/errd" ) // DialOptions represents Dial's options. type DialOptions struct { // HTTPClient is used for the connection. // Its Transport must return writable bodies for WebSocket handshakes. // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. // However, you can only read the first 1024 bytes of the body. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") if opts == nil { opts = &DialOptions{} } opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } else if opts.HTTPClient.Timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) defer cancel() newClient := *opts.HTTPClient newClient.Timeout = 0 opts.HTTPClient = &newClient } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } var copts *compressionOptions if opts.CompressionMode != CompressionDisabled { copts = opts.CompressionMode.opts() } resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } respBody := resp.Body resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() }) defer timer.Stop() b, _ := ioutil.ReadAll(r) respBody.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) } }() copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } rwc, ok := respBody.(io.ReadWriteCloser) if !ok { return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, copts: copts, flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { copts.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } func secWebSocketKey(rr io.Reader) (string, error) { if rr == nil { rr = rand.Reader } b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) } err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { proto := resp.Header.Get("Sec-WebSocket-Protocol") if proto == "" { return nil } for _, sp2 := range subprotos { if strings.EqualFold(sp2, proto) { return nil } } return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil } ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts = &*copts for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } var bufioReaderPool sync.Pool func getBufioReader(r io.Reader) *bufio.Reader { br, ok := bufioReaderPool.Get().(*bufio.Reader) if !ok { return bufio.NewReader(r) } br.Reset(r) return br } func putBufioReader(br *bufio.Reader) { bufioReaderPool.Put(br) } var bufioWriterPool sync.Pool func getBufioWriter(w io.Writer) *bufio.Writer { bw, ok := bufioWriterPool.Get().(*bufio.Writer) if !ok { return bufio.NewWriter(w) } bw.Reset(w) return bw } func putBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } websocket-1.8.7/dial_test.go 0000664 0000000 0000000 00000012534 14033353662 0016026 0 ustar 00root root 0000000 0000000 // +build !js package websocket import ( "context" "crypto/rand" "io" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "time" "nhooyr.io/websocket/internal/test/assert" ) func TestBadDials(t *testing.T) { t.Parallel() t.Run("badReq", func(t *testing.T) { t.Parallel() testCases := []struct { name string url string opts *DialOptions rand readerFunc }{ { name: "badURL", url: "://noscheme", }, { name: "badURLScheme", url: "ftp://nhooyr.io", }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", }, { name: "badReader", rand: func(p []byte) (int, error) { return 0, io.EOF }, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() if tc.rand == nil { tc.rand = rand.Reader.Read } _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } }) t.Run("badResponse", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(strings.NewReader("hi")), }, nil }), }) assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") }) t.Run("badBody", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() rt := func(r *http.Request) (*http.Response, error) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, Body: ioutil.NopCloser(strings.NewReader("hi")), }, nil } _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } func Test_verifyServerHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string response func(w http.ResponseWriter) success bool }{ { name: "badStatus", response: func(w http.ResponseWriter) { w.WriteHeader(http.StatusOK) }, success: false, }, { name: "badConnection", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badUpgrade", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketAccept", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Accept", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketProtocol", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Protocol", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedExtension", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedDeflateParam", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "success", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.WriteHeader(http.StatusSwitchingProtocols) }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() tc.response(w) resp := w.Result() r := httptest.NewRequest("GET", "/", nil) key, err := secWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } opts := &DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func mockHTTPClient(fn roundTripperFunc) *http.Client { return &http.Client{ Transport: fn, } } type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } websocket-1.8.7/doc.go 0000664 0000000 0000000 00000001723 14033353662 0014621 0 ustar 00root root 0000000 0000000 // +build !js // Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // // Use Dial to dial a WebSocket server. // // Use Accept to accept a WebSocket client. // // Conn represents the resulting WebSocket connection. // // The examples are the best way to understand how to correctly use the library. // // The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // // Wasm // // The client side supports compiling to Wasm. // It wraps the WebSocket browser API. // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // // Some important caveats to be aware of: // // - Accept always errors out // - Conn.Ping is no-op // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" websocket-1.8.7/example_test.go 0000664 0000000 0000000 00000011640 14033353662 0016545 0 ustar 00root root 0000000 0000000 package websocket_test import ( "context" "log" "net/http" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) func ExampleAccept() { // This handler accepts a WebSocket connection, reads a single JSON // message from the client and then closes the connection. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func ExampleDial() { // Dials a server, writes a single JSON message and then // closes the connection. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") err = wsjson.Write(ctx, c, "hi") if err != nil { log.Fatal(err) } c.Close(websocket.StatusNormalClosure, "") } func ExampleCloseStatus() { // Dials a server and then expects to be disconnected with status code // websocket.StatusNormalClosure. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } func Example_writeOnly() { // This handler demonstrates how to correctly handle a write only WebSocket connection. // i.e you only expect to write messages and do not expect to read any messages. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() ctx = c.CloseRead(ctx) t := time.NewTicker(time.Second * 30) defer t.Stop() for { select { case <-ctx.Done(): c.Close(websocket.StatusNormalClosure, "") return case <-t.C: err = wsjson.Write(ctx, c, "hi") if err != nil { log.Println(err) return } } } }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func Example_crossOrigin() { // This handler demonstrates how to safely accept cross origin WebSockets // from the origin example.com. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } // This example demonstrates how to create a WebSocket server // that gracefully exits when sent a signal. // // It starts a WebSocket server that keeps every connection open // for 10 seconds. // If you CTRL+C while a connection is open, it will wait at most 30s // for all connections to terminate before shutting down. // func ExampleGrace() { // fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // c, err := websocket.Accept(w, r, nil) // if err != nil { // log.Println(err) // return // } // defer c.Close(websocket.StatusInternalError, "the sky is falling") // // ctx := c.CloseRead(r.Context()) // select { // case <-ctx.Done(): // case <-time.After(time.Second * 10): // } // // c.Close(websocket.StatusNormalClosure, "") // }) // // var g websocket.Grace // s := &http.Server{ // Handler: g.Handler(fn), // ReadTimeout: time.Second * 15, // WriteTimeout: time.Second * 15, // } // // errc := make(chan error, 1) // go func() { // errc <- s.ListenAndServe() // }() // // sigs := make(chan os.Signal, 1) // signal.Notify(sigs, os.Interrupt) // select { // case err := <-errc: // log.Printf("failed to listen and serve: %v", err) // case sig := <-sigs: // log.Printf("terminating: %v", sig) // } // // ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) // defer cancel() // s.Shutdown(ctx) // g.Shutdown(ctx) // } // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { // https://github.com/nhooyr/websocket/tree/master/examples/chat } // This example demonstrates a echo server. func Example_echo() { // https://github.com/nhooyr/websocket/tree/master/examples/echo } websocket-1.8.7/examples/ 0000775 0000000 0000000 00000000000 14033353662 0015340 5 ustar 00root root 0000000 0000000 websocket-1.8.7/examples/README.md 0000664 0000000 0000000 00000000136 14033353662 0016617 0 ustar 00root root 0000000 0000000 # Examples This directory contains more involved examples unsuitable for display with godoc. websocket-1.8.7/examples/chat/ 0000775 0000000 0000000 00000000000 14033353662 0016257 5 ustar 00root root 0000000 0000000 websocket-1.8.7/examples/chat/README.md 0000664 0000000 0000000 00000002672 14033353662 0017545 0 ustar 00root root 0000000 0000000 # Chat Example This directory contains a full stack example of a simple chat webapp using nhooyr.io/websocket. ```bash $ cd examples/chat $ go run . localhost:0 listening on http://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser.  ## Structure The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the DOM with a scrollable div at the top that is populated with new messages as they are broadcast. At the bottom it adds a form to submit messages. The messages are received via the WebSocket `/subscribe` endpoint and published via the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket is so that you can easily publish a message with curl. The server portion is `main.go` and `chat.go` and implements serving the static frontend assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint. The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. There are two automated tests for the server included in `chat_test.go`. The first is a simple one client echo test. It publishes a single message and ensures it's received. The second is a complex concurrency test where 10 clients send 128 unique messages of max 128 bytes concurrently. The test ensures all messages are seen by every client. websocket-1.8.7/examples/chat/chat.go 0000664 0000000 0000000 00000011436 14033353662 0017532 0 ustar 00root root 0000000 0000000 package main import ( "context" "errors" "io/ioutil" "log" "net/http" "sync" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { // subscriberMessageBuffer controls the max number // of messages that can be queued for a subscriber // before it is kicked. // // Defaults to 16. subscriberMessageBuffer int // publishLimiter controls the rate limit applied to the publish endpoint. // // Defaults to one publish every 100ms with a burst of 8. publishLimiter *rate.Limiter // logf controls where logs are sent. // Defaults to log.Printf. logf func(f string, v ...interface{}) // serveMux routes the various endpoints to the appropriate handler. serveMux http.ServeMux subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } // newChatServer constructs a chatServer with the defaults. func newChatServer() *chatServer { cs := &chatServer{ subscriberMessageBuffer: 16, logf: log.Printf, subscribers: make(map[*subscriber]struct{}), publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), } cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) cs.serveMux.HandleFunc("/publish", cs.publishHandler) return cs } // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. type subscriber struct { msgs chan []byte closeSlow func() } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") err = cs.subscribe(r.Context(), c) if errors.Is(err, context.Canceled) { return } if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { return } if err != nil { cs.logf("%v", err) return } } // publishHandler reads the request body with a limit of 8192 bytes and then publishes // the received message. func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } cs.publish(msg) w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. // It creates a subscriber with a buffered msgs chan to give some room to slower // connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) for { select { case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } // publish publishes the msg to all subscribers. // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.Lock() defer cs.subscribersMu.Unlock() cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { case s.msgs <- msg: default: go s.closeSlow() } } } // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } // deleteSubscriber deletes the given subscriber. func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() delete(cs.subscribers, s) cs.subscribersMu.Unlock() } func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return c.Write(ctx, websocket.MessageText, msg) } websocket-1.8.7/examples/chat/chat_test.go 0000664 0000000 0000000 00000013463 14033353662 0020573 0 ustar 00root root 0000000 0000000 package main import ( "context" "crypto/rand" "fmt" "math/big" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) func Test_chatServer(t *testing.T) { t.Parallel() // This is a simple echo test with a single client. // The client sends a message and ensures it receives // it on its WebSocket. t.Run("simple", func(t *testing.T) { t.Parallel() url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() expMsg := randString(512) err = cl.publish(ctx, expMsg) assertSuccess(t, err) msg, err := cl.nextMessage() assertSuccess(t, err) if expMsg != msg { t.Fatalf("expected %v but got %v", expMsg, msg) } }) // This test is a complex concurrency test. // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // // The test verifies that every message is seen by ever client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() const nmessages = 128 const maxMessageSize = 128 const nclients = 16 url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var clients []*client var clientMsgs []map[string]struct{} for i := 0; i < nclients; i++ { cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() clients = append(clients, cl) clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) } allMessages := make(map[string]struct{}) for _, msgs := range clientMsgs { for m := range msgs { allMessages[m] = struct{}{} } } var wg sync.WaitGroup for i, cl := range clients { i := i cl := cl wg.Add(1) go func() { defer wg.Done() err := cl.publishMsgs(ctx, clientMsgs[i]) if err != nil { t.Errorf("client %d failed to publish all messages: %v", i, err) } }() wg.Add(1) go func() { defer wg.Done() err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) if err != nil { t.Errorf("client %d failed to receive all messages: %v", i, err) } }() } wg.Wait() }) } // setupTest sets up chatServer that can be used // via the returned url. // // Defer closeFn to ensure everything is cleaned up at // the end of the test. // // chatServer logs will be logged via t.Logf. func setupTest(t *testing.T) (url string, closeFn func()) { cs := newChatServer() cs.logf = t.Logf // To ensure tests run quickly under even -race. cs.subscriberMessageBuffer = 4096 cs.publishLimiter.SetLimit(rate.Inf) s := httptest.NewServer(cs) return s.URL, func() { s.Close() } } // testAllMessagesReceived ensures that after n reads, all msgs in msgs // have been read. func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { msgs = cloneMessages(msgs) for i := 0; i < n; i++ { msg, err := cl.nextMessage() if err != nil { return err } delete(msgs, msg) } if len(msgs) != 0 { return fmt.Errorf("did not receive all expected messages: %q", msgs) } return nil } func cloneMessages(msgs map[string]struct{}) map[string]struct{} { msgs2 := make(map[string]struct{}, len(msgs)) for m := range msgs { msgs2[m] = struct{}{} } return msgs2 } func randMessages(n, maxMessageLength int) map[string]struct{} { msgs := make(map[string]struct{}) for i := 0; i < n; i++ { m := randString(randInt(maxMessageLength)) if _, ok := msgs[m]; ok { i-- continue } msgs[m] = struct{}{} } return msgs } func assertSuccess(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } type client struct { url string c *websocket.Conn } func newClient(ctx context.Context, url string) (*client, error) { c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) if err != nil { return nil, err } cl := &client{ url: url, c: c, } return cl, nil } func (cl *client) publish(ctx context.Context, msg string) (err error) { defer func() { if err != nil { cl.c.Close(websocket.StatusInternalError, "publish failed") } }() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusAccepted { return fmt.Errorf("publish request failed: %v", resp.StatusCode) } return nil } func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { for m := range msgs { err := cl.publish(ctx, m) if err != nil { return err } } return nil } func (cl *client) nextMessage() (string, error) { typ, b, err := cl.c.Read(context.Background()) if err != nil { return "", err } if typ != websocket.MessageText { cl.c.Close(websocket.StatusUnsupportedData, "expected text message") return "", fmt.Errorf("expected text message but got %v", typ) } return string(b), nil } func (cl *client) Close() error { return cl.c.Close(websocket.StatusNormalClosure, "") } // randString generates a random string with length n. func randString(n int) string { b := make([]byte, n) _, err := rand.Reader.Read(b) if err != nil { panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) } s := strings.ToValidUTF8(string(b), "_") s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } if len(s) < n { // Pad with = extra := n - len(s) return s + strings.Repeat("=", extra) } return s } // randInt returns a randomly generated integer between [0, max). func randInt(max int) int { x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { panic(fmt.Sprintf("failed to get random int: %v", err)) } return int(x.Int64()) } websocket-1.8.7/examples/chat/index.css 0000664 0000000 0000000 00000002243 14033353662 0020101 0 ustar 00root root 0000000 0000000 body { width: 100vw; min-width: 320px; } #root { padding: 40px 20px; max-width: 600px; margin: auto; height: 100vh; display: flex; flex-direction: column; align-items: center; justify-content: center; } #root > * + * { margin: 20px 0 0 0; } /* 100vh on safari does not include the bottom bar. */ @supports (-webkit-overflow-scrolling: touch) { #root { height: 85vh; } } #message-log { width: 100%; flex-grow: 1; overflow: auto; } #message-log p:first-child { margin: 0; } #message-log > * + * { margin: 10px 0 0 0; } #publish-form-container { width: 100%; } #publish-form { width: 100%; display: flex; height: 40px; } #publish-form > * + * { margin: 0 0 0 10px; } #publish-form input[type="text"] { flex-grow: 1; -moz-appearance: none; -webkit-appearance: none; word-break: normal; border-radius: 5px; border: 1px solid #ccc; } #publish-form input[type="submit"] { color: white; background-color: black; border-radius: 5px; padding: 5px 10px; border: none; } #publish-form input[type="submit"]:hover { background-color: red; } #publish-form input[type="submit"]:active { background-color: red; } websocket-1.8.7/examples/chat/index.html 0000664 0000000 0000000 00000001515 14033353662 0020256 0 ustar 00root root 0000000 0000000