pax_global_header00006660000000000000000000000064140333536620014517gustar00rootroot0000000000000052 comment=3604edcb857415cb2c1213d63328cdcd738f2328 websocket-1.8.7/000077500000000000000000000000001403335366200135225ustar00rootroot00000000000000websocket-1.8.7/.github/000077500000000000000000000000001403335366200150625ustar00rootroot00000000000000websocket-1.8.7/.github/CODEOWNERS000066400000000000000000000000121403335366200164460ustar00rootroot00000000000000* @nhooyr websocket-1.8.7/.github/FUNDING.yml000066400000000000000000000000171403335366200166750ustar00rootroot00000000000000github: nhooyr websocket-1.8.7/.github/workflows/000077500000000000000000000000001403335366200171175ustar00rootroot00000000000000websocket-1.8.7/.github/workflows/ci.yaml000066400000000000000000000015721403335366200204030ustar00rootroot00000000000000name: ci on: [push, pull_request] jobs: fmt: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/fmt.sh uses: ./ci/container with: args: ./ci/fmt.sh lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/lint.sh uses: ./ci/container with: args: ./ci/lint.sh test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - name: Run ./ci/test.sh uses: ./ci/container with: args: ./ci/test.sh env: NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} NETLIFY_SITE_ID: 9b3ee4dc-8297-4774-b4b9-a61561fbbce7 - name: Upload coverage.html uses: actions/upload-artifact@v2 with: name: coverage.html path: ./ci/out/coverage.html websocket-1.8.7/.gitignore000066400000000000000000000000171403335366200155100ustar00rootroot00000000000000websocket.test websocket-1.8.7/LICENSE.txt000066400000000000000000000020541403335366200153460ustar00rootroot00000000000000MIT License Copyright (c) 2018 Anmol Sethi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. websocket-1.8.7/README.md000066400000000000000000000122161403335366200150030ustar00rootroot00000000000000# websocket [![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) [![coverage](https://img.shields.io/badge/coverage-88%25-success)](https://nhooyrio-websocket-coverage.netlify.app) websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```bash go get nhooyr.io/websocket ``` ## Highlights - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) ## Examples For a production quality example that demonstrates the complete API, see the [echo example](./examples/echo). For a full stack example, see the [chat example](./examples/chat). ### Server ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { // ... } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { // ... } log.Printf("received: %v", v) c.Close(websocket.StatusNormalClosure, "") }) ``` ### Client ```go ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } defer c.Close(websocket.StatusInternalError, "the sky is falling") err = wsjson.Write(ctx, c, "hi") if err != nil { // ... } c.Close(websocket.StatusNormalClosure, "") ``` ## Comparison ### gorilla/websocket Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) Advantages of nhooyr.io/websocket: - Minimal and idiomatic API - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) - [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning to nhooyr.io/websocket. #### gobwas/ws [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. websocket-1.8.7/accept.go000066400000000000000000000252621403335366200153170ustar00rootroot00000000000000// +build !js package websocket import ( "bytes" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // // You probably want to use OriginPatterns instead. InsecureSkipVerify bool // OriginPatterns lists the host patterns for authorized origins. // The request host is always authorized. // Use this to enable cross origin WebSockets. // // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. // In such a case, example.com is the origin and chat.example.com is the request host. // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host // with filepath.Match. // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. // // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead // to bring attention to the danger of such a setting. OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // // Accept will not allow cross origin requests by default. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") if opts == nil { opts = &AcceptOptions{} } opts = &*opts errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { if errors.Is(err, filepath.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } } hj, ok := w.(http.Hijacker) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } copts, err := acceptCompression(r, w, opts.CompressionMode) if err != nil { return nil, err } w.WriteHeader(http.StatusSwitchingProtocols) // See https://github.com/nhooyr/websocket/issues/166 if ginWriter, ok := w.(interface { WriteHeaderNow() }); ok { ginWriter.WriteHeaderNow() } netConn, brw, err := hj.Hijack() if err != nil { err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } // https://github.com/golang/go/issues/32314 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, copts: copts, flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return 0, nil } func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") if origin == "" { return nil } u, err := url.Parse(origin) if err != nil { return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if strings.EqualFold(r.Host, u.Host) { return nil } for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } if matched { return nil } } return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } func match(pattern, s string) (bool, error) { return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { return cp } } } return "" } func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { if mode == CompressionDisabled { return nil, nil } for _, ext := range websocketExtensions(r.Header) { switch ext.name { case "permessage-deflate": return acceptDeflate(w, ext, mode) // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 // case "x-webkit-deflate-frame": // return acceptWebkitDeflate(w, ext, mode) } } return nil, nil } func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } if strings.HasPrefix(p, "client_max_window_bits") { // We cannot adjust the read sliding window so cannot make use of this. continue } err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } copts.setHeader(w.Header()) return copts, nil } func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. copts.serverNoContextTakeover = false for _, p := range ext.params { if p == "no_context_takeover" { copts.serverNoContextTakeover = true continue } // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead // of ignoring it as the draft spec is unclear. It says the server can ignore it // but the server has no way of signalling to the client it was ignored as the parameters // are set one way. // Thus us ignoring it would make the client think we understood it which would cause issues. // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } s := "x-webkit-deflate-frame" if copts.clientNoContextTakeover { s += "; no_context_takeover" } w.Header().Set("Sec-WebSocket-Extensions", s) return copts, nil } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { if strings.EqualFold(t, token) { return true } } return false } type websocketExtension struct { name string params []string } func websocketExtensions(h http.Header) []websocketExtension { var exts []websocketExtension extStrs := headerTokens(h, "Sec-WebSocket-Extensions") for _, extStr := range extStrs { if extStr == "" { continue } vals := strings.Split(extStr, ";") for i := range vals { vals[i] = strings.TrimSpace(vals[i]) } e := websocketExtension{ name: vals[0], params: vals[1:], } exts = append(exts, e) } return exts } func headerTokens(h http.Header, key string) []string { key = textproto.CanonicalMIMEHeaderKey(key) var tokens []string for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { t = strings.TrimSpace(t) tokens = append(tokens, t) } } return tokens } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } websocket-1.8.7/accept_js.go000066400000000000000000000007031403335366200160040ustar00rootroot00000000000000package websocket import ( "errors" "net/http" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { Subprotocols []string InsecureSkipVerify bool OriginPatterns []string CompressionMode CompressionMode CompressionThreshold int } // Accept is stubbed out for Wasm. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return nil, errors.New("unimplemented") } websocket-1.8.7/accept_test.go000066400000000000000000000231171403335366200163530ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "errors" "net" "net/http" "net/http/httptest" "strings" "testing" "nhooyr.io/websocket/internal/test/assert" ) func TestAccept(t *testing.T) { t.Parallel() t.Run("badClientHandshake", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) assert.Contains(t, err, "protocol violation") }) t.Run("badOrigin", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) }) t.Run("badCompression", func(t *testing.T) { t.Parallel() w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), } r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") _, err := Accept(w, r, nil) assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) }) t.Run("badHijack", func(t *testing.T) { t.Parallel() w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { return nil, nil, errors.New("haha") }, } r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) } func Test_verifyClientHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string method string http1 bool h map[string]string success bool }{ { name: "badConnection", h: map[string]string{ "Connection": "notUpgrade", }, }, { name: "badUpgrade", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "notWebSocket", }, }, { name: "badMethod", method: "POST", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", }, }, { name: "badWebSocketVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "14", }, }, { name: "badWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "", }, }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "meow123", }, http1: true, }, { name: "success", h: map[string]string{ "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "meow123", }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(tc.method, "/", nil) r.ProtoMajor = 1 r.ProtoMinor = 1 if tc.http1 { r.ProtoMinor = 0 } for k, v := range tc.h { r.Header.Set(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_selectSubprotocol(t *testing.T) { t.Parallel() testCases := []struct { name string clientProtocols []string serverProtocols []string negotiated string }{ { name: "empty", clientProtocols: nil, serverProtocols: nil, negotiated: "", }, { name: "basic", clientProtocols: []string{"echo", "echo2"}, serverProtocols: []string{"echo2", "echo"}, negotiated: "echo2", }, { name: "none", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo4"}, negotiated: "", }, { name: "fallback", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, { name: "clientCasePresered", clientProtocols: []string{"Echo1"}, serverProtocols: []string{"echo1"}, negotiated: "Echo1", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { name string origin string host string originPatterns []string success bool }{ { name: "none", success: true, host: "example.com", }, { name: "invalid", origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", host: "example.com", success: false, }, { name: "unauthorized", origin: "https://example.com", host: "example1.com", success: false, }, { name: "authorized", origin: "https://example.com", host: "example.com", success: true, }, { name: "authorizedCaseInsensitive", origin: "https://examplE.com", host: "example.com", success: true, }, { name: "originPatterns", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "*.example.com", "bar.com", }, success: true, }, { name: "originPatternsUnauthorized", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "exam3.com", "bar.com", }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_acceptCompression(t *testing.T) { t.Parallel() testCases := []struct { name string mode CompressionMode reqSecWebSocketExtensions string respSecWebSocketExtensions string expCopts *compressionOptions error bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, }, { name: "permessage-deflate", mode: CompressionNoContextTakeover, reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, }, { name: "permessage-deflate/error", mode: CompressionNoContextTakeover, reqSecWebSocketExtensions: "permessage-deflate; meow", error: true, }, // { // name: "x-webkit-deflate-frame", // mode: CompressionNoContextTakeover, // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", // expCopts: &compressionOptions{ // clientNoContextTakeover: true, // serverNoContextTakeover: true, // }, // }, // { // name: "x-webkit-deflate/error", // mode: CompressionNoContextTakeover, // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", // error: true, // }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) w := httptest.NewRecorder() copts, err := acceptCompression(r, w, tc.mode) if tc.error { assert.Error(t, err) return } assert.Success(t, err) assert.Equal(t, "compression options", tc.expCopts, copts) assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } type mockHijacker struct { http.ResponseWriter hijack func() (net.Conn, *bufio.ReadWriter, error) } var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } websocket-1.8.7/autobahn_test.go000066400000000000000000000124071403335366200167150ustar00rootroot00000000000000// +build !js package websocket_test import ( "context" "encoding/json" "fmt" "io/ioutil" "net" "os" "os/exec" "strconv" "strings" "testing" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" ) var excludedAutobahnCases = []string{ // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just // more performance overhead. "6.*", "7.5.1", // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } var autobahnCases = []string{"*"} func TestAutobahn(t *testing.T) { t.Parallel() if os.Getenv("AUTOBAHN_TEST") == "" { t.SkipNow() } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) assert.Success(t, err) defer closeFn() err = waitWS(ctx, wstestURL) assert.Success(t, err) cases, err := wstestCaseCount(ctx, wstestURL) assert.Success(t, err) t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) } }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") } func waitWS(ctx context.Context, url string) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() for ctx.Err() == nil { c, _, err := websocket.Dial(ctx, url, nil) if err != nil { continue } c.Close(websocket.StatusNormalClosure, "") return nil } return ctx.Err() } func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { serverAddr, err := unusedListenAddr() if err != nil { return "", nil, err } url = "ws://" + serverAddr specFile, err := tempJSONFile(map[string]interface{}{ "url": url, "outdir": "ci/out/wstestClientReports", "cases": autobahnCases, "exclude-cases": excludedAutobahnCases, }) if err != nil { return "", nil, fmt.Errorf("failed to write spec: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer func() { if err != nil { cancel() } }() args := []string{"--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", } wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { return "", nil, fmt.Errorf("failed to start wstest: %w", err) } return url, func() { wstest.Process.Kill() }, nil } func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { defer errd.Wrap(&err, "failed to get case count") c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) if err != nil { return 0, err } defer c.Close(websocket.StatusInternalError, "") _, r, err := c.Reader(ctx) if err != nil { return 0, err } b, err := ioutil.ReadAll(r) if err != nil { return 0, err } cases, err = strconv.Atoi(string(b)) if err != nil { return 0, err } c.Close(websocket.StatusNormalClosure, "") return cases, nil } func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) assert.Success(t, err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) assert.Success(t, err) for _, tests := range indexJSON { for test, result := range tests { t.Run(test, func(t *testing.T) { switch result.BehaviorClose { case "OK", "INFORMATIONAL": default: t.Errorf("bad close behaviour") } switch result.Behavior { case "OK", "NON-STRICT", "INFORMATIONAL": default: t.Errorf("failed") } }) } } if t.Failed() { htmlPath := strings.Replace(path, ".json", ".html", 1) t.Errorf("detected autobahn violation, see %q", htmlPath) } } func unusedListenAddr() (_ string, err error) { defer errd.Wrap(&err, "failed to get unused listen address") l, err := net.Listen("tcp", "localhost:0") if err != nil { return "", err } l.Close() return l.Addr().String(), nil } func tempJSONFile(v interface{}) (string, error) { f, err := ioutil.TempFile("", "temp.json") if err != nil { return "", fmt.Errorf("temp file: %w", err) } defer f.Close() e := json.NewEncoder(f) e.SetIndent("", "\t") err = e.Encode(v) if err != nil { return "", fmt.Errorf("json encode: %w", err) } err = f.Close() if err != nil { return "", fmt.Errorf("close temp file: %w", err) } return f.Name(), nil } websocket-1.8.7/ci/000077500000000000000000000000001403335366200141155ustar00rootroot00000000000000websocket-1.8.7/ci/all.sh000077500000000000000000000002111403335366200152160ustar00rootroot00000000000000#!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." ./ci/fmt.sh ./ci/lint.sh ./ci/test.sh "$@" } main "$@" websocket-1.8.7/ci/container/000077500000000000000000000000001403335366200160775ustar00rootroot00000000000000websocket-1.8.7/ci/container/Dockerfile000066400000000000000000000006231403335366200200720ustar00rootroot00000000000000FROM golang RUN apt-get update RUN apt-get install -y npm shellcheck chromium ENV GO111MODULE=on RUN go get golang.org/x/tools/cmd/goimports RUN go get mvdan.cc/sh/v3/cmd/shfmt RUN go get golang.org/x/tools/cmd/stringer RUN go get golang.org/x/lint/golint RUN go get github.com/agnivade/wasmbrowsertest RUN npm --unsafe-perm=true install -g prettier RUN npm --unsafe-perm=true install -g netlify-cli websocket-1.8.7/ci/fmt.sh000077500000000000000000000013531403335366200152440ustar00rootroot00000000000000#!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go mod tidy gofmt -w -s . goimports -w "-local=$(go list -m)" . prettier \ --write \ --print-width=120 \ --no-semi \ --trailing-comma=all \ --loglevel=warn \ --arrow-parens=avoid \ $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") shfmt -i 2 -w -s -sr $(git ls-files "*.sh") stringer -type=opcode,MessageType,StatusCode -output=stringer.go if [[ ${CI-} ]]; then ensure_fmt fi } ensure_fmt() { if [[ $(git ls-files --other --modified --exclude-standard) ]]; then git -c color.ui=always --no-pager diff echo echo "Please run the following locally:" echo " ./ci/fmt.sh" exit 1 fi } main "$@" websocket-1.8.7/ci/lint.sh000077500000000000000000000004251403335366200154230ustar00rootroot00000000000000#!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go vet ./... GOOS=js GOARCH=wasm go vet ./... golint -set_exit_status ./... GOOS=js GOARCH=wasm golint -set_exit_status ./... shellcheck --exclude=SC2046 $(git ls-files "*.sh") } main "$@" websocket-1.8.7/ci/out/000077500000000000000000000000001403335366200147245ustar00rootroot00000000000000websocket-1.8.7/ci/out/.gitignore000066400000000000000000000000021403335366200167040ustar00rootroot00000000000000* websocket-1.8.7/ci/test.sh000077500000000000000000000013201403335366200154270ustar00rootroot00000000000000#!/usr/bin/env bash set -euo pipefail main() { cd "$(dirname "$0")/.." go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof sed -i '/examples/d' ci/out/coverage.prof # Last line is the total coverage. go tool cover -func ci/out/coverage.prof | tail -n1 go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html if [[ ${CI-} && ${GITHUB_REF-} == *master ]]; then local deployDir deployDir="$(mktemp -d)" cp ci/out/coverage.html "$deployDir/index.html" netlify deploy --prod "--dir=$deployDir" fi } main "$@" websocket-1.8.7/close.go000066400000000000000000000045261403335366200151650ustar00rootroot00000000000000package websocket import ( "errors" "fmt" ) // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 } websocket-1.8.7/close_notjs.go000066400000000000000000000107501403335366200163760ustar00rootroot00000000000000// +build !js package websocket import ( "context" "encoding/binary" "errors" "fmt" "log" "time" "nhooyr.io/websocket/internal/errd" ) // Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. // All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // // The maximum length of reason must be 125 bytes. Avoid // sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { return c.closeHandshake(code, reason) } func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") writeErr := c.writeClose(code, reason) closeHandshakeErr := c.waitCloseHandshake() if writeErr != nil { return writeErr } if CloseStatus(closeHandshakeErr) == -1 { return closeHandshakeErr } return nil } var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if wroteClose { return errAlreadyWroteClose } ce := CloseError{ Code: code, Reason: reason, } var p []byte var marshalErr error if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() if marshalErr != nil { log.Printf("websocket: %v", marshalErr) } } writeErr := c.writeControl(context.Background(), opClose, p) if CloseStatus(writeErr) != -1 { // Not a real error if it's due to a close frame being received. writeErr = nil } // We do this after in case there was an error writing the close frame. c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) if marshalErr != nil { return marshalErr } return writeErr } func (c *Conn) waitCloseHandshake() error { defer c.close(nil) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) if err != nil { return err } defer c.readMu.unlock() if c.readCloseFrameErr != nil { return c.readCloseFrameErr } for { h, err := c.readLoop(ctx) if err != nil { return err } for i := int64(0); i < h.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } } } func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ Code: StatusNoStatusRcvd, }, nil } if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ Code: StatusCode(binary.BigEndian.Uint16(p)), Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { switch code { case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: return false } if code >= StatusNormalClosure && code <= StatusBadGateway { return true } if code >= 3000 && code <= 4999 { return true } return false } func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } return p, err } const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() c.setCloseErrLocked(err) c.closeMu.Unlock() } func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } websocket-1.8.7/close_test.go000066400000000000000000000063631403335366200162250ustar00rootroot00000000000000// +build !js package websocket import ( "io" "math" "strings" "testing" "nhooyr.io/websocket/internal/test/assert" ) func TestCloseError(t *testing.T) { t.Parallel() testCases := []struct { name string ce CloseError success bool }{ { name: "normal", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason), }, success: true, }, { name: "bigReason", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason+1), }, success: false, }, { name: "bigCode", ce: CloseError{ Code: math.MaxUint16, Reason: strings.Repeat("x", maxCloseReason), }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } t.Run("Error", func(t *testing.T) { exp := `status = StatusInternalError and reason = "meow"` act := CloseError{ Code: StatusInternalError, Reason: "meow", }.Error() assert.Equal(t, "CloseError.Error()", exp, act) }) } func Test_parseClosePayload(t *testing.T) { t.Parallel() testCases := []struct { name string p []byte success bool ce CloseError }{ { name: "normal", p: append([]byte{0x3, 0xE8}, []byte("hello")...), success: true, ce: CloseError{ Code: StatusNormalClosure, Reason: "hello", }, }, { name: "nothing", success: true, ce: CloseError{ Code: StatusNoStatusRcvd, }, }, { name: "oneByte", p: []byte{0}, success: false, }, { name: "badStatusCode", p: []byte{0x17, 0x70}, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ce, err := parseClosePayload(tc.p) if tc.success { assert.Success(t, err) assert.Equal(t, "close payload", tc.ce, ce) } else { assert.Error(t, err) } }) } } func Test_validWireCloseCode(t *testing.T) { t.Parallel() testCases := []struct { name string code StatusCode valid bool }{ { name: "normal", code: StatusNormalClosure, valid: true, }, { name: "noStatus", code: StatusNoStatusRcvd, valid: false, }, { name: "3000", code: 3000, valid: true, }, { name: "4999", code: 4999, valid: true, }, { name: "unknown", code: 5000, valid: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := validWireCloseCode(tc.code) assert.Equal(t, "wire close code", tc.valid, act) }) } } func TestCloseStatus(t *testing.T) { t.Parallel() testCases := []struct { name string in error exp StatusCode }{ { name: "nil", in: nil, exp: -1, }, { name: "io.EOF", in: io.EOF, exp: -1, }, { name: "StatusInternalError", in: CloseError{ Code: StatusInternalError, }, exp: StatusInternalError, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := CloseStatus(tc.in) assert.Equal(t, "close status", tc.exp, act) }) } } websocket-1.8.7/compress.go000066400000000000000000000035041403335366200157060ustar00rootroot00000000000000package websocket // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // // A compatibility layer is implemented for the older deflate-frame extension used // by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 // It will work the same in every way except that we cannot signal to the peer we // want to use no context takeover on our side, we can only signal that they should. // It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 type CompressionMode int const ( // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. CompressionContextTakeover // CompressionDisabled disables the deflate extension. // // Use this if you are using a predominantly binary protocol with very // little duplication in between messages or CPU and memory are more // important than bandwidth. CompressionDisabled ) websocket-1.8.7/compress_notjs.go000066400000000000000000000067071403335366200171330ustar00rootroot00000000000000// +build !js package websocket import ( "io" "net/http" "sync" "github.com/klauspost/compress/flate" ) func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool } func (copts *compressionOptions) setHeader(h http.Header) { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" } if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } h.Set("Sec-WebSocket-Extensions", s) } // These bytes are required to get flate.Reader to return. // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { w io.Writer tail []byte } func (tw *trimLastFourBytesWriter) reset() { if tw != nil && tw.tail != nil { tw.tail = tw.tail[:0] } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if tw.tail == nil { tw.tail = make([]byte, 0, 4) } extra := len(tw.tail) + len(p) - 4 if extra <= 0 { tw.tail = append(tw.tail, p...) return len(p), nil } // Now we need to write as many extra bytes as we can from the previous tail. if extra > len(tw.tail) { extra = len(tw.tail) } if extra > 0 { _, err := tw.w.Write(tw.tail[:extra]) if err != nil { return 0, err } // Shift remaining bytes in tail over. n := copy(tw.tail, tw.tail[extra:]) tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, // all of it is is part of the tail. if len(p) <= 4 { tw.tail = append(tw.tail, p...) return len(p), nil } // Otherwise, only the last 4 bytes are. tw.tail = append(tw.tail, p[len(p)-4:]...) p = p[:len(p)-4] n, err := tw.w.Write(p) return n + 4, err } var flateReaderPool sync.Pool func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { return flate.NewReaderDict(r, dict) } fr.(flate.Resetter).Reset(r, dict) return fr } func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } type slidingWindow struct { buf []byte } var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} func slidingWindowPool(n int) *sync.Pool { swPoolMu.RLock() p, ok := swPool[n] swPoolMu.RUnlock() if ok { return p } p = &sync.Pool{} swPoolMu.Lock() swPool[n] = p swPoolMu.Unlock() return p } func (sw *slidingWindow) init(n int) { if sw.buf != nil { return } if n == 0 { n = 32768 } p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { sw.buf = buf[:0] } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { if sw.buf == nil { return } swPoolMu.Lock() swPool[cap(sw.buf)].Put(sw.buf) swPoolMu.Unlock() sw.buf = nil } func (sw *slidingWindow) write(p []byte) { if len(p) >= cap(sw.buf) { sw.buf = sw.buf[:cap(sw.buf)] p = p[len(p)-cap(sw.buf):] copy(sw.buf, p) return } left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left copy(sw.buf, sw.buf[spaceNeeded:]) sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } sw.buf = append(sw.buf, p...) } websocket-1.8.7/compress_test.go000066400000000000000000000012701403335366200167430ustar00rootroot00000000000000// +build !js package websocket import ( "strings" "testing" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { t.Parallel() const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { t.Run("", func(t *testing.T) { t.Parallel() input := xrand.String(maxWindow) windowLength := xrand.Int(maxWindow) var sw slidingWindow sw.init(windowLength) sw.write([]byte(input)) assert.Equal(t, "window length", windowLength, cap(sw.buf)) if !strings.HasSuffix(input, string(sw.buf)) { t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) } }) } } websocket-1.8.7/conn.go000066400000000000000000000005511403335366200150070ustar00rootroot00000000000000package websocket // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int // MessageType constants. const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary ) websocket-1.8.7/conn_notjs.go000066400000000000000000000124441403335366200162300ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "runtime" "strconv" "sync" "sync/atomic" ) // Conn represents a WebSocket connection. // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control // frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. // // On any error from any method, the connection is closed // with an appropriate reason. type Conn struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context // Read state. readMu *mu readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. msgWriterState *msgWriterState writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte writeHeader header closed chan struct{} closeMu sync.Mutex closeErr error wroteClose bool pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } type connConfig struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer } func newConn(cfg connConfig) *Conn { c := &Conn{ subprotocol: cfg.subprotocol, rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } c.readMu = newMu(c) c.writeFrameMu = newMu(c) c.msgReader = newMsgReader(c) c.msgWriterState = newMsgWriterState(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } } runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) go c.timeoutLoop() return c } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } func (c *Conn) close(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { return } c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. c.rwc.Close() go func() { c.msgWriterState.close() c.msgReader.close() }() } func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() for { select { case <-c.closed: return case writeCtx = <-c.writeTimeout: case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } } func (c *Conn) flate() bool { return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does // not read from the connection but instead waits for a Reader call // to read the pong. // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { p := atomic.AddInt32(&c.pingCounter, 1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } return nil } func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}, 1) c.activePingsMu.Lock() c.activePings[p] = pong c.activePingsMu.Unlock() defer func() { c.activePingsMu.Lock() delete(c.activePings, p) c.activePingsMu.Unlock() }() err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { case <-c.closed: return c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: return nil } } type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected // over the receive on closed. select { case <-m.c.closed: // Make sure to release. m.unlock() return m.c.closeErr default: } return nil } } func (m *mu) unlock() { select { case <-m.ch: default: } } websocket-1.8.7/conn_test.go000066400000000000000000000273051403335366200160540ustar00rootroot00000000000000// +build !js package websocket_test import ( "bytes" "context" "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "os" "os/exec" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wspb" ) func TestConn(t *testing.T) { t.Parallel() t.Run("fuzzData", func(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) } for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }, &websocket.AcceptOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) defer tt.cleanup() tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { err := wstest.Echo(tt.ctx, c1, 131072) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } }) t.Run("badClose", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) defer tt.cleanup() err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { err := c1.Ping(tt.ctx) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() c2.CloseRead(tt.ctx) ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) defer cancel() err := c1.Ping(ctx) assert.Contains(t, err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) const count = 100 errs := make(chan error, count) for i := 0; i < count; i++ { go func() { select { case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg): case <-tt.ctx.Done(): return } }() } for i := 0; i < count; i++ { select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) assert.Equal(t, "write error", context.DeadlineExceeded, err) }) t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) // Does not give any confidence but at least ensures no crashes. d, _ := tt.ctx.Deadline() n1.SetDeadline(d) n1.SetDeadline(time.Time{}) assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) if err != nil { return err } return n2.Close() }) b, err := ioutil.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } assert.Equal(t, "read msg", []byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) if err != nil { return err } return nil }) _, err := ioutil.ReadAll(n1) assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } }) t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) select { case err := <-werr: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("wspb", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() tt.goEchoLoop(c2) exp := ptypes.DurationProto(100) err := wspb.Write(tt.ctx, c1, exp) assert.Success(t, err) act := &duration.Duration{} err = wspb.Read(tt.ctx, c1, act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } func TestWasm(t *testing.T) { t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, }) if err != nil { t.Error(err) } })) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } } func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { return fmt.Errorf("expected close status %v but got %v", exp, err) } return nil } type connTest struct { t testing.TB ctx context.Context doneFuncs []func() } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { if t, ok := t.(*testing.T); ok { t.Parallel() } t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} tt.appendDone(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } tt.appendDone(func() { c2.Close(websocket.StatusInternalError, "") c1.Close(websocket.StatusInternalError, "") }) return tt, c1, c2 } func (tt *connTest) appendDone(f func()) { tt.doneFuncs = append(tt.doneFuncs, f) } func (tt *connTest) cleanup() { for i := len(tt.doneFuncs) - 1; i >= 0; i-- { tt.doneFuncs[i]() } } func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) echoLoopErr := xsync.Go(func() error { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) tt.appendDone(func() { cancel() err := <-echoLoopErr if err != nil { tt.t.Errorf("echo loop error: %v", err) } }) } func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { defer c.Close(websocket.StatusInternalError, "") for { _, _, err := c.Read(ctx) if err != nil { return assertCloseStatus(websocket.StatusNormalClosure, err) } } }) tt.appendDone(func() { cancel() err := <-discardLoopErr if err != nil { tt.t.Errorf("discard loop error: %v", err) } }) } func BenchmarkConn(b *testing.B) { var benchCases = []struct { name string mode websocket.CompressionMode }{ { name: "disabledCompress", mode: websocket.CompressionDisabled, }, { name: "compress", mode: websocket.CompressionContextTakeover, }, { name: "compressNoContext", mode: websocket.CompressionNoContextTakeover, }, } for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ CompressionMode: bc.mode, }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) defer bb.cleanup() bb.goEchoLoop(c2) bytesWritten := c1.RecordBytesWritten() bytesRead := c1.RecordBytesRead() msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) defer close(writes) werrs := make(chan error) go func() { for range writes { select { case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg): case <-bb.ctx.Done(): return } } }() b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { select { case writes <- struct{}{}: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } typ, r, err := c1.Reader(bb.ctx) if err != nil { b.Fatal(err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) } _, err = io.ReadFull(r, readBuf) if err != nil { b.Fatal(err) } n2, err := r.Read(readBuf) if err != io.EOF { assert.Equal(b, "read err", io.EOF, err) } if n2 != 0 { assert.Equal(b, "n2", 0, n2) } if !bytes.Equal(msg, readBuf) { assert.Equal(b, "msg", msg, readBuf) } select { case err = <-werrs: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } if err != nil { b.Fatal(err) } } b.StopTimer() b.ReportMetric(float64(*bytesWritten/b.N), "written/op") b.ReportMetric(float64(*bytesRead/b.N), "read/op") err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(b, err) }) } } func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { defer errd.Wrap(&err, "echo server failed") c, err := websocket.Accept(w, r, opts) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") err = wstest.EchoLoop(r.Context(), c) return assertCloseStatus(websocket.StatusNormalClosure, err) } func TestGin(t *testing.T) { t.Parallel() gin.SetMode(gin.ReleaseMode) r := gin.New() r.GET("/", func(ginCtx *gin.Context) { err := echoServer(ginCtx.Writer, ginCtx.Request, nil) if err != nil { t.Error(err) } }) s := httptest.NewServer(r) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() c, _, err := websocket.Dial(ctx, s.URL, nil) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") err = wsjson.Write(ctx, c, "hello") assert.Success(t, err) var v interface{} err = wsjson.Read(ctx, c, &v) assert.Success(t, err) assert.Equal(t, "read msg", "hello", v) err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } websocket-1.8.7/dial.go000066400000000000000000000173251403335366200147720ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "sync" "time" "nhooyr.io/websocket/internal/errd" ) // DialOptions represents Dial's options. type DialOptions struct { // HTTPClient is used for the connection. // Its Transport must return writable bodies for WebSocket handshakes. // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. // However, you can only read the first 1024 bytes of the body. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") if opts == nil { opts = &DialOptions{} } opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } else if opts.HTTPClient.Timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) defer cancel() newClient := *opts.HTTPClient newClient.Timeout = 0 opts.HTTPClient = &newClient } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } var copts *compressionOptions if opts.CompressionMode != CompressionDisabled { copts = opts.CompressionMode.opts() } resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } respBody := resp.Body resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() }) defer timer.Stop() b, _ := ioutil.ReadAll(r) respBody.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) } }() copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } rwc, ok := respBody.(io.ReadWriteCloser) if !ok { return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, copts: copts, flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { copts.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } func secWebSocketKey(rr io.Reader) (string, error) { if rr == nil { rr = rand.Reader } b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) } err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { proto := resp.Header.Get("Sec-WebSocket-Protocol") if proto == "" { return nil } for _, sp2 := range subprotos { if strings.EqualFold(sp2, proto) { return nil } } return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil } ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts = &*copts for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } var bufioReaderPool sync.Pool func getBufioReader(r io.Reader) *bufio.Reader { br, ok := bufioReaderPool.Get().(*bufio.Reader) if !ok { return bufio.NewReader(r) } br.Reset(r) return br } func putBufioReader(br *bufio.Reader) { bufioReaderPool.Put(br) } var bufioWriterPool sync.Pool func getBufioWriter(w io.Writer) *bufio.Writer { bw, ok := bufioWriterPool.Get().(*bufio.Writer) if !ok { return bufio.NewWriter(w) } bw.Reset(w) return bw } func putBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } websocket-1.8.7/dial_test.go000066400000000000000000000125341403335366200160260ustar00rootroot00000000000000// +build !js package websocket import ( "context" "crypto/rand" "io" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "time" "nhooyr.io/websocket/internal/test/assert" ) func TestBadDials(t *testing.T) { t.Parallel() t.Run("badReq", func(t *testing.T) { t.Parallel() testCases := []struct { name string url string opts *DialOptions rand readerFunc }{ { name: "badURL", url: "://noscheme", }, { name: "badURLScheme", url: "ftp://nhooyr.io", }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", }, { name: "badReader", rand: func(p []byte) (int, error) { return 0, io.EOF }, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() if tc.rand == nil { tc.rand = rand.Reader.Read } _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } }) t.Run("badResponse", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(strings.NewReader("hi")), }, nil }), }) assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") }) t.Run("badBody", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() rt := func(r *http.Request) (*http.Response, error) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, Body: ioutil.NopCloser(strings.NewReader("hi")), }, nil } _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } func Test_verifyServerHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string response func(w http.ResponseWriter) success bool }{ { name: "badStatus", response: func(w http.ResponseWriter) { w.WriteHeader(http.StatusOK) }, success: false, }, { name: "badConnection", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badUpgrade", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketAccept", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Accept", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketProtocol", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Protocol", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedExtension", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedDeflateParam", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "success", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.WriteHeader(http.StatusSwitchingProtocols) }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() tc.response(w) resp := w.Result() r := httptest.NewRequest("GET", "/", nil) key, err := secWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } opts := &DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func mockHTTPClient(fn roundTripperFunc) *http.Client { return &http.Client{ Transport: fn, } } type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } websocket-1.8.7/doc.go000066400000000000000000000017231403335366200146210ustar00rootroot00000000000000// +build !js // Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // // Use Dial to dial a WebSocket server. // // Use Accept to accept a WebSocket client. // // Conn represents the resulting WebSocket connection. // // The examples are the best way to understand how to correctly use the library. // // The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // // Wasm // // The client side supports compiling to Wasm. // It wraps the WebSocket browser API. // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // // Some important caveats to be aware of: // // - Accept always errors out // - Conn.Ping is no-op // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" websocket-1.8.7/example_test.go000066400000000000000000000116401403335366200165450ustar00rootroot00000000000000package websocket_test import ( "context" "log" "net/http" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) func ExampleAccept() { // This handler accepts a WebSocket connection, reads a single JSON // message from the client and then closes the connection. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func ExampleDial() { // Dials a server, writes a single JSON message and then // closes the connection. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") err = wsjson.Write(ctx, c, "hi") if err != nil { log.Fatal(err) } c.Close(websocket.StatusNormalClosure, "") } func ExampleCloseStatus() { // Dials a server and then expects to be disconnected with status code // websocket.StatusNormalClosure. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } func Example_writeOnly() { // This handler demonstrates how to correctly handle a write only WebSocket connection. // i.e you only expect to write messages and do not expect to read any messages. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.Close(websocket.StatusInternalError, "the sky is falling") ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() ctx = c.CloseRead(ctx) t := time.NewTicker(time.Second * 30) defer t.Stop() for { select { case <-ctx.Done(): c.Close(websocket.StatusNormalClosure, "") return case <-t.C: err = wsjson.Write(ctx, c, "hi") if err != nil { log.Println(err) return } } } }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func Example_crossOrigin() { // This handler demonstrates how to safely accept cross origin WebSockets // from the origin example.com. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } // This example demonstrates how to create a WebSocket server // that gracefully exits when sent a signal. // // It starts a WebSocket server that keeps every connection open // for 10 seconds. // If you CTRL+C while a connection is open, it will wait at most 30s // for all connections to terminate before shutting down. // func ExampleGrace() { // fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // c, err := websocket.Accept(w, r, nil) // if err != nil { // log.Println(err) // return // } // defer c.Close(websocket.StatusInternalError, "the sky is falling") // // ctx := c.CloseRead(r.Context()) // select { // case <-ctx.Done(): // case <-time.After(time.Second * 10): // } // // c.Close(websocket.StatusNormalClosure, "") // }) // // var g websocket.Grace // s := &http.Server{ // Handler: g.Handler(fn), // ReadTimeout: time.Second * 15, // WriteTimeout: time.Second * 15, // } // // errc := make(chan error, 1) // go func() { // errc <- s.ListenAndServe() // }() // // sigs := make(chan os.Signal, 1) // signal.Notify(sigs, os.Interrupt) // select { // case err := <-errc: // log.Printf("failed to listen and serve: %v", err) // case sig := <-sigs: // log.Printf("terminating: %v", sig) // } // // ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) // defer cancel() // s.Shutdown(ctx) // g.Shutdown(ctx) // } // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { // https://github.com/nhooyr/websocket/tree/master/examples/chat } // This example demonstrates a echo server. func Example_echo() { // https://github.com/nhooyr/websocket/tree/master/examples/echo } websocket-1.8.7/examples/000077500000000000000000000000001403335366200153405ustar00rootroot00000000000000websocket-1.8.7/examples/README.md000066400000000000000000000001361403335366200166170ustar00rootroot00000000000000# Examples This directory contains more involved examples unsuitable for display with godoc. websocket-1.8.7/examples/chat/000077500000000000000000000000001403335366200162575ustar00rootroot00000000000000websocket-1.8.7/examples/chat/README.md000066400000000000000000000026721403335366200175450ustar00rootroot00000000000000# Chat Example This directory contains a full stack example of a simple chat webapp using nhooyr.io/websocket. ```bash $ cd examples/chat $ go run . localhost:0 listening on http://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser. ![Image of Example](https://i.imgur.com/VwJl9Bh.png) ## Structure The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the DOM with a scrollable div at the top that is populated with new messages as they are broadcast. At the bottom it adds a form to submit messages. The messages are received via the WebSocket `/subscribe` endpoint and published via the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket is so that you can easily publish a message with curl. The server portion is `main.go` and `chat.go` and implements serving the static frontend assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint. The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. There are two automated tests for the server included in `chat_test.go`. The first is a simple one client echo test. It publishes a single message and ensures it's received. The second is a complex concurrency test where 10 clients send 128 unique messages of max 128 bytes concurrently. The test ensures all messages are seen by every client. websocket-1.8.7/examples/chat/chat.go000066400000000000000000000114361403335366200175320ustar00rootroot00000000000000package main import ( "context" "errors" "io/ioutil" "log" "net/http" "sync" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { // subscriberMessageBuffer controls the max number // of messages that can be queued for a subscriber // before it is kicked. // // Defaults to 16. subscriberMessageBuffer int // publishLimiter controls the rate limit applied to the publish endpoint. // // Defaults to one publish every 100ms with a burst of 8. publishLimiter *rate.Limiter // logf controls where logs are sent. // Defaults to log.Printf. logf func(f string, v ...interface{}) // serveMux routes the various endpoints to the appropriate handler. serveMux http.ServeMux subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } // newChatServer constructs a chatServer with the defaults. func newChatServer() *chatServer { cs := &chatServer{ subscriberMessageBuffer: 16, logf: log.Printf, subscribers: make(map[*subscriber]struct{}), publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), } cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) cs.serveMux.HandleFunc("/publish", cs.publishHandler) return cs } // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. type subscriber struct { msgs chan []byte closeSlow func() } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") err = cs.subscribe(r.Context(), c) if errors.Is(err, context.Canceled) { return } if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { return } if err != nil { cs.logf("%v", err) return } } // publishHandler reads the request body with a limit of 8192 bytes and then publishes // the received message. func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } cs.publish(msg) w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. // It creates a subscriber with a buffered msgs chan to give some room to slower // connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) for { select { case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } // publish publishes the msg to all subscribers. // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.Lock() defer cs.subscribersMu.Unlock() cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { case s.msgs <- msg: default: go s.closeSlow() } } } // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } // deleteSubscriber deletes the given subscriber. func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() delete(cs.subscribers, s) cs.subscribersMu.Unlock() } func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return c.Write(ctx, websocket.MessageText, msg) } websocket-1.8.7/examples/chat/chat_test.go000066400000000000000000000134631403335366200205730ustar00rootroot00000000000000package main import ( "context" "crypto/rand" "fmt" "math/big" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) func Test_chatServer(t *testing.T) { t.Parallel() // This is a simple echo test with a single client. // The client sends a message and ensures it receives // it on its WebSocket. t.Run("simple", func(t *testing.T) { t.Parallel() url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() expMsg := randString(512) err = cl.publish(ctx, expMsg) assertSuccess(t, err) msg, err := cl.nextMessage() assertSuccess(t, err) if expMsg != msg { t.Fatalf("expected %v but got %v", expMsg, msg) } }) // This test is a complex concurrency test. // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // // The test verifies that every message is seen by ever client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() const nmessages = 128 const maxMessageSize = 128 const nclients = 16 url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var clients []*client var clientMsgs []map[string]struct{} for i := 0; i < nclients; i++ { cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() clients = append(clients, cl) clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) } allMessages := make(map[string]struct{}) for _, msgs := range clientMsgs { for m := range msgs { allMessages[m] = struct{}{} } } var wg sync.WaitGroup for i, cl := range clients { i := i cl := cl wg.Add(1) go func() { defer wg.Done() err := cl.publishMsgs(ctx, clientMsgs[i]) if err != nil { t.Errorf("client %d failed to publish all messages: %v", i, err) } }() wg.Add(1) go func() { defer wg.Done() err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) if err != nil { t.Errorf("client %d failed to receive all messages: %v", i, err) } }() } wg.Wait() }) } // setupTest sets up chatServer that can be used // via the returned url. // // Defer closeFn to ensure everything is cleaned up at // the end of the test. // // chatServer logs will be logged via t.Logf. func setupTest(t *testing.T) (url string, closeFn func()) { cs := newChatServer() cs.logf = t.Logf // To ensure tests run quickly under even -race. cs.subscriberMessageBuffer = 4096 cs.publishLimiter.SetLimit(rate.Inf) s := httptest.NewServer(cs) return s.URL, func() { s.Close() } } // testAllMessagesReceived ensures that after n reads, all msgs in msgs // have been read. func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { msgs = cloneMessages(msgs) for i := 0; i < n; i++ { msg, err := cl.nextMessage() if err != nil { return err } delete(msgs, msg) } if len(msgs) != 0 { return fmt.Errorf("did not receive all expected messages: %q", msgs) } return nil } func cloneMessages(msgs map[string]struct{}) map[string]struct{} { msgs2 := make(map[string]struct{}, len(msgs)) for m := range msgs { msgs2[m] = struct{}{} } return msgs2 } func randMessages(n, maxMessageLength int) map[string]struct{} { msgs := make(map[string]struct{}) for i := 0; i < n; i++ { m := randString(randInt(maxMessageLength)) if _, ok := msgs[m]; ok { i-- continue } msgs[m] = struct{}{} } return msgs } func assertSuccess(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } type client struct { url string c *websocket.Conn } func newClient(ctx context.Context, url string) (*client, error) { c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) if err != nil { return nil, err } cl := &client{ url: url, c: c, } return cl, nil } func (cl *client) publish(ctx context.Context, msg string) (err error) { defer func() { if err != nil { cl.c.Close(websocket.StatusInternalError, "publish failed") } }() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusAccepted { return fmt.Errorf("publish request failed: %v", resp.StatusCode) } return nil } func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { for m := range msgs { err := cl.publish(ctx, m) if err != nil { return err } } return nil } func (cl *client) nextMessage() (string, error) { typ, b, err := cl.c.Read(context.Background()) if err != nil { return "", err } if typ != websocket.MessageText { cl.c.Close(websocket.StatusUnsupportedData, "expected text message") return "", fmt.Errorf("expected text message but got %v", typ) } return string(b), nil } func (cl *client) Close() error { return cl.c.Close(websocket.StatusNormalClosure, "") } // randString generates a random string with length n. func randString(n int) string { b := make([]byte, n) _, err := rand.Reader.Read(b) if err != nil { panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) } s := strings.ToValidUTF8(string(b), "_") s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } if len(s) < n { // Pad with = extra := n - len(s) return s + strings.Repeat("=", extra) } return s } // randInt returns a randomly generated integer between [0, max). func randInt(max int) int { x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { panic(fmt.Sprintf("failed to get random int: %v", err)) } return int(x.Int64()) } websocket-1.8.7/examples/chat/index.css000066400000000000000000000022431403335366200201010ustar00rootroot00000000000000body { width: 100vw; min-width: 320px; } #root { padding: 40px 20px; max-width: 600px; margin: auto; height: 100vh; display: flex; flex-direction: column; align-items: center; justify-content: center; } #root > * + * { margin: 20px 0 0 0; } /* 100vh on safari does not include the bottom bar. */ @supports (-webkit-overflow-scrolling: touch) { #root { height: 85vh; } } #message-log { width: 100%; flex-grow: 1; overflow: auto; } #message-log p:first-child { margin: 0; } #message-log > * + * { margin: 10px 0 0 0; } #publish-form-container { width: 100%; } #publish-form { width: 100%; display: flex; height: 40px; } #publish-form > * + * { margin: 0 0 0 10px; } #publish-form input[type="text"] { flex-grow: 1; -moz-appearance: none; -webkit-appearance: none; word-break: normal; border-radius: 5px; border: 1px solid #ccc; } #publish-form input[type="submit"] { color: white; background-color: black; border-radius: 5px; padding: 5px 10px; border: none; } #publish-form input[type="submit"]:hover { background-color: red; } #publish-form input[type="submit"]:active { background-color: red; } websocket-1.8.7/examples/chat/index.html000066400000000000000000000015151403335366200202560ustar00rootroot00000000000000 nhooyr.io/websocket - Chat Example
websocket-1.8.7/examples/chat/index.js000066400000000000000000000042701403335366200177270ustar00rootroot00000000000000;(() => { // expectingMessage is set to true // if the user has just submitted a message // and so we should scroll the next message into view when received. let expectingMessage = false function dial() { const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener("close", ev => { appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) if (ev.code !== 1001) { appendLog("Reconnecting in 1s", true) setTimeout(dial, 1000) } }) conn.addEventListener("open", ev => { console.info("websocket connected") }) // This is where we handle messages received. conn.addEventListener("message", ev => { if (typeof ev.data !== "string") { console.error("unexpected message type", typeof ev.data) return } const p = appendLog(ev.data) if (expectingMessage) { p.scrollIntoView() expectingMessage = false } }) } dial() const messageLog = document.getElementById("message-log") const publishForm = document.getElementById("publish-form") const messageInput = document.getElementById("message-input") // appendLog appends the passed text to messageLog. function appendLog(text, error) { const p = document.createElement("p") // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` if (error) { p.style.color = "red" p.style.fontStyle = "bold" } messageLog.append(p) return p } appendLog("Submit a message to get started!") // onsubmit publishes the message from the user when the form is submitted. publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value if (msg === "") { return } messageInput.value = "" expectingMessage = true try { const resp = await fetch("/publish", { method: "POST", body: msg, }) if (resp.status !== 202) { throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) } } catch (err) { appendLog(`Publish failed: ${err.message}`, true) } } })() websocket-1.8.7/examples/chat/main.go000066400000000000000000000020371403335366200175340ustar00rootroot00000000000000package main import ( "context" "errors" "log" "net" "net/http" "os" "os/signal" "time" ) func main() { log.SetFlags(0) err := run() if err != nil { log.Fatal(err) } } // run initializes the chatServer and then // starts a http.Server for the passed in address. func run() error { if len(os.Args) < 2 { return errors.New("please provide an address to listen on as the first argument") } l, err := net.Listen("tcp", os.Args[1]) if err != nil { return err } log.Printf("listening on http://%v", l.Addr()) cs := newChatServer() s := &http.Server{ Handler: cs, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } errc := make(chan error, 1) go func() { errc <- s.Serve(l) }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) select { case err := <-errc: log.Printf("failed to serve: %v", err) case sig := <-sigs: log.Printf("terminating: %v", sig) } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() return s.Shutdown(ctx) } websocket-1.8.7/examples/echo/000077500000000000000000000000001403335366200162565ustar00rootroot00000000000000websocket-1.8.7/examples/echo/README.md000066400000000000000000000012211403335366200175310ustar00rootroot00000000000000# Echo Example This directory contains a echo server example using nhooyr.io/websocket. ```bash $ cd examples/echo $ go run . localhost:0 listening on http://127.0.0.1:51055 ``` You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages written will be echoed back. ## Structure The server is in `server.go` and is implemented as a `http.HandlerFunc` that accepts the WebSocket and then reads all messages and writes them exactly as is back to the connection. `server_test.go` contains a small unit test to verify it works correctly. `main.go` brings it all together so that you can run it and play around with it. websocket-1.8.7/examples/echo/main.go000066400000000000000000000020511403335366200175270ustar00rootroot00000000000000package main import ( "context" "errors" "log" "net" "net/http" "os" "os/signal" "time" ) func main() { log.SetFlags(0) err := run() if err != nil { log.Fatal(err) } } // run starts a http.Server for the passed in address // with all requests handled by echoServer. func run() error { if len(os.Args) < 2 { return errors.New("please provide an address to listen on as the first argument") } l, err := net.Listen("tcp", os.Args[1]) if err != nil { return err } log.Printf("listening on http://%v", l.Addr()) s := &http.Server{ Handler: echoServer{ logf: log.Printf, }, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } errc := make(chan error, 1) go func() { errc <- s.Serve(l) }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) select { case err := <-errc: log.Printf("failed to serve: %v", err) case sig := <-sigs: log.Printf("terminating: %v", sig) } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() return s.Shutdown(ctx) } websocket-1.8.7/examples/echo/server.go000066400000000000000000000032431403335366200201150ustar00rootroot00000000000000package main import ( "context" "fmt" "io" "net/http" "time" "golang.org/x/time/rate" "nhooyr.io/websocket" ) // echoServer is the WebSocket echo server implementation. // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. type echoServer struct { // logf controls where logs are sent. logf func(f string, v ...interface{}) } func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { s.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "the sky is falling") if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") return } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) for { err = echo(r.Context(), c, l) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { return } if err != nil { s.logf("failed to echo with %v: %v", r.RemoteAddr, err) return } } } // echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() err := l.Wait(ctx) if err != nil { return err } typ, r, err := c.Reader(ctx) if err != nil { return err } w, err := c.Writer(ctx, typ) if err != nil { return err } _, err = io.Copy(w, r) if err != nil { return fmt.Errorf("failed to io.Copy: %w", err) } err = w.Close() return err } websocket-1.8.7/examples/echo/server_test.go000066400000000000000000000017571403335366200211640ustar00rootroot00000000000000package main import ( "context" "net/http/httptest" "testing" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) // Test_echoServer tests the echoServer by sending it 5 different messages // and ensuring the responses all match. func Test_echoServer(t *testing.T) { t.Parallel() s := httptest.NewServer(echoServer{ logf: t.Logf, }) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) if err != nil { t.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") for i := 0; i < 5; i++ { err = wsjson.Write(ctx, c, map[string]int{ "i": i, }) if err != nil { t.Fatal(err) } v := map[string]int{} err = wsjson.Read(ctx, c, &v) if err != nil { t.Fatal(err) } if v["i"] != i { t.Fatalf("expected %v but got %v", i, v) } } c.Close(websocket.StatusNormalClosure, "") } websocket-1.8.7/export_test.go000066400000000000000000000006541403335366200164360ustar00rootroot00000000000000// +build !js package websocket func (c *Conn) RecordBytesWritten() *int { var bytesWritten int c.bw.Reset(writerFunc(func(p []byte) (int, error) { bytesWritten += len(p) return c.rwc.Write(p) })) return &bytesWritten } func (c *Conn) RecordBytesRead() *int { var bytesRead int c.br.Reset(readerFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) bytesRead += n return n, err })) return &bytesRead } websocket-1.8.7/frame.go000066400000000000000000000165541403335366200151560ustar00rootroot00000000000000package websocket import ( "bufio" "encoding/binary" "fmt" "io" "math" "math/bits" "nhooyr.io/websocket/internal/errd" ) // opcode represents a WebSocket opcode. type opcode int // https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText opBinary // 3 - 7 are reserved for further non-control frames. _ _ _ _ _ opClose opPing opPong // 11-16 are reserved for further control frames. ) // header represents a WebSocket frame header. // See https://tools.ietf.org/html/rfc6455#section-5.2. type header struct { fin bool rsv1 bool rsv2 bool rsv3 bool opcode opcode payloadLength int64 masked bool maskKey uint32 } // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { return header{}, err } h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 h.rsv3 = b&(1<<4) != 0 h.opcode = opcode(b & 0xf) b, err = r.ReadByte() if err != nil { return header{}, err } h.masked = b&(1<<7) != 0 payloadLength := b &^ (1 << 7) switch { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: _, err = io.ReadFull(r, readBuf[:2]) h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: _, err = io.ReadFull(r, readBuf) h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { return header{}, err } if h.payloadLength < 0 { return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { _, err = io.ReadFull(r, readBuf[:4]) if err != nil { return header{}, err } h.maskKey = binary.LittleEndian.Uint32(readBuf) } return h, nil } // maxControlPayload is the maximum length of a control frame payload. // See https://tools.ietf.org/html/rfc6455#section-5.5. const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte if h.fin { b |= 1 << 7 } if h.rsv1 { b |= 1 << 6 } if h.rsv2 { b |= 1 << 5 } if h.rsv3 { b |= 1 << 4 } b |= byte(h.opcode) err = w.WriteByte(b) if err != nil { return err } lengthByte := byte(0) if h.masked { lengthByte |= 1 << 7 } switch { case h.payloadLength > math.MaxUint16: lengthByte |= 127 case h.payloadLength > 125: lengthByte |= 126 case h.payloadLength >= 0: lengthByte |= byte(h.payloadLength) } err = w.WriteByte(lengthByte) if err != nil { return err } switch { case h.payloadLength > math.MaxUint16: binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) _, err = w.Write(buf) case h.payloadLength > 125: binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { binary.LittleEndian.PutUint32(buf, h.maskKey) _, err = w.Write(buf[:4]) if err != nil { return err } } return nil } // mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // // The returned value is the correctly rotated key to // to continue to mask/unmask the message. // // It is optimized for LittleEndian and expects the key // to be in little endian. // // See https://github.com/golang/go/issues/31586 func mask(key uint32, b []byte) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) // At some point in the future we can clean these unrolled loops up. // See https://github.com/golang/go/issues/31586#issuecomment-487436401 // Then we xor until b is less than 128 bytes. for len(b) >= 128 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) v = binary.LittleEndian.Uint64(b[64:72]) binary.LittleEndian.PutUint64(b[64:72], v^key64) v = binary.LittleEndian.Uint64(b[72:80]) binary.LittleEndian.PutUint64(b[72:80], v^key64) v = binary.LittleEndian.Uint64(b[80:88]) binary.LittleEndian.PutUint64(b[80:88], v^key64) v = binary.LittleEndian.Uint64(b[88:96]) binary.LittleEndian.PutUint64(b[88:96], v^key64) v = binary.LittleEndian.Uint64(b[96:104]) binary.LittleEndian.PutUint64(b[96:104], v^key64) v = binary.LittleEndian.Uint64(b[104:112]) binary.LittleEndian.PutUint64(b[104:112], v^key64) v = binary.LittleEndian.Uint64(b[112:120]) binary.LittleEndian.PutUint64(b[112:120], v^key64) v = binary.LittleEndian.Uint64(b[120:128]) binary.LittleEndian.PutUint64(b[120:128], v^key64) b = b[128:] } // Then we xor until b is less than 64 bytes. for len(b) >= 64 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) b = b[64:] } // Then we xor until b is less than 32 bytes. for len(b) >= 32 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) b = b[32:] } // Then we xor until b is less than 16 bytes. for len(b) >= 16 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) b = b[16:] } // Then we xor until b is less than 8 bytes. for len(b) >= 8 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) b = b[8:] } } // Then we xor until b is less than 4 bytes. for len(b) >= 4 { v := binary.LittleEndian.Uint32(b) binary.LittleEndian.PutUint32(b, v^key) b = b[4:] } // xor remaining bytes. for i := range b { b[i] ^= byte(key) key = bits.RotateLeft32(key, -8) } return key } websocket-1.8.7/frame_test.go000066400000000000000000000062161403335366200162070ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "bytes" "encoding/binary" "math/bits" "math/rand" "strconv" "testing" "time" _ "unsafe" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" "nhooyr.io/websocket/internal/test/assert" ) func TestHeader(t *testing.T) { t.Parallel() t.Run("lengths", func(t *testing.T) { t.Parallel() lengths := []int{ 124, 125, 126, 127, 65534, 65535, 65536, 65537, } for _, n := range lengths { n := n t.Run(strconv.Itoa(n), func(t *testing.T) { t.Parallel() testHeader(t, header{ payloadLength: int64(n), }) }) } }) t.Run("fuzz", func(t *testing.T) { t.Parallel() r := rand.New(rand.NewSource(time.Now().UnixNano())) randBool := func() bool { return r.Intn(1) == 0 } for i := 0; i < 10000; i++ { h := header{ fin: randBool(), rsv1: randBool(), rsv2: randBool(), rsv3: randBool(), opcode: opcode(r.Intn(16)), masked: randBool(), maskKey: r.Uint32(), payloadLength: r.Int63(), } testHeader(t, h) } }) } func testHeader(t *testing.T, h header) { b := &bytes.Buffer{} w := bufio.NewWriter(b) r := bufio.NewReader(b) err := writeFrameHeader(h, w, make([]byte, 8)) assert.Success(t, err) err = w.Flush() assert.Success(t, err) h2, err := readFrameHeader(r, make([]byte, 8)) assert.Success(t, err) assert.Equal(t, "read header", h, h2) } func Test_mask(t *testing.T) { t.Parallel() key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) expP := []byte{0, 0, 0, 0x0d, 0x6} assert.Equal(t, "p", expP, p) expKey32 := bits.RotateLeft32(key32, -8) assert.Equal(t, "key32", expKey32, gotKey32) } func basicMask(maskKey [4]byte, pos int, b []byte) int { for i := range b { b[i] ^= maskKey[pos&3] pos++ } return pos & 3 } //go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes func gorillaMaskBytes(key [4]byte, pos int, b []byte) int func Benchmark_mask(b *testing.B) { sizes := []int{ 2, 3, 4, 8, 16, 32, 128, 512, 4096, 16384, } fns := []struct { name string fn func(b *testing.B, key [4]byte, p []byte) }{ { name: "basic", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { basicMask(key, 0, p) } }, }, { name: "nhooyr", fn: func(b *testing.B, key [4]byte, p []byte) { key32 := binary.LittleEndian.Uint32(key[:]) b.ResetTimer() for i := 0; i < b.N; i++ { mask(key32, p) } }, }, { name: "gorilla", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { gorillaMaskBytes(key, 0, p) } }, }, { name: "gobwas", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { ws.Cipher(p, key, 0) } }, }, } key := [4]byte{1, 2, 3, 4} for _, size := range sizes { p := make([]byte, size) b.Run(strconv.Itoa(size), func(b *testing.B) { for _, fn := range fns { b.Run(fn.name, func(b *testing.B) { b.SetBytes(int64(size)) fn.fn(b, key, p) }) } }) } } websocket-1.8.7/go.mod000066400000000000000000000006531403335366200146340ustar00rootroot00000000000000module nhooyr.io/websocket go 1.13 require ( github.com/gin-gonic/gin v1.6.3 github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 github.com/golang/protobuf v1.3.5 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 github.com/klauspost/compress v1.10.3 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) websocket-1.8.7/go.sum000066400000000000000000000134101403335366200146540ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= websocket-1.8.7/internal/000077500000000000000000000000001403335366200153365ustar00rootroot00000000000000websocket-1.8.7/internal/bpool/000077500000000000000000000000001403335366200164515ustar00rootroot00000000000000websocket-1.8.7/internal/bpool/bpool.go000066400000000000000000000005501403335366200201130ustar00rootroot00000000000000package bpool import ( "bytes" "sync" ) var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() if b == nil { return &bytes.Buffer{} } return b.(*bytes.Buffer) } // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() bpool.Put(b) } websocket-1.8.7/internal/errd/000077500000000000000000000000001403335366200162725ustar00rootroot00000000000000websocket-1.8.7/internal/errd/wrap.go000066400000000000000000000005061403335366200175730ustar00rootroot00000000000000package errd import ( "fmt" ) // Wrap wraps err with fmt.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { *err = fmt.Errorf(f+": %w", append(v, *err)...) } } websocket-1.8.7/internal/test/000077500000000000000000000000001403335366200163155ustar00rootroot00000000000000websocket-1.8.7/internal/test/assert/000077500000000000000000000000001403335366200176165ustar00rootroot00000000000000websocket-1.8.7/internal/test/assert/assert.go000066400000000000000000000020761403335366200214530ustar00rootroot00000000000000package assert import ( "fmt" "reflect" "strings" "testing" "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) // Diff returns a human readable diff between v1 and v2 func Diff(v1, v2 interface{}) string { return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { return true }), cmp.Comparer(proto.Equal)) } // Equal asserts exp == act. func Equal(t testing.TB, name string, exp, act interface{}) { t.Helper() if diff := Diff(exp, act); diff != "" { t.Fatalf("unexpected %v: %v", name, diff) } } // Success asserts err == nil. func Success(t testing.TB, err error) { t.Helper() if err != nil { t.Fatal(err) } } // Error asserts err != nil. func Error(t testing.TB, err error) { t.Helper() if err == nil { t.Fatal("expected error") } } // Contains asserts the fmt.Sprint(v) contains sub. func Contains(t testing.TB, v interface{}, sub string) { t.Helper() s := fmt.Sprint(v) if !strings.Contains(s, sub) { t.Fatalf("expected %q to contain %q", s, sub) } } websocket-1.8.7/internal/test/doc.go000066400000000000000000000001061403335366200174060ustar00rootroot00000000000000// Package test contains subpackages only used in tests. package test websocket-1.8.7/internal/test/wstest/000077500000000000000000000000001403335366200176465ustar00rootroot00000000000000websocket-1.8.7/internal/test/wstest/echo.go000066400000000000000000000032521403335366200211150ustar00rootroot00000000000000package wstest import ( "bytes" "context" "fmt" "io" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" ) // EchoLoop echos every msg received from c until an error // occurs or the context expires. // The read limit is set to 1 << 30. func EchoLoop(ctx context.Context, c *websocket.Conn) error { defer c.Close(websocket.StatusInternalError, "") c.SetReadLimit(1 << 30) ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() b := make([]byte, 32<<10) for { typ, r, err := c.Reader(ctx) if err != nil { return err } w, err := c.Writer(ctx, typ) if err != nil { return err } _, err = io.CopyBuffer(w, r, b) if err != nil { return err } err = w.Close() if err != nil { return err } } } // Echo writes a message and ensures the same is sent back on c. func Echo(ctx context.Context, c *websocket.Conn, max int) error { expType := websocket.MessageBinary if xrand.Bool() { expType = websocket.MessageText } msg := randMessage(expType, xrand.Int(max)) writeErr := xsync.Go(func() error { return c.Write(ctx, expType, msg) }) actType, act, err := c.Read(ctx) if err != nil { return err } err = <-writeErr if err != nil { return err } if expType != actType { return fmt.Errorf("unexpected message typ (%v): %v", expType, actType) } if !bytes.Equal(msg, act) { return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act)) } return nil } func randMessage(typ websocket.MessageType, n int) []byte { if typ == websocket.MessageBinary { return xrand.Bytes(n) } return []byte(xrand.String(n)) } websocket-1.8.7/internal/test/wstest/pipe.go000066400000000000000000000027151403335366200211370ustar00rootroot00000000000000// +build !js package wstest import ( "bufio" "context" "net" "net/http" "net/http/httptest" "nhooyr.io/websocket" ) // Pipe is used to create an in memory connection // between two websockets analogous to net.Pipe. func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) { tt := fakeTransport{ h: func(w http.ResponseWriter, r *http.Request) { serverConn, _ = websocket.Accept(w, r, acceptOpts) }, } if dialOpts == nil { dialOpts = &websocket.DialOptions{} } dialOpts = &*dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts) return clientConn, serverConn } type fakeTransport struct { h http.HandlerFunc } func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { clientConn, serverConn := net.Pipe() hj := testHijacker{ ResponseRecorder: httptest.NewRecorder(), serverConn: serverConn, } t.h.ServeHTTP(hj, r) resp := hj.ResponseRecorder.Result() if resp.StatusCode == http.StatusSwitchingProtocols { resp.Body = clientConn } return resp, nil } type testHijacker struct { *httptest.ResponseRecorder serverConn net.Conn } var _ http.Hijacker = testHijacker{} func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil } websocket-1.8.7/internal/test/xrand/000077500000000000000000000000001403335366200174315ustar00rootroot00000000000000websocket-1.8.7/internal/test/xrand/xrand.go000066400000000000000000000016571403335366200211050ustar00rootroot00000000000000package xrand import ( "crypto/rand" "fmt" "math/big" "strings" ) // Bytes generates random bytes with length n. func Bytes(n int) []byte { b := make([]byte, n) _, err := rand.Reader.Read(b) if err != nil { panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) } return b } // String generates a random string with length n. func String(n int) string { s := strings.ToValidUTF8(string(Bytes(n)), "_") s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } if len(s) < n { // Pad with = extra := n - len(s) return s + strings.Repeat("=", extra) } return s } // Bool returns a randomly generated boolean. func Bool() bool { return Int(2) == 1 } // Int returns a randomly generated integer between [0, max). func Int(max int) int { x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { panic(fmt.Sprintf("failed to get random int: %v", err)) } return int(x.Int64()) } websocket-1.8.7/internal/wsjs/000077500000000000000000000000001403335366200163245ustar00rootroot00000000000000websocket-1.8.7/internal/wsjs/wsjs_js.go000066400000000000000000000077201403335366200203430ustar00rootroot00000000000000// +build js // Package wsjs implements typed access to the browser javascript WebSocket API. // // https://developer.mozilla.org/en-US/docs/Web/API/WebSocket package wsjs import ( "syscall/js" ) func handleJSError(err *error, onErr func()) { r := recover() if jsErr, ok := r.(js.Error); ok { *err = jsErr if onErr != nil { onErr() } return } if r != nil { panic(r) } } // New is a wrapper around the javascript WebSocket constructor. func New(url string, protocols []string) (c WebSocket, err error) { defer handleJSError(&err, func() { c = WebSocket{} }) jsProtocols := make([]interface{}, len(protocols)) for i, p := range protocols { jsProtocols[i] = p } c = WebSocket{ v: js.Global().Get("WebSocket").New(url, jsProtocols), } c.setBinaryType("arraybuffer") return c, nil } // WebSocket is a wrapper around a javascript WebSocket object. type WebSocket struct { v js.Value } func (c WebSocket) setBinaryType(typ string) { c.v.Set("binaryType", string(typ)) } func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { fn(args[0]) return nil }) c.v.Call("addEventListener", eventType, f) return func() { c.v.Call("removeEventListener", eventType, f) f.Release() } } // CloseEvent is the type passed to a WebSocket close handler. type CloseEvent struct { Code uint16 Reason string WasClean bool } // OnClose registers a function to be called when the WebSocket is closed. func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { return c.addEventListener("close", func(e js.Value) { ce := CloseEvent{ Code: uint16(e.Get("code").Int()), Reason: e.Get("reason").String(), WasClean: e.Get("wasClean").Bool(), } fn(ce) }) } // OnError registers a function to be called when there is an error // with the WebSocket. func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { return c.addEventListener("error", fn) } // MessageEvent is the type passed to a message handler. type MessageEvent struct { // string or []byte. Data interface{} // There are more fields to the interface but we don't use them. // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent } // OnMessage registers a function to be called when the WebSocket receives a message. func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { return c.addEventListener("message", func(e js.Value) { var data interface{} arrayBuffer := e.Get("data") if arrayBuffer.Type() == js.TypeString { data = arrayBuffer.String() } else { data = extractArrayBuffer(arrayBuffer) } me := MessageEvent{ Data: data, } fn(me) return }) } // Subprotocol returns the WebSocket subprotocol in use. func (c WebSocket) Subprotocol() string { return c.v.Get("protocol").String() } // OnOpen registers a function to be called when the WebSocket is opened. func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { return c.addEventListener("open", fn) } // Close closes the WebSocket with the given code and reason. func (c WebSocket) Close(code int, reason string) (err error) { defer handleJSError(&err, nil) c.v.Call("close", code, reason) return err } // SendText sends the given string as a text message // on the WebSocket. func (c WebSocket) SendText(v string) (err error) { defer handleJSError(&err, nil) c.v.Call("send", v) return err } // SendBytes sends the given message as a binary message // on the WebSocket. func (c WebSocket) SendBytes(v []byte) (err error) { defer handleJSError(&err, nil) c.v.Call("send", uint8Array(v)) return err } func extractArrayBuffer(arrayBuffer js.Value) []byte { uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer) dst := make([]byte, uint8Array.Length()) js.CopyBytesToGo(dst, uint8Array) return dst } func uint8Array(src []byte) js.Value { uint8Array := js.Global().Get("Uint8Array").New(len(src)) js.CopyBytesToJS(uint8Array, src) return uint8Array } websocket-1.8.7/internal/xsync/000077500000000000000000000000001403335366200165025ustar00rootroot00000000000000websocket-1.8.7/internal/xsync/go.go000066400000000000000000000006001403335366200174320ustar00rootroot00000000000000package xsync import ( "fmt" ) // Go allows running a function in another goroutine // and waiting for its error. func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { r := recover() if r != nil { select { case errs <- fmt.Errorf("panic in go fn: %v", r): default: } } }() errs <- fn() }() return errs } websocket-1.8.7/internal/xsync/go_test.go000066400000000000000000000003511403335366200204740ustar00rootroot00000000000000package xsync import ( "testing" "nhooyr.io/websocket/internal/test/assert" ) func TestGoRecover(t *testing.T) { t.Parallel() errs := Go(func() error { panic("anmol") }) err := <-errs assert.Contains(t, err, "anmol") } websocket-1.8.7/internal/xsync/int64.go000066400000000000000000000006301403335366200177740ustar00rootroot00000000000000package xsync import ( "sync/atomic" ) // Int64 represents an atomic int64. type Int64 struct { // We do not use atomic.Load/StoreInt64 since it does not // work on 32 bit computers but we need 64 bit integers. i atomic.Value } // Load loads the int64. func (v *Int64) Load() int64 { i, _ := v.i.Load().(int64) return i } // Store stores the int64. func (v *Int64) Store(i int64) { v.i.Store(i) } websocket-1.8.7/netconn.go000066400000000000000000000071151403335366200155210ustar00rootroot00000000000000package websocket import ( "context" "fmt" "io" "math" "net" "sync" "time" ) // NetConn converts a *websocket.Conn into a net.Conn. // // It's for tunneling arbitrary protocols over WebSockets. // Few users of the library will need this but it's tricky to implement // correctly and so provided in the library. // See https://github.com/nhooyr/websocket/issues/100. // // Every Write to the net.Conn will correspond to a message write of // the given type on *websocket.Conn. // // The passed ctx bounds the lifetime of the net.Conn. If cancelled, // all reads and writes on the net.Conn will be cancelled. // // If a message is read that is not of the correct type, the connection // will be closed with StatusUnsupportedData and an error will be returned. // // Close will close the *websocket.Conn with StatusNormalClosure. // // When a deadline is hit, the connection will be closed. This is // different from most net.Conn implementations where only the // reading/writing goroutines are interrupted but the connection is kept alive. // // The Addr methods will return a mock net.Addr that returns "websocket" for Network // and "websocket/unknown-addr" for String. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, } var cancel context.CancelFunc nc.writeContext, cancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } nc.readContext, cancel = context.WithCancel(ctx) nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.readTimer.Stop() { <-nc.readTimer.C } return nc } type netConn struct { c *Conn msgType MessageType writeTimer *time.Timer writeContext context.Context readTimer *time.Timer readContext context.Context readMu sync.Mutex eofed bool reader io.Reader } var _ net.Conn = &netConn{} func (c *netConn) Close() error { return c.c.Close(StatusNormalClosure, "") } func (c *netConn) Write(p []byte) (int, error) { err := c.c.Write(c.writeContext, c.msgType, p) if err != nil { return 0, err } return len(p), nil } func (c *netConn) Read(p []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() if c.eofed { return 0, io.EOF } if c.reader == nil { typ, r, err := c.c.Reader(c.readContext) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: c.eofed = true return 0, io.EOF } return 0, err } if typ != c.msgType { err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) c.c.Close(StatusUnsupportedData, err.Error()) return 0, err } c.reader = r } n, err := c.reader.Read(p) if err == io.EOF { c.reader = nil err = nil } return n, err } type websocketAddr struct { } func (a websocketAddr) Network() string { return "websocket" } func (a websocketAddr) String() string { return "websocket/unknown-addr" } func (c *netConn) RemoteAddr() net.Addr { return websocketAddr{} } func (c *netConn) LocalAddr() net.Addr { return websocketAddr{} } func (c *netConn) SetDeadline(t time.Time) error { c.SetWriteDeadline(t) c.SetReadDeadline(t) return nil } func (c *netConn) SetWriteDeadline(t time.Time) error { if t.IsZero() { c.writeTimer.Stop() } else { c.writeTimer.Reset(t.Sub(time.Now())) } return nil } func (c *netConn) SetReadDeadline(t time.Time) error { if t.IsZero() { c.readTimer.Stop() } else { c.readTimer.Reset(t.Sub(time.Now())) } return nil } websocket-1.8.7/read.go000066400000000000000000000243441403335366200147730ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "io/ioutil" "strings" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err } b, err := ioutil.ReadAll(r) return typ, b, err } // CloseRead starts a goroutine to read from the connection until it is closed // or a data message is received. // // Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // // If a data message is received, the connection will be closed with StatusPolicyViolation. // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() c.Reader(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx } // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { // We add read one more byte than the limit in case // there is a fin frame that needs to be read. c.msgReader.limitReader.limit.Store(n + 1) } const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } mr.readFunc = mr.read mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil } } func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() mr.dict.close() if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } if mr.c.client { putBufioReader(mr.c.br) mr.c.br = nil } } func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { return !mr.c.copts.serverNoContextTakeover } return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. if h.opcode != opText && h.opcode != opBinary { return true } return false } func (c *Conn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } } } func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- ctx: } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: return header{}, c.closeErr case <-ctx.Done(): return header{}, ctx.Err() default: c.close(err) return header{}, err } } select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- context.Background(): } return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, c.closeErr case c.readTimeout <- ctx: } n, err := io.ReadFull(c.br, p) if err != nil { select { case <-c.closed: return n, c.closeErr case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } } select { case <-c.closed: return n, c.closeErr case c.readTimeout <- context.Background(): } return n, err } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := c.readControlBuf[:h.payloadLength] _, err = c.readFramePayload(ctx, b) if err != nil { return err } if h.masked { mask(h.maskKey, b) } switch h.opcode { case opPing: return c.writeControl(ctx, opPong, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() if ok { select { case pong <- struct{}{}: default: } } return nil } defer func() { c.readCloseFrameErr = err }() ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) return err } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) if err != nil { return 0, nil, err } defer c.readMu.unlock() if !c.msgReader.fin { err = errors.New("previous message not read to completion") c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } c.msgReader.reset(ctx, h) return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { c *Conn ctx context.Context flate bool flateReader io.Reader flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict slidingWindow fin bool payloadLength int64 maskKey uint32 // readerFunc(mr.Read) to avoid continuous allocations. readFunc readerFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() } mr.setFrame(h) } func (mr *msgReader) setFrame(h header) { mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey } func (mr *msgReader) Read(p []byte) (n int, err error) { err = mr.c.readMu.lock(mr.ctx) if err != nil { return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() n, err = mr.limitReader.Read(p) if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { mr.putFlateReader() return n, io.EOF } if err != nil { err = fmt.Errorf("failed to read: %w", err) mr.c.close(err) } return n, err } func (mr *msgReader) read(p []byte) (int, error) { for { if mr.payloadLength == 0 { if mr.fin { if mr.flate { return mr.flateTail.Read(p) } return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) continue } if int64(len(p)) > mr.payloadLength { p = p[:mr.payloadLength] } n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } return n, nil } } type limitReader struct { c *Conn r io.Reader limit xsync.Int64 n int64 } func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } lr.limit.Store(limit) lr.reset(r) return lr } func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } if int64(len(p)) > lr.n { p = p[:lr.n] } n, err := lr.r.Read(p) lr.n -= int64(n) return n, err } type readerFunc func(p []byte) (int, error) func (f readerFunc) Read(p []byte) (int, error) { return f(p) } websocket-1.8.7/stringer.go000066400000000000000000000055771403335366200157240ustar00rootroot00000000000000// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. package websocket import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[opContinuation-0] _ = x[opText-1] _ = x[opBinary-2] _ = x[opClose-8] _ = x[opPing-9] _ = x[opPong-10] } const ( _opcode_name_0 = "opContinuationopTextopBinary" _opcode_name_1 = "opCloseopPingopPong" ) var ( _opcode_index_0 = [...]uint8{0, 14, 20, 28} _opcode_index_1 = [...]uint8{0, 7, 13, 19} ) func (i opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] case 8 <= i && i <= 10: i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" } } func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[MessageText-1] _ = x[MessageBinary-2] } const _MessageType_name = "MessageTextMessageBinary" var _MessageType_index = [...]uint8{0, 11, 24} func (i MessageType) String() string { i -= 1 if i < 0 || i >= MessageType(len(_MessageType_index)-1) { return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" } return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] } func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[StatusNormalClosure-1000] _ = x[StatusGoingAway-1001] _ = x[StatusProtocolError-1002] _ = x[StatusUnsupportedData-1003] _ = x[statusReserved-1004] _ = x[StatusNoStatusRcvd-1005] _ = x[StatusAbnormalClosure-1006] _ = x[StatusInvalidFramePayloadData-1007] _ = x[StatusPolicyViolation-1008] _ = x[StatusMessageTooBig-1009] _ = x[StatusMandatoryExtension-1010] _ = x[StatusInternalError-1011] _ = x[StatusServiceRestart-1012] _ = x[StatusTryAgainLater-1013] _ = x[StatusBadGateway-1014] _ = x[StatusTLSHandshake-1015] } const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} func (i StatusCode) String() string { i -= 1000 if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" } return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] } websocket-1.8.7/write.go000066400000000000000000000200011403335366200151740ustar00rootroot00000000000000// +build !js package websocket import ( "bufio" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "time" "github.com/klauspost/compress/flate" "nhooyr.io/websocket/internal/errd" ) // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // // You must close the writer once you have written the entire message. // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } return w, nil } // Write writes a message to the connection. // // See the Writer method if you want to stream a message. // // If compression is disabled or the threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } type msgWriter struct { mw *msgWriterState closed bool } func (mw *msgWriter) Write(p []byte) (int, error) { if mw.closed { return 0, errors.New("cannot use closed writer") } return mw.mw.Write(p) } func (mw *msgWriter) Close() error { if mw.closed { return errors.New("cannot use closed writer") } mw.closed = true return mw.mw.Close() } type msgWriterState struct { c *Conn mu *mu writeMu *mu ctx context.Context opcode opcode flate bool trimWriter *trimLastFourBytesWriter dict slidingWindow } func newMsgWriterState(c *Conn) *msgWriterState { mw := &msgWriterState{ c: c, mu: newMu(c), writeMu: newMu(c), } return mw } func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } } mw.dict.init(8192) mw.flate = true } func (mw *msgWriterState) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { err := c.msgWriterState.reset(ctx, typ) if err != nil { return nil, err } return &msgWriter{ mw: c.msgWriterState, closed: false, }, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { mw, err := c.writer(ctx, typ) if err != nil { return 0, err } if !c.flate() { defer c.msgWriterState.mu.unlock() return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) } n, err := mw.Write(p) if err != nil { return n, err } err = mw.Close() return n, err } func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err } mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false mw.trimWriter.reset() return nil } // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) mw.c.close(err) } }() if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { mw.ensureFlate() } } if mw.flate { err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) if err != nil { return 0, err } mw.dict.write(p) return len(p), nil } return mw.write(p) } func (mw *msgWriterState) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil } // Close flushes the frame to the connection. func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) if err != nil { return err } defer mw.writeMu.unlock() _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { mw.dict.close() } mw.mu.unlock() return nil } func (mw *msgWriterState) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() mw.dict.close() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } // frame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } defer c.writeFrameMu.unlock() // If the state says a close has already been written, we wait until // the connection is closed and return that error. // // However, if the frame being written is a close, that means its the close from // the state being set so we let it go through. c.closeMu.Lock() wroteClose := c.wroteClose c.closeMu.Unlock() if wroteClose && opcode != opClose { select { case <-ctx.Done(): return 0, ctx.Err() case <-c.closed: return 0, c.closeErr } } select { case <-c.closed: return 0, c.closeErr case c.writeTimeout <- ctx: } defer func() { if err != nil { select { case <-c.closed: err = c.closeErr case <-ctx.Done(): err = ctx.Err() } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) if c.client { c.writeHeader.masked = true _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false if flate && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err } n, err := c.writeFramePayload(p) if err != nil { return n, err } if c.writeHeader.fin { err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { case <-c.closed: return n, c.closeErr case c.writeTimeout <- context.Background(): } return n, nil } func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. i := c.bw.Buffered() j := len(p) if j > c.bw.Available() { j = c.bw.Available() } _, err := c.bw.Write(p[:j]) if err != nil { return n, err } maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j } return n, nil } type writerFunc func(p []byte) (int, error) func (f writerFunc) Write(p []byte) (int, error) { return f(p) } // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte bw.Reset(writerFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) bw.WriteByte(0) bw.Flush() bw.Reset(w) return writeBuf } func (c *Conn) writeError(code StatusCode, err error) { c.setCloseErr(err) c.writeClose(code, err.Error()) c.close(nil) } websocket-1.8.7/ws_js.go000066400000000000000000000212731403335366200152030ustar00rootroot00000000000000package websocket // import "nhooyr.io/websocket" import ( "bytes" "context" "errors" "fmt" "io" "net/http" "reflect" "runtime" "strings" "sync" "syscall/js" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" "nhooyr.io/websocket/internal/xsync" ) // Conn provides a wrapper around the browser WebSocket API. type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. msgReadLimit xsync.Int64 closingMu sync.Mutex isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once closeErr error closeWasClean bool releaseOnClose func() releaseOnMessage func() readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent } func (c *Conn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) if !wasClean { err = fmt.Errorf("unclean connection close: %w", err) } c.setCloseErr(err) c.closeWasClean = wasClean close(c.closed) }) } func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) c.msgReadLimit.Store(32768) c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } // We do not know if we sent or received this close as // its possible the browser triggered it without us // explicitly sending it. c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnMessage() }) c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { c.readBufMu.Lock() defer c.readBufMu.Unlock() c.readBuf = append(c.readBuf, e) // Lets the read goroutine know there is definitely something in readBuf. select { case c.readSignal <- struct{}{}: default: } }) runtime.SetFinalizer(c, func(c *Conn) { c.setCloseErr(errors.New("connection garbage collected")) c.closeWithInternal() }) } func (c *Conn) closeWithInternal() { c.Close(StatusInternalError, "something went wrong") } // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { return 0, nil, errors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } return typ, p, nil } func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "read timed out") return 0, nil, ctx.Err() case <-c.readSignal: case <-c.closed: return 0, nil, c.closeErr } c.readBufMu.Lock() defer c.readBufMu.Unlock() me := c.readBuf[0] // We copy the messages forward and decrease the size // of the slice to avoid reallocating. copy(c.readBuf, c.readBuf[1:]) c.readBuf = c.readBuf[:len(c.readBuf)-1] if len(c.readBuf) > 0 { // Next time we read, we'll grab the message. select { case c.readSignal <- struct{}{}: default: } } switch p := me.Data.(type) { case string: return MessageText, []byte(p), nil case []byte: return MessageBinary, p, nil default: panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) } } // Ping is mocked out for Wasm. func (c *Conn) Ping(ctx context.Context) error { return nil } // Write writes a message of the given type to the connection. // Always non blocking. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { err := c.write(ctx, typ, p) if err != nil { // Have to ensure the WebSocket is closed after a write error // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. err := fmt.Errorf("failed to write: %w", err) c.setCloseErr(err) c.closeWithInternal() return err } return nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { if c.isClosed() { return c.closeErr } switch typ { case MessageBinary: return c.ws.SendBytes(p) case MessageText: return c.ws.SendText(string(p)) default: return fmt.Errorf("unexpected message type: %v", typ) } } // Close closes the WebSocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) } return nil } func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() ce := fmt.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) if c.isClosed() { return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { return err } <-c.closed if !c.closeWasClean { return c.closeErr } return nil } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.ws.Subprotocol() } // DialOptions represents the options available to pass to Dial. type DialOptions struct { // Subprotocols lists the subprotocols to negotiate with the server. Subprotocols []string } // Dial creates a new WebSocket connection to the given url with the given options. // The passed context bounds the maximum time spent waiting for the connection to open. // The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { c, resp, err := dial(ctx, url, opts) if err != nil { return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, resp, nil } func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { if opts == nil { opts = &DialOptions{} } url = strings.Replace(url, "http://", "ws://", 1) url = strings.Replace(url, "https://", "wss://", 1) ws, err := wsjs.New(url, opts.Subprotocols) if err != nil { return nil, nil, err } c := &Conn{ ws: ws, } c.init() opench := make(chan struct{}) releaseOpen := ws.OnOpen(func(e js.Value) { close(opench) }) defer releaseOpen() select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: return c, &http.Response{ StatusCode: http.StatusSwitchingProtocols, }, nil case <-c.closed: return nil, nil, c.closeErr } } // Reader attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { typ, p, err := c.Read(ctx) if err != nil { return 0, nil, err } return typ, bytes.NewReader(p), nil } // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { return writer{ c: c, ctx: ctx, typ: typ, b: bpool.Get(), }, nil } type writer struct { closed bool c *Conn ctx context.Context typ MessageType b *bytes.Buffer } func (w writer) Write(p []byte) (int, error) { if w.closed { return 0, errors.New("cannot write to closed writer") } n, err := w.b.Write(p) if err != nil { return n, fmt.Errorf("failed to write message: %w", err) } return n, nil } func (w writer) Close() error { if w.closed { return errors.New("cannot close closed writer") } w.closed = true defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { return fmt.Errorf("failed to close writer: %w", err) } return nil } // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { c.isReadClosed.Store(1) ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() c.read(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx } // SetReadLimit implements *Conn.SetReadLimit for wasm. func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit.Store(n) } func (c *Conn) setCloseErr(err error) { c.closeErrOnce.Do(func() { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) }) } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } websocket-1.8.7/ws_js_test.go000066400000000000000000000015421403335366200162370ustar00rootroot00000000000000package websocket_test import ( "context" "net/http" "os" "testing" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" ) func TestWasm(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") assert.Equal(t, "subprotocol", "echo", c.Subprotocol()) assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode) c.SetReadLimit(65536) for i := 0; i < 10; i++ { err = wstest.Echo(ctx, c, 65536) assert.Success(t, err) } err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } websocket-1.8.7/wsjson/000077500000000000000000000000001403335366200150455ustar00rootroot00000000000000websocket-1.8.7/wsjson/wsjson.go000066400000000000000000000032271403335366200167230ustar00rootroot00000000000000// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" "fmt" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" ) // Read reads a JSON message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to read JSON message") _, r, err := c.Reader(ctx) if err != nil { return err } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil } // Write writes the JSON message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return err } // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(w).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } return w.Close() } websocket-1.8.7/wspb/000077500000000000000000000000001403335366200144755ustar00rootroot00000000000000websocket-1.8.7/wspb/wspb.go000066400000000000000000000034751403335366200160100ustar00rootroot00000000000000// Package wspb provides helpers for reading and writing protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" "fmt" "github.com/golang/protobuf/proto" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" ) // Read reads a protobuf message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { defer errd.Wrap(&err, "failed to read protobuf message") typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") return fmt.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = proto.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") return fmt.Errorf("failed to unmarshal protobuf: %w", err) } return nil } // Write writes the protobuf message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { defer errd.Wrap(&err, "failed to write protobuf message") b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { bpool.Put(bytes.NewBuffer(pb.Bytes())) }() err = pb.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) }