pax_global_header00006660000000000000000000000064146554641730014531gustar00rootroot0000000000000052 comment=e4379472fe1dfe70032ecc68fec08b1b3a8fc996 websocket-1.8.12/000077500000000000000000000000001465546417300136105ustar00rootroot00000000000000websocket-1.8.12/.github/000077500000000000000000000000001465546417300151505ustar00rootroot00000000000000websocket-1.8.12/.github/FUNDING.yml000066400000000000000000000000171465546417300167630ustar00rootroot00000000000000github: nhooyr websocket-1.8.12/.github/workflows/000077500000000000000000000000001465546417300172055ustar00rootroot00000000000000websocket-1.8.12/.github/workflows/ci.yml000066400000000000000000000020041465546417300203170ustar00rootroot00000000000000name: ci on: [push, pull_request] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true jobs: fmt: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/fmt.sh lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - run: go version - uses: actions/setup-go@v5 - run: ./ci/lint.sh test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/test.sh - uses: actions/upload-artifact@v3 with: name: coverage.html path: ./ci/out/coverage.html bench: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/bench.sh websocket-1.8.12/.github/workflows/daily.yml000066400000000000000000000024301465546417300210310ustar00rootroot00000000000000name: daily on: workflow_dispatch: schedule: - cron: '42 0 * * *' # daily at 00:42 concurrency: group: ${{ github.workflow }} cancel-in-progress: true jobs: bench: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/bench.sh test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - uses: actions/upload-artifact@v3 with: name: coverage.html path: ./ci/out/coverage.html bench-dev: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: ref: dev - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/bench.sh test-dev: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: ref: dev - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - uses: actions/upload-artifact@v3 with: name: coverage-dev.html path: ./ci/out/coverage.html websocket-1.8.12/LICENSE.txt000066400000000000000000000013451465546417300154360ustar00rootroot00000000000000Copyright (c) 2023 Anmol Sethi Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. websocket-1.8.12/README.md000066400000000000000000000153061465546417300150740ustar00rootroot00000000000000# websocket [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket) [![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://github.com/coder/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```sh go get github.com/coder/websocket ``` > [!NOTE] > Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket). > We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from > 2019 to 2024. ## Highlights - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports) - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close) - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper - [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections - Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm) ## Roadmap See GitHub issues for minor issues but the major future enhancements are: - [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217) - [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340) - [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267) - [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246) - [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209) - [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16) - WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) - [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402) ## Examples For a production quality example that demonstrates the complete API, see the [echo example](./internal/examples/echo). For a full stack example, see the [chat example](./internal/examples/chat). ### Server ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { // ... } defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { // ... } log.Printf("received: %v", v) c.Close(websocket.StatusNormalClosure, "") }) ``` ### Client ```go ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { // ... } c.Close(websocket.StatusNormalClosure, "") ``` ## Comparison ### gorilla/websocket Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) - No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection. - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) Advantages of github.com/coder/websocket: - Minimal and idiomatic API - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) #### golang.org/x/net/websocket [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning to github.com/coder/websocket. #### gobwas/ws [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. #### lesismal/nbio [lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is event driven for performance reasons. However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. websocket-1.8.12/accept.go000066400000000000000000000242301465546417300153770ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bytes" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "path/filepath" "strings" "github.com/coder/websocket/internal/errd" ) // AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // // You probably want to use OriginPatterns instead. InsecureSkipVerify bool // OriginPatterns lists the host patterns for authorized origins. // The request host is always authorized. // Use this to enable cross origin WebSockets. // // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. // In such a case, example.com is the origin and chat.example.com is the request host. // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host // with filepath.Match. // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. // // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead // to bring attention to the danger of such a setting. OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { var o AcceptOptions if opts != nil { o = *opts } return &o } // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // // Accept will not allow cross origin requests by default. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } opts = opts.cloneWithDefaults() if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { if errors.Is(err, filepath.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } } hj, ok := w.(http.Hijacker) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) if ok { w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) // See https://github.com/nhooyr/websocket/issues/166 if ginWriter, ok := w.(interface { WriteHeaderNow() }); ok { ginWriter.WriteHeaderNow() } netConn, brw, err := hj.Hijack() if err != nil { err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } // https://github.com/golang/go/issues/32314 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, copts: copts, flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") if len(websocketSecKeys) == 0 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } if len(websocketSecKeys) > 1 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") } // The RFC states to remove any leading or trailing whitespace. websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) } return 0, nil } func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") if origin == "" { return nil } u, err := url.Parse(origin) if err != nil { return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if strings.EqualFold(r.Host, u.Host) { return nil } for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } if matched { return nil } } if u.Host == "" { return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) } return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { return cp } } } return "" } func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { return nil, false } for _, ext := range extensions { switch ext.name { // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": copts, ok := acceptDeflate(ext, mode) if ok { return copts, true } } } return nil, false } func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue case "client_max_window_bits", "server_max_window_bits=15": continue } if strings.HasPrefix(p, "client_max_window_bits=") { // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } return nil, false } return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { if strings.EqualFold(t, token) { return true } } return false } type websocketExtension struct { name string params []string } func websocketExtensions(h http.Header) []websocketExtension { var exts []websocketExtension extStrs := headerTokens(h, "Sec-WebSocket-Extensions") for _, extStr := range extStrs { if extStr == "" { continue } vals := strings.Split(extStr, ";") for i := range vals { vals[i] = strings.TrimSpace(vals[i]) } e := websocketExtension{ name: vals[0], params: vals[1:], } exts = append(exts, e) } return exts } func headerTokens(h http.Header, key string) []string { key = textproto.CanonicalMIMEHeaderKey(key) var tokens []string for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { t = strings.TrimSpace(t) tokens = append(tokens, t) } } return tokens } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } websocket-1.8.12/accept_test.go000066400000000000000000000307111465546417300164370ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "errors" "net" "net/http" "net/http/httptest" "strings" "sync" "testing" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/xrand" ) func TestAccept(t *testing.T) { t.Parallel() t.Run("badClientHandshake", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) assert.Contains(t, err, "protocol violation") }) t.Run("badOrigin", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`) }) // #247 t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "https://harhar.com") _, err := Accept(w, r, nil) assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`) }) t.Run("badCompression", func(t *testing.T) { t.Parallel() newRequest := func(extensions string) *http.Request { r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Sec-WebSocket-Extensions", extensions) return r } errHijack := errors.New("hijack error") newResponseWriter := func() http.ResponseWriter { return mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, errHijack }, } } t.Run("withoutFallback", func(t *testing.T) { t.Parallel() w := newResponseWriter() r := newRequest("permessage-deflate; harharhar") _, err := Accept(w, r, &AcceptOptions{ CompressionMode: CompressionNoContextTakeover, }) assert.ErrorIs(t, errHijack, err) assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") }) t.Run("withFallback", func(t *testing.T) { t.Parallel() w := newResponseWriter() r := newRequest("permessage-deflate; harharhar, permessage-deflate") _, err := Accept(w, r, &AcceptOptions{ CompressionMode: CompressionNoContextTakeover, }) assert.ErrorIs(t, errHijack, err) assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), CompressionNoContextTakeover.opts().String(), ) }) }) t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) }) t.Run("badHijack", func(t *testing.T) { t.Parallel() w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { return nil, nil, errors.New("haha") }, } r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) t.Run("closeRace", func(t *testing.T) { t.Parallel() server, _ := net.Pipe() rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) newResponseWriter := func() http.ResponseWriter { return mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (net.Conn, *bufio.ReadWriter, error) { return server, rw, nil }, } } w := newResponseWriter() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) c, err := Accept(w, r, nil) wg := &sync.WaitGroup{} wg.Add(2) go func() { c.Close(StatusInternalError, "the sky is falling") wg.Done() }() go func() { c.CloseNow() wg.Done() }() wg.Wait() assert.Success(t, err) }) } func Test_verifyClientHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string method string http1 bool h map[string]string success bool }{ { name: "badConnection", h: map[string]string{ "Connection": "notUpgrade", }, }, { name: "badUpgrade", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "notWebSocket", }, }, { name: "badMethod", method: "POST", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", }, }, { name: "badWebSocketVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "14", }, }, { name: "missingWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", }, }, { name: "emptyWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "", }, }, { name: "shortWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": xrand.Base64(15), }, }, { name: "invalidWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "notbase64", }, }, { name: "extraWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", // Kinda cheeky, but http headers are case-insensitive. // If 2 sec keys are present, this is a failure condition. "Sec-WebSocket-Key": xrand.Base64(16), "sec-webSocket-key": xrand.Base64(16), }, }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": xrand.Base64(16), }, http1: true, }, { name: "success", h: map[string]string{ "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": xrand.Base64(16), }, success: true, }, { name: "successSecKeyExtraSpace", h: map[string]string{ "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": " " + xrand.Base64(16) + " ", }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest(tc.method, "/", nil) r.ProtoMajor = 1 r.ProtoMinor = 1 if tc.http1 { r.ProtoMinor = 0 } for k, v := range tc.h { r.Header.Add(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_selectSubprotocol(t *testing.T) { t.Parallel() testCases := []struct { name string clientProtocols []string serverProtocols []string negotiated string }{ { name: "empty", clientProtocols: nil, serverProtocols: nil, negotiated: "", }, { name: "basic", clientProtocols: []string{"echo", "echo2"}, serverProtocols: []string{"echo2", "echo"}, negotiated: "echo2", }, { name: "none", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo4"}, negotiated: "", }, { name: "fallback", clientProtocols: []string{"echo", "echo3"}, serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, { name: "clientCasePresered", clientProtocols: []string{"Echo1"}, serverProtocols: []string{"echo1"}, negotiated: "Echo1", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { name string origin string host string originPatterns []string success bool }{ { name: "none", success: true, host: "example.com", }, { name: "invalid", origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", host: "example.com", success: false, }, { name: "unauthorized", origin: "https://example.com", host: "example1.com", success: false, }, { name: "authorized", origin: "https://example.com", host: "example.com", success: true, }, { name: "authorizedCaseInsensitive", origin: "https://examplE.com", host: "example.com", success: true, }, { name: "originPatterns", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "*.example.com", "bar.com", }, success: true, }, { name: "originPatternsUnauthorized", origin: "https://two.examplE.com", host: "example.com", originPatterns: []string{ "exam3.com", "bar.com", }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { name string mode CompressionMode header string expCopts *compressionOptions expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, expOK: false, }, { name: "permessage-deflate", mode: CompressionNoContextTakeover, header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, expOK: true, }, { name: "permessage-deflate/unknown-parameter", mode: CompressionNoContextTakeover, header: "permessage-deflate; meow", expOK: false, }, { name: "permessage-deflate/unknown-parameter", mode: CompressionNoContextTakeover, header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, expOK: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() h := http.Header{} h.Set("Sec-WebSocket-Extensions", tc.header) copts, ok := selectDeflate(websocketExtensions(h), tc.mode) assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) }) } } type mockHijacker struct { http.ResponseWriter hijack func() (net.Conn, *bufio.ReadWriter, error) } var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } websocket-1.8.12/autobahn_test.go000066400000000000000000000162701465546417300170050ustar00rootroot00000000000000//go:build !js // +build !js package websocket_test import ( "context" "encoding/json" "errors" "fmt" "io" "net" "os" "os/exec" "strconv" "strings" "testing" "time" "github.com/coder/websocket" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/wstest" "github.com/coder/websocket/internal/util" ) var excludedAutobahnCases = []string{ // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just // more performance overhead. "6.*", "7.5.1", // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 "13.3.*", "13.4.*", "13.5.*", "13.6.*", } var autobahnCases = []string{"*"} // Used to run individual test cases. autobahnCases runs only those cases matched // and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases // is niled. var onlyAutobahnCases = []string{} func TestAutobahn(t *testing.T) { t.Parallel() if os.Getenv("AUTOBAHN") == "" { t.SkipNow() } if os.Getenv("AUTOBAHN") == "fast" { // These are the slow tests. excludedAutobahnCases = append(excludedAutobahnCases, "9.*", "12.*", "13.*", ) } if len(onlyAutobahnCases) > 0 { excludedAutobahnCases = []string{} autobahnCases = onlyAutobahnCases } ctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() wstestURL, closeFn, err := wstestServer(t, ctx) assert.Success(t, err) defer func() { assert.Success(t, closeFn()) }() err = waitWS(ctx, wstestURL) assert.Success(t, err) cases, err := wstestCaseCount(ctx, wstestURL) assert.Success(t, err) t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ CompressionMode: websocket.CompressionContextTakeover, }) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) } }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/autobahn-report/index.json") } func waitWS(ctx context.Context, url string) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() for ctx.Err() == nil { c, _, err := websocket.Dial(ctx, url, nil) if err != nil { continue } c.Close(websocket.StatusNormalClosure, "") return nil } return ctx.Err() } func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) { defer errd.Wrap(&err, "failed to start autobahn wstest server") serverAddr, err := unusedListenAddr() if err != nil { return "", nil, err } _, serverPort, err := net.SplitHostPort(serverAddr) if err != nil { return "", nil, err } url = "ws://" + serverAddr const outDir = "ci/out/autobahn-report" specFile, err := tempJSONFile(map[string]interface{}{ "url": url, "outdir": outDir, "cases": autobahnCases, "exclude-cases": excludedAutobahnCases, }) if err != nil { return "", nil, fmt.Errorf("failed to write spec: %w", err) } ctx, cancel := context.WithTimeout(ctx, time.Hour) defer func() { if err != nil { cancel() } }() dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite") dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) { tb.Log(string(p)) return len(p), nil }) dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) { tb.Log(string(p)) return len(p), nil }) tb.Log(dockerPull) err = dockerPull.Run() if err != nil { return "", nil, fmt.Errorf("failed to pull docker image: %w", err) } wd, err := os.Getwd() if err != nil { return "", nil, err } var args []string args = append(args, "run", "-i", "--rm", "-v", fmt.Sprintf("%s:%[1]s", specFile), "-v", fmt.Sprintf("%s/ci:/ci", wd), fmt.Sprintf("-p=%s:%s", serverAddr, serverPort), "crossbario/autobahn-testsuite", ) args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", ) wstest := exec.CommandContext(ctx, "docker", args...) wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) { tb.Log(string(p)) return len(p), nil }) wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) { tb.Log(string(p)) return len(p), nil }) tb.Log(wstest) err = wstest.Start() if err != nil { return "", nil, fmt.Errorf("failed to start wstest: %w", err) } return url, func() error { err = wstest.Process.Kill() if err != nil { return fmt.Errorf("failed to kill wstest: %w", err) } err = wstest.Wait() var ee *exec.ExitError if errors.As(err, &ee) && ee.ExitCode() == -1 { return nil } return err }, nil } func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { defer errd.Wrap(&err, "failed to get case count") c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) if err != nil { return 0, err } defer c.Close(websocket.StatusInternalError, "") _, r, err := c.Reader(ctx) if err != nil { return 0, err } b, err := io.ReadAll(r) if err != nil { return 0, err } cases, err = strconv.Atoi(string(b)) if err != nil { return 0, err } c.Close(websocket.StatusNormalClosure, "") return cases, nil } func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := os.ReadFile(path) assert.Success(t, err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) assert.Success(t, err) for _, tests := range indexJSON { for test, result := range tests { t.Run(test, func(t *testing.T) { switch result.BehaviorClose { case "OK", "INFORMATIONAL": default: t.Errorf("bad close behaviour") } switch result.Behavior { case "OK", "NON-STRICT", "INFORMATIONAL": default: t.Errorf("failed") } }) } } if t.Failed() { htmlPath := strings.Replace(path, ".json", ".html", 1) t.Errorf("detected autobahn violation, see %q", htmlPath) } } func unusedListenAddr() (_ string, err error) { defer errd.Wrap(&err, "failed to get unused listen address") l, err := net.Listen("tcp", "localhost:0") if err != nil { return "", err } l.Close() return l.Addr().String(), nil } func tempJSONFile(v interface{}) (string, error) { f, err := os.CreateTemp("", "temp.json") if err != nil { return "", fmt.Errorf("temp file: %w", err) } defer f.Close() e := json.NewEncoder(f) e.SetIndent("", "\t") err = e.Encode(v) if err != nil { return "", fmt.Errorf("json encode: %w", err) } err = f.Close() if err != nil { return "", fmt.Errorf("close temp file: %w", err) } return f.Name(), nil } websocket-1.8.12/ci/000077500000000000000000000000001465546417300142035ustar00rootroot00000000000000websocket-1.8.12/ci/bench.sh000077500000000000000000000012261465546417300156220ustar00rootroot00000000000000#!/bin/sh set -eu cd -- "$(dirname "$0")/.." go test --run=^$ --bench=. --benchmem "$@" ./... # For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test ( cd ./internal/thirdparty go test --run=^$ --bench=. --benchmem "$@" . GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . if [ "$#" -eq 0 ]; then if [ "${CI-}" ]; then sudo apt-get update sudo apt-get install -y qemu-user-static ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 fi qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem fi ) websocket-1.8.12/ci/fmt.sh000077500000000000000000000010741465546417300153320ustar00rootroot00000000000000#!/bin/sh set -eu cd -- "$(dirname "$0")/.." go mod tidy (cd ./internal/thirdparty && go mod tidy) (cd ./internal/examples && go mod tidy) gofmt -w -s . go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . npx prettier@3.0.3 \ --write \ --log-level=warn \ --print-width=90 \ --no-semi \ --single-quote \ --arrow-parens=avoid \ $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go if [ "${CI-}" ]; then git diff --exit-code fi websocket-1.8.12/ci/lint.sh000077500000000000000000000011111465546417300155020ustar00rootroot00000000000000#!/bin/sh set -eu cd -- "$(dirname "$0")/.." go vet ./... GOOS=js GOARCH=wasm go vet ./... go install honnef.co/go/tools/cmd/staticcheck@latest staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./... govulncheck() { tmpf=$(mktemp) if ! command govulncheck "$@" >"$tmpf" 2>&1; then cat "$tmpf" fi } go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... GOOS=js GOARCH=wasm govulncheck ./... ( cd ./internal/examples go vet ./... staticcheck ./... govulncheck ./... ) ( cd ./internal/thirdparty go vet ./... staticcheck ./... govulncheck ./... ) websocket-1.8.12/ci/out/000077500000000000000000000000001465546417300150125ustar00rootroot00000000000000websocket-1.8.12/ci/out/.gitignore000066400000000000000000000000021465546417300167720ustar00rootroot00000000000000* websocket-1.8.12/ci/test.sh000077500000000000000000000017351465546417300155270ustar00rootroot00000000000000#!/bin/sh set -eu cd -- "$(dirname "$0")/.." ( cd ./internal/examples go test "$@" ./... ) ( cd ./internal/thirdparty go test "$@" ./... ) ( GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . if [ "$#" -eq 0 ]; then if [ "${CI-}" ]; then sudo apt-get update sudo apt-get install -y qemu-user-static ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 fi qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask fi ) go install github.com/agnivade/wasmbrowsertest@latest go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof sed -i.bak '/examples/d' ci/out/coverage.prof # Last line is the total coverage. go tool cover -func ci/out/coverage.prof | tail -n1 go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html websocket-1.8.12/close.go000066400000000000000000000176151465546417300152560ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "context" "encoding/binary" "errors" "fmt" "net" "time" "github.com/coder/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 } // Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. // All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") if !c.casClosing() { err = c.waitGoroutines() if err != nil { return err } return net.ErrClosed } defer func() { if errors.Is(err, net.ErrClosed) { err = nil } }() err = c.closeHandshake(code, reason) err2 := c.close() if err == nil && err2 != nil { err = err2 } err2 = c.waitGoroutines() if err == nil && err2 != nil { err = err2 } return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") if !c.casClosing() { err = c.waitGoroutines() if err != nil { return err } return net.ErrClosed } defer func() { if errors.Is(err, net.ErrClosed) { err = nil } }() err = c.close() err2 := c.waitGoroutines() if err == nil && err2 != nil { err = err2 } return err } func (c *Conn) closeHandshake(code StatusCode, reason string) error { err := c.writeClose(code, reason) if err != nil { return err } err = c.waitCloseHandshake() if CloseStatus(err) != code { return err } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { ce := CloseError{ Code: code, Reason: reason, } var p []byte var err error if ce.Code != StatusNoStatusRcvd { p, err = ce.bytes() if err != nil { return err } } ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err = c.writeControl(ctx, opClose, p) // If the connection closed as we're writing we ignore the error as we might // have written the close frame, the peer responded and then someone else read it // and closed the connection. if err != nil && !errors.Is(err, net.ErrClosed) { return err } return nil } func (c *Conn) waitCloseHandshake() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) if err != nil { return err } defer c.readMu.unlock() for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } for { h, err := c.readLoop(ctx) if err != nil { return err } for i := int64(0); i < h.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { return err } } } } func (c *Conn) waitGoroutines() error { t := time.NewTimer(time.Second * 15) defer t.Stop() select { case <-c.timeoutLoopDone: case <-t.C: return errors.New("failed to wait for timeoutLoop goroutine to exit") } c.closeReadMu.Lock() closeRead := c.closeReadCtx != nil c.closeReadMu.Unlock() if closeRead { select { case <-c.closeReadDone: case <-t.C: return errors.New("failed to wait for close read goroutine to exit") } } select { case <-c.closed: case <-t.C: return errors.New("failed to wait for connection to be closed") } return nil } func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ Code: StatusNoStatusRcvd, }, nil } if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ Code: StatusCode(binary.BigEndian.Uint16(p)), Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { switch code { case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: return false } if code >= StatusNormalClosure && code <= StatusBadGateway { return true } if code >= 3000 && code <= 4999 { return true } return false } func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } return p, err } const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } func (c *Conn) casClosing() bool { c.closeMu.Lock() defer c.closeMu.Unlock() if !c.closing { c.closing = true return true } return false } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } websocket-1.8.12/close_test.go000066400000000000000000000064111465546417300163050ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "io" "math" "strings" "testing" "github.com/coder/websocket/internal/test/assert" ) func TestCloseError(t *testing.T) { t.Parallel() testCases := []struct { name string ce CloseError success bool }{ { name: "normal", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason), }, success: true, }, { name: "bigReason", ce: CloseError{ Code: StatusNormalClosure, Reason: strings.Repeat("x", maxCloseReason+1), }, success: false, }, { name: "bigCode", ce: CloseError{ Code: math.MaxUint16, Reason: strings.Repeat("x", maxCloseReason), }, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } t.Run("Error", func(t *testing.T) { exp := `status = StatusInternalError and reason = "meow"` act := CloseError{ Code: StatusInternalError, Reason: "meow", }.Error() assert.Equal(t, "CloseError.Error()", exp, act) }) } func Test_parseClosePayload(t *testing.T) { t.Parallel() testCases := []struct { name string p []byte success bool ce CloseError }{ { name: "normal", p: append([]byte{0x3, 0xE8}, []byte("hello")...), success: true, ce: CloseError{ Code: StatusNormalClosure, Reason: "hello", }, }, { name: "nothing", success: true, ce: CloseError{ Code: StatusNoStatusRcvd, }, }, { name: "oneByte", p: []byte{0}, success: false, }, { name: "badStatusCode", p: []byte{0x17, 0x70}, success: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ce, err := parseClosePayload(tc.p) if tc.success { assert.Success(t, err) assert.Equal(t, "close payload", tc.ce, ce) } else { assert.Error(t, err) } }) } } func Test_validWireCloseCode(t *testing.T) { t.Parallel() testCases := []struct { name string code StatusCode valid bool }{ { name: "normal", code: StatusNormalClosure, valid: true, }, { name: "noStatus", code: StatusNoStatusRcvd, valid: false, }, { name: "3000", code: 3000, valid: true, }, { name: "4999", code: 4999, valid: true, }, { name: "unknown", code: 5000, valid: false, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := validWireCloseCode(tc.code) assert.Equal(t, "wire close code", tc.valid, act) }) } } func TestCloseStatus(t *testing.T) { t.Parallel() testCases := []struct { name string in error exp StatusCode }{ { name: "nil", in: nil, exp: -1, }, { name: "io.EOF", in: io.EOF, exp: -1, }, { name: "StatusInternalError", in: CloseError{ Code: StatusInternalError, }, exp: StatusInternalError, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() act := CloseStatus(tc.in) assert.Equal(t, "close status", tc.exp, act) }) } } websocket-1.8.12/compress.go000066400000000000000000000133271465546417300160000ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "compress/flate" "io" "sync" ) // CompressionMode represents the modes available to the permessage-deflate extension. // See https://tools.ietf.org/html/rfc7692 // // Works in all modern browsers except Safari which does not implement the permessage-deflate extension. // // Compression is only used if the peer supports the mode selected. type CompressionMode int const ( // CompressionDisabled disables the negotiation of the permessage-deflate extension. // // This is the default. Do not enable compression without benchmarking for your particular use case first. CompressionDisabled CompressionMode = iota // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from // previous messages. i.e compression context across messages is preserved. // // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. // // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's // that are used when reading and then returned. // // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. // // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. CompressionContextTakeover // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from // a sync.Pool. // // This means less efficient compression as the sliding window from previous messages will not be used but the // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. // Especially if the connections are long lived and seldom written to. // // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. // // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. CompressionNoContextTakeover ) func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool } func (copts *compressionOptions) String() string { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" } if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } return s } // These bytes are required to get flate.Reader to return. // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps // trying to read more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { w io.Writer tail []byte } func (tw *trimLastFourBytesWriter) reset() { if tw != nil && tw.tail != nil { tw.tail = tw.tail[:0] } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if tw.tail == nil { tw.tail = make([]byte, 0, 4) } extra := len(tw.tail) + len(p) - 4 if extra <= 0 { tw.tail = append(tw.tail, p...) return len(p), nil } // Now we need to write as many extra bytes as we can from the previous tail. if extra > len(tw.tail) { extra = len(tw.tail) } if extra > 0 { _, err := tw.w.Write(tw.tail[:extra]) if err != nil { return 0, err } // Shift remaining bytes in tail over. n := copy(tw.tail, tw.tail[extra:]) tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, // all of it is is part of the tail. if len(p) <= 4 { tw.tail = append(tw.tail, p...) return len(p), nil } // Otherwise, only the last 4 bytes are. tw.tail = append(tw.tail, p[len(p)-4:]...) p = p[:len(p)-4] n, err := tw.w.Write(p) return n + 4, err } var flateReaderPool sync.Pool func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { return flate.NewReaderDict(r, dict) } fr.(flate.Resetter).Reset(r, dict) return fr } func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } var flateWriterPool sync.Pool func getFlateWriter(w io.Writer) *flate.Writer { fw, ok := flateWriterPool.Get().(*flate.Writer) if !ok { fw, _ = flate.NewWriter(w, flate.BestSpeed) return fw } fw.Reset(w) return fw } func putFlateWriter(w *flate.Writer) { flateWriterPool.Put(w) } type slidingWindow struct { buf []byte } var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} func slidingWindowPool(n int) *sync.Pool { swPoolMu.RLock() p, ok := swPool[n] swPoolMu.RUnlock() if ok { return p } p = &sync.Pool{} swPoolMu.Lock() swPool[n] = p swPoolMu.Unlock() return p } func (sw *slidingWindow) init(n int) { if sw.buf != nil { return } if n == 0 { n = 32768 } p := slidingWindowPool(n) sw2, ok := p.Get().(*slidingWindow) if ok { *sw = *sw2 } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { sw.buf = sw.buf[:0] swPoolMu.Lock() swPool[cap(sw.buf)].Put(sw) swPoolMu.Unlock() } func (sw *slidingWindow) write(p []byte) { if len(p) >= cap(sw.buf) { sw.buf = sw.buf[:cap(sw.buf)] p = p[len(p)-cap(sw.buf):] copy(sw.buf, p) return } left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left copy(sw.buf, sw.buf[spaceNeeded:]) sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } sw.buf = append(sw.buf, p...) } websocket-1.8.12/compress_test.go000066400000000000000000000024111465546417300170270ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bytes" "compress/flate" "io" "strings" "testing" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { t.Parallel() const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { t.Run("", func(t *testing.T) { t.Parallel() input := xrand.String(maxWindow) windowLength := xrand.Int(maxWindow) var sw slidingWindow sw.init(windowLength) sw.write([]byte(input)) assert.Equal(t, "window length", windowLength, cap(sw.buf)) if !strings.HasSuffix(input, string(sw.buf)) { t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) } }) } } func BenchmarkFlateWriter(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { w, _ := flate.NewWriter(io.Discard, flate.BestSpeed) // We have to write a byte to get the writer to allocate to its full extent. w.Write([]byte{'a'}) w.Flush() } } func BenchmarkFlateReader(b *testing.B) { b.ReportAllocs() var buf bytes.Buffer w, _ := flate.NewWriter(&buf, flate.BestSpeed) w.Write([]byte{'a'}) w.Flush() for i := 0; i < b.N; i++ { r := flate.NewReader(bytes.NewReader(buf.Bytes())) io.ReadAll(r) } } websocket-1.8.12/conn.go000066400000000000000000000135351465546417300151030ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "context" "fmt" "io" "net" "runtime" "strconv" "sync" "sync/atomic" ) // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int // MessageType constants. const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary ) // Conn represents a WebSocket connection. // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control // frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. // // On any error from any method, the connection is closed // with an appropriate reason. // // This applies to context expirations as well unfortunately. // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 type Conn struct { noCopy noCopy subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context timeoutLoopDone chan struct{} // Read state. readMu *mu readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader // Write state. msgWriter *msgWriter writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte writeHeader header closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} closed chan struct{} closeMu sync.Mutex closing bool pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } type connConfig struct { subprotocol string rwc io.ReadWriteCloser client bool copts *compressionOptions flateThreshold int br *bufio.Reader bw *bufio.Writer } func newConn(cfg connConfig) *Conn { c := &Conn{ subprotocol: cfg.subprotocol, rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), timeoutLoopDone: make(chan struct{}), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } c.readMu = newMu(c) c.writeFrameMu = newMu(c) c.msgReader = newMsgReader(c) c.msgWriter = newMsgWriter(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 if !c.msgWriter.flateContextTakeover() { c.flateThreshold = 512 } } runtime.SetFinalizer(c, func(c *Conn) { c.close() }) go c.timeoutLoop() return c } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { return net.ErrClosed } runtime.SetFinalizer(c, nil) close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. err := c.rwc.Close() // With the close of rwc, these become safe to close. c.msgWriter.close() c.msgReader.close() return err } func (c *Conn) timeoutLoop() { defer close(c.timeoutLoopDone) readCtx := context.Background() writeCtx := context.Background() for { select { case <-c.closed: return case writeCtx = <-c.writeTimeout: case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.close() return case <-writeCtx.Done(): c.close() return } } } func (c *Conn) flate() bool { return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does // not read from the connection but instead waits for a Reader call // to read the pong. // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { p := atomic.AddInt32(&c.pingCounter, 1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } return nil } func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}, 1) c.activePingsMu.Lock() c.activePings[p] = pong c.activePingsMu.Unlock() defer func() { c.activePingsMu.Lock() delete(c.activePings, p) c.activePingsMu.Unlock() }() err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { case <-c.closed: return net.ErrClosed case <-ctx.Done(): return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } } type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) tryLock() bool { select { case m.ch <- struct{}{}: return true default: return false } } func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected // over the receive on closed. select { case <-m.c.closed: // Make sure to release. m.unlock() return net.ErrClosed default: } return nil } } func (m *mu) unlock() { select { case <-m.ch: default: } } type noCopy struct{} func (*noCopy) Lock() {} websocket-1.8.12/conn_test.go000066400000000000000000000331321465546417300161350ustar00rootroot00000000000000//go:build !js package websocket_test import ( "bytes" "context" "errors" "fmt" "io" "net/http" "net/http/httptest" "os" "os/exec" "strings" "testing" "time" "github.com/coder/websocket" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/wstest" "github.com/coder/websocket/internal/test/xrand" "github.com/coder/websocket/internal/xsync" "github.com/coder/websocket/wsjson" ) func TestConn(t *testing.T) { t.Parallel() t.Run("fuzzData", func(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1)) } for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }, &websocket.AcceptOptions{ CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { err := wstest.Echo(tt.ctx, c1, 131072) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } }) t.Run("badClose", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) c2.CloseRead(tt.ctx) err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { err := c1.Ping(tt.ctx) assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) c2.CloseRead(tt.ctx) ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) defer cancel() err := c1.Ping(ctx) assert.Contains(t, err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) const count = 100 errs := make(chan error, count) for i := 0; i < count; i++ { go func() { select { case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg): case <-tt.ctx.Done(): return } }() } for i := 0; i < count; i++ { select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } } err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("unexpected error: %#v", err) } }) t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) // Does not give any confidence but at least ensures no crashes. d, _ := tt.ctx.Deadline() n1.SetDeadline(d) n1.SetDeadline(time.Time{}) assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String()) assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) if err != nil { return err } return n2.Close() }) b, err := io.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } assert.Equal(t, "read msg", []byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) c2.CloseRead(tt.ctx) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) return err }) _, err := io.ReadAll(n1) assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } }) t.Run("netConn/readLimit", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) s := strings.Repeat("papa", 1<<20) errs := xsync.Go(func() error { _, err := n2.Write([]byte(s)) if err != nil { return err } return n2.Close() }) b, err := io.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } assert.Equal(t, "read msg", s, string(b)) }) t.Run("netConn/pastDeadline", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) n1.SetDeadline(time.Now().Add(-time.Minute)) n2.SetDeadline(time.Now().Add(-time.Minute)) // No panic we're good. }) t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) select { case err := <-werr: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("HTTPClient.Timeout", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ HTTPClient: &http.Client{Timeout: time.Second * 5}, }, nil) tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) select { case err := <-werr: assert.Success(t, err) case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) t.Run("CloseNow", func(t *testing.T) { _, c1, c2 := newConnTest(t, nil, nil) err1 := c1.CloseNow() err2 := c2.CloseNow() assert.Success(t, err1) assert.Success(t, err2) err1 = c1.CloseNow() err2 = c2.CloseNow() assert.ErrorIs(t, websocket.ErrClosed, err1) assert.ErrorIs(t, websocket.ErrClosed, err2) }) t.Run("MidReadClose", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { err := wstest.Echo(tt.ctx, c1, 131072) assert.Success(t, err) } err := wsjson.Write(tt.ctx, c1, "four") assert.Success(t, err) _, _, err = c1.Reader(tt.ctx) assert.Success(t, err) err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) } func TestWasm(t *testing.T) { t.Parallel() if os.Getenv("CI") == "" { t.SkipNow() } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, }) if err != nil { t.Error(err) } })) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } } func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { return fmt.Errorf("expected close status %v but got %v", exp, err) } return nil } type connTest struct { t testing.TB ctx context.Context } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { if t, ok := t.(*testing.T); ok { t.Parallel() } t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} t.Cleanup(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } t.Cleanup(func() { c2.CloseNow() c1.CloseNow() }) return tt, c1, c2 } func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) echoLoopErr := xsync.Go(func() error { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) tt.t.Cleanup(func() { cancel() err := <-echoLoopErr if err != nil { tt.t.Errorf("echo loop error: %v", err) } }) } func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { defer c.Close(websocket.StatusInternalError, "") for { _, _, err := c.Read(ctx) if err != nil { return assertCloseStatus(websocket.StatusNormalClosure, err) } } }) tt.t.Cleanup(func() { cancel() err := <-discardLoopErr if err != nil { tt.t.Errorf("discard loop error: %v", err) } }) } func BenchmarkConn(b *testing.B) { var benchCases = []struct { name string mode websocket.CompressionMode }{ { name: "disabledCompress", mode: websocket.CompressionDisabled, }, { name: "compressContextTakeover", mode: websocket.CompressionContextTakeover, }, { name: "compressNoContext", mode: websocket.CompressionNoContextTakeover, }, } for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ CompressionMode: bc.mode, }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) bb.goEchoLoop(c2) bytesWritten := c1.RecordBytesWritten() bytesRead := c1.RecordBytesRead() msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) defer close(writes) werrs := make(chan error) go func() { for range writes { select { case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg): case <-bb.ctx.Done(): return } } }() b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { select { case writes <- struct{}{}: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } typ, r, err := c1.Reader(bb.ctx) if err != nil { b.Fatal(i, err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) } _, err = io.ReadFull(r, readBuf) if err != nil { b.Fatal(err) } n2, err := r.Read(readBuf) if err != io.EOF { assert.Equal(b, "read err", io.EOF, err) } if n2 != 0 { assert.Equal(b, "n2", 0, n2) } if !bytes.Equal(msg, readBuf) { assert.Equal(b, "msg", msg, readBuf) } select { case err = <-werrs: case <-bb.ctx.Done(): b.Fatal(bb.ctx.Err()) } if err != nil { b.Fatal(err) } } b.StopTimer() b.ReportMetric(float64(*bytesWritten/b.N), "written/op") b.ReportMetric(float64(*bytesRead/b.N), "read/op") err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(b, err) }) } } func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { defer errd.Wrap(&err, "echo server failed") c, err := websocket.Accept(w, r, opts) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") err = wstest.EchoLoop(r.Context(), c) return assertCloseStatus(websocket.StatusNormalClosure, err) } func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) { exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { return wsjson.Write(ctx, c, exp) }) var act interface{} c.SetReadLimit(1 << 30) err := wsjson.Read(ctx, c, &act) assert.Success(tb, err) assert.Equal(tb, "read msg", exp, act) select { case err := <-werr: assert.Success(tb, err) case <-ctx.Done(): tb.Fatal(ctx.Err()) } } func assertClose(tb testing.TB, c *websocket.Conn) { tb.Helper() err := c.Close(websocket.StatusNormalClosure, "") assert.Success(tb, err) } func TestConcurrentClosePing(t *testing.T) { t.Parallel() for i := 0; i < 64; i++ { func() { c1, c2 := wstest.Pipe(nil, nil) defer c1.CloseNow() defer c2.CloseNow() c1.CloseRead(context.Background()) c2.CloseRead(context.Background()) errc := xsync.Go(func() error { for range time.Tick(time.Millisecond) { err := c1.Ping(context.Background()) if err != nil { return err } } panic("unreachable") }) time.Sleep(10 * time.Millisecond) assert.Success(t, c1.Close(websocket.StatusNormalClosure, "")) <-errc }() } } websocket-1.8.12/dial.go000066400000000000000000000213751465546417300150600ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "fmt" "io" "net/http" "net/url" "strings" "sync" "time" "github.com/coder/websocket/internal/errd" ) // DialOptions represents Dial's options. type DialOptions struct { // HTTPClient is used for the connection. // Its Transport must return writable bodies for WebSocket handshakes. // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header // Host optionally overrides the Host HTTP header to send. If empty, the value // of URL.Host will be used. Host string // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode // CompressionThreshold controls the minimum size of a message before compression is applied. // // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { var cancel context.CancelFunc var o DialOptions if opts != nil { o = *opts } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient } if o.HTTPClient.Timeout > 0 { ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) newClient := *o.HTTPClient newClient.Timeout = 0 o.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } newClient := *o.HTTPClient oldCheckRedirect := o.HTTPClient.CheckRedirect newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { switch req.URL.Scheme { case "ws": req.URL.Scheme = "http" case "wss": req.URL.Scheme = "https" } if oldCheckRedirect != nil { return oldCheckRedirect(req, via) } return nil } o.HTTPClient = &newClient return ctx, cancel, &o } // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. // However, you can only read the first 1024 bytes of the body. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // // URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") var cancel context.CancelFunc ctx, cancel, opts = opts.cloneWithDefaults(ctx) if cancel != nil { defer cancel() } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } var copts *compressionOptions if opts.CompressionMode != CompressionDisabled { copts = opts.CompressionMode.opts() } resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } respBody := resp.Body resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() }) defer timer.Stop() b, _ := io.ReadAll(r) respBody.Close() resp.Body = io.NopCloser(bytes.NewReader(b)) } }() copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } rwc, ok := respBody.(io.ReadWriteCloser) if !ok { return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), rwc: rwc, client: true, copts: copts, flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { return nil, fmt.Errorf("failed to create new http request: %w", err) } if len(opts.Host) > 0 { req.Host = opts.Host } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } func secWebSocketKey(rr io.Reader) (string, error) { if rr == nil { rr = rand.Reader } b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) } err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { proto := resp.Header.Get("Sec-WebSocket-Protocol") if proto == "" { return nil } for _, sp2 := range subprotos { if strings.EqualFold(sp2, proto) { return nil } } return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil } ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } _copts := *copts copts = &_copts for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true continue case "server_no_context_takeover": copts.serverNoContextTakeover = true continue } if strings.HasPrefix(p, "server_max_window_bits=") { // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } var bufioReaderPool sync.Pool func getBufioReader(r io.Reader) *bufio.Reader { br, ok := bufioReaderPool.Get().(*bufio.Reader) if !ok { return bufio.NewReader(r) } br.Reset(r) return br } func putBufioReader(br *bufio.Reader) { bufioReaderPool.Put(br) } var bufioWriterPool sync.Pool func getBufioWriter(w io.Writer) *bufio.Writer { bw, ok := bufioWriterPool.Get().(*bufio.Writer) if !ok { return bufio.NewWriter(w) } bw.Reset(w) return bw } func putBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } websocket-1.8.12/dial_test.go000066400000000000000000000232501465546417300161110ustar00rootroot00000000000000//go:build !js // +build !js package websocket_test import ( "bytes" "context" "crypto/rand" "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/coder/websocket" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/util" "github.com/coder/websocket/internal/xsync" ) func TestBadDials(t *testing.T) { t.Parallel() t.Run("badReq", func(t *testing.T) { t.Parallel() testCases := []struct { name string url string opts *websocket.DialOptions rand util.ReaderFunc nilCtx bool }{ { name: "badURL", url: "://noscheme", }, { name: "badURLScheme", url: "ftp://nhooyr.io", }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", }, { name: "badReader", rand: func(p []byte) (int, error) { return 0, io.EOF }, }, { name: "nilContext", url: "http://localhost", nilCtx: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var ctx context.Context var cancel func() if !tc.nilCtx { ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) defer cancel() } if tc.rand == nil { tc.rand = rand.Reader.Read } _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } }) t.Run("badResponse", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(strings.NewReader("hi")), }, nil }), }) assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") }) t.Run("badBody", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() rt := func(r *http.Request) (*http.Response, error) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, Body: io.NopCloser(strings.NewReader("hi")), }, nil } _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } func Test_verifyHostOverride(t *testing.T) { testCases := []struct { name string host string exp string }{ { name: "noOverride", host: "", exp: "example.com", }, { name: "hostOverride", host: "example.net", exp: "example.net", }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() rt := func(r *http.Request) (*http.Response, error) { assert.Equal(t, "Host", tc.exp, r.Host) h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, Body: mockBody{bytes.NewBufferString("hi")}, }, nil } c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), Host: tc.host, }) assert.Success(t, err) c.CloseNow() }) } } type mockBody struct { *bytes.Buffer } func (mb mockBody) Close() error { return nil } func Test_verifyServerHandshake(t *testing.T) { t.Parallel() testCases := []struct { name string response func(w http.ResponseWriter) success bool }{ { name: "badStatus", response: func(w http.ResponseWriter) { w.WriteHeader(http.StatusOK) }, success: false, }, { name: "badConnection", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badUpgrade", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "???") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketAccept", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Accept", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "badSecWebSocketProtocol", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Protocol", "xd") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedExtension", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "unsupportedDeflateParam", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") w.WriteHeader(http.StatusSwitchingProtocols) }, success: false, }, { name: "success", response: func(w http.ResponseWriter) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") w.WriteHeader(http.StatusSwitchingProtocols) }, success: true, }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() tc.response(w) resp := w.Result() r := httptest.NewRequest("GET", "/", nil) key, err := websocket.SecWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) } opts := &websocket.DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) if tc.success { assert.Success(t, err) } else { assert.Error(t, err) } }) } } func mockHTTPClient(fn roundTripperFunc) *http.Client { return &http.Client{ Transport: fn, } } type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } func TestDialRedirect(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { resp := &http.Response{ Header: http.Header{}, } if r.URL.Scheme != "https" { resp.Header.Set("Location", "wss://example.com") resp.StatusCode = http.StatusFound return resp, nil } resp.Header.Set("Connection", "Upgrade") resp.Header.Set("Upgrade", "meow") resp.StatusCode = http.StatusSwitchingProtocols return resp, nil }), }) assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") } type forwardProxy struct { hc *http.Client } func newForwardProxy() *forwardProxy { return &forwardProxy{ hc: &http.Client{}, } } func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() r = r.WithContext(ctx) r.RequestURI = "" resp, err := fc.hc.Do(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer resp.Body.Close() for k, v := range resp.Header { w.Header()[k] = v } w.Header().Set("PROXIED", "true") w.WriteHeader(resp.StatusCode) if resprw, ok := resp.Body.(io.ReadWriter); ok { c, brw, err := w.(http.Hijacker).Hijack() if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } brw.Flush() errc1 := xsync.Go(func() error { _, err := io.Copy(c, resprw) return err }) errc2 := xsync.Go(func() error { _, err := io.Copy(resprw, c) return err }) select { case <-errc1: case <-errc2: case <-r.Context().Done(): } } else { io.Copy(w, resp.Body) } } func TestDialViaProxy(t *testing.T) { t.Parallel() ps := httptest.NewServer(newForwardProxy()) defer ps.Close() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, nil) assert.Success(t, err) })) defer s.Close() psu, err := url.Parse(ps.URL) assert.Success(t, err) proxyTransport := http.DefaultTransport.(*http.Transport).Clone() proxyTransport.Proxy = http.ProxyURL(psu) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ HTTPClient: &http.Client{ Transport: proxyTransport, }, }) assert.Success(t, err) assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) assertEcho(t, ctx, c) assertClose(t, c) } websocket-1.8.12/doc.go000066400000000000000000000020371465546417300147060ustar00rootroot00000000000000//go:build !js // +build !js // Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // // Use Dial to dial a WebSocket server. // // Use Accept to accept a WebSocket client. // // Conn represents the resulting WebSocket connection. // // The examples are the best way to understand how to correctly use the library. // // The wsjson subpackage contain helpers for JSON and protobuf messages. // // More documentation at https://github.com/coder/websocket. // // # Wasm // // The client side supports compiling to Wasm. // It wraps the WebSocket browser API. // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // // Some important caveats to be aware of: // // - Accept always errors out // - Conn.Ping is no-op // - Conn.CloseNow is Close(StatusGoingAway, "") // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "github.com/coder/websocket" websocket-1.8.12/example_test.go000066400000000000000000000075241465546417300166410ustar00rootroot00000000000000package websocket_test import ( "context" "log" "net/http" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) func ExampleAccept() { // This handler accepts a WebSocket connection, reads a single JSON // message from the client and then closes the connection. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() var v interface{} err = wsjson.Read(ctx, c, &v) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func ExampleDial() { // Dials a server, writes a single JSON message and then // closes the connection. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { log.Fatal(err) } c.Close(websocket.StatusNormalClosure, "") } func ExampleCloseStatus() { // Dials a server and then expects to be disconnected with status code // websocket.StatusNormalClosure. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.CloseNow() _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } func Example_writeOnly() { // This handler demonstrates how to correctly handle a write only WebSocket connection. // i.e you only expect to write messages and do not expect to read any messages. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return } defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() ctx = c.CloseRead(ctx) t := time.NewTicker(time.Second * 30) defer t.Stop() for { select { case <-ctx.Done(): c.Close(websocket.StatusNormalClosure, "") return case <-t.C: err = wsjson.Write(ctx, c, "hi") if err != nil { log.Println(err) return } } } }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func Example_crossOrigin() { // This handler demonstrates how to safely accept cross origin WebSockets // from the origin example.com. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err) return } c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") }) err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } func ExampleConn_Ping() { // Dials a server and pings it 5 times. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } defer c.CloseNow() // Required to read the Pongs from the server. ctx = c.CloseRead(ctx) for i := 0; i < 5; i++ { err = c.Ping(ctx) if err != nil { log.Fatal(err) } } c.Close(websocket.StatusNormalClosure, "") } // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { // https://github.com/nhooyr/websocket/tree/master/internal/examples/chat } // This example demonstrates a echo server. func Example_echo() { // https://github.com/nhooyr/websocket/tree/master/internal/examples/echo } websocket-1.8.12/export_test.go000066400000000000000000000013571465546417300165250ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "net" "github.com/coder/websocket/internal/util" ) func (c *Conn) RecordBytesWritten() *int { var bytesWritten int c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { bytesWritten += len(p) return c.rwc.Write(p) })) return &bytesWritten } func (c *Conn) RecordBytesRead() *int { var bytesRead int c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) bytesRead += n return n, err })) return &bytesRead } var ErrClosed = net.ErrClosed var ExportedDial = dial var SecWebSocketAccept = secWebSocketAccept var SecWebSocketKey = secWebSocketKey var VerifyServerResponse = verifyServerResponse var CompressionModeOpts = CompressionMode.opts websocket-1.8.12/frame.go000066400000000000000000000063271465546417300152410ustar00rootroot00000000000000//go:build !js package websocket import ( "bufio" "encoding/binary" "fmt" "io" "math" "github.com/coder/websocket/internal/errd" ) // opcode represents a WebSocket opcode. type opcode int // https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText opBinary // 3 - 7 are reserved for further non-control frames. _ _ _ _ _ opClose opPing opPong // 11-16 are reserved for further control frames. ) // header represents a WebSocket frame header. // See https://tools.ietf.org/html/rfc6455#section-5.2. type header struct { fin bool rsv1 bool rsv2 bool rsv3 bool opcode opcode payloadLength int64 masked bool maskKey uint32 } // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { return header{}, err } h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 h.rsv3 = b&(1<<4) != 0 h.opcode = opcode(b & 0xf) b, err = r.ReadByte() if err != nil { return header{}, err } h.masked = b&(1<<7) != 0 payloadLength := b &^ (1 << 7) switch { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: _, err = io.ReadFull(r, readBuf[:2]) h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: _, err = io.ReadFull(r, readBuf) h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { return header{}, err } if h.payloadLength < 0 { return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { _, err = io.ReadFull(r, readBuf[:4]) if err != nil { return header{}, err } h.maskKey = binary.LittleEndian.Uint32(readBuf) } return h, nil } // maxControlPayload is the maximum length of a control frame payload. // See https://tools.ietf.org/html/rfc6455#section-5.5. const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte if h.fin { b |= 1 << 7 } if h.rsv1 { b |= 1 << 6 } if h.rsv2 { b |= 1 << 5 } if h.rsv3 { b |= 1 << 4 } b |= byte(h.opcode) err = w.WriteByte(b) if err != nil { return err } lengthByte := byte(0) if h.masked { lengthByte |= 1 << 7 } switch { case h.payloadLength > math.MaxUint16: lengthByte |= 127 case h.payloadLength > 125: lengthByte |= 126 case h.payloadLength >= 0: lengthByte |= byte(h.payloadLength) } err = w.WriteByte(lengthByte) if err != nil { return err } switch { case h.payloadLength > math.MaxUint16: binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) _, err = w.Write(buf) case h.payloadLength > 125: binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { binary.LittleEndian.PutUint32(buf, h.maskKey) _, err = w.Write(buf[:4]) if err != nil { return err } } return nil } websocket-1.8.12/frame_test.go000066400000000000000000000033541465546417300162750ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "bytes" "encoding/binary" "math/bits" "math/rand" "strconv" "testing" "time" "github.com/coder/websocket/internal/test/assert" ) func TestHeader(t *testing.T) { t.Parallel() t.Run("lengths", func(t *testing.T) { t.Parallel() lengths := []int{ 124, 125, 126, 127, 65534, 65535, 65536, 65537, } for _, n := range lengths { n := n t.Run(strconv.Itoa(n), func(t *testing.T) { t.Parallel() testHeader(t, header{ payloadLength: int64(n), }) }) } }) t.Run("fuzz", func(t *testing.T) { t.Parallel() r := rand.New(rand.NewSource(time.Now().UnixNano())) randBool := func() bool { return r.Intn(2) == 0 } for i := 0; i < 10000; i++ { h := header{ fin: randBool(), rsv1: randBool(), rsv2: randBool(), rsv3: randBool(), opcode: opcode(r.Intn(16)), masked: randBool(), payloadLength: r.Int63(), } if h.masked { h.maskKey = r.Uint32() } testHeader(t, h) } }) } func testHeader(t *testing.T, h header) { b := &bytes.Buffer{} w := bufio.NewWriter(b) r := bufio.NewReader(b) err := writeFrameHeader(h, w, make([]byte, 8)) assert.Success(t, err) err = w.Flush() assert.Success(t, err) h2, err := readFrameHeader(r, make([]byte, 8)) assert.Success(t, err) assert.Equal(t, "read header", h, h2) } func Test_mask(t *testing.T) { t.Parallel() key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(p, key32) expP := []byte{0, 0, 0, 0x0d, 0x6} assert.Equal(t, "p", expP, p) expKey32 := bits.RotateLeft32(key32, -8) assert.Equal(t, "key32", expKey32, gotKey32) } websocket-1.8.12/go.mod000066400000000000000000000000531465546417300147140ustar00rootroot00000000000000module github.com/coder/websocket go 1.19 websocket-1.8.12/go.sum000066400000000000000000000000001465546417300147310ustar00rootroot00000000000000websocket-1.8.12/internal/000077500000000000000000000000001465546417300154245ustar00rootroot00000000000000websocket-1.8.12/internal/bpool/000077500000000000000000000000001465546417300165375ustar00rootroot00000000000000websocket-1.8.12/internal/bpool/bpool.go000066400000000000000000000005501465546417300202010ustar00rootroot00000000000000package bpool import ( "bytes" "sync" ) var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() if b == nil { return &bytes.Buffer{} } return b.(*bytes.Buffer) } // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() bpool.Put(b) } websocket-1.8.12/internal/errd/000077500000000000000000000000001465546417300163605ustar00rootroot00000000000000websocket-1.8.12/internal/errd/wrap.go000066400000000000000000000005061465546417300176610ustar00rootroot00000000000000package errd import ( "fmt" ) // Wrap wraps err with fmt.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { *err = fmt.Errorf(f+": %w", append(v, *err)...) } } websocket-1.8.12/internal/examples/000077500000000000000000000000001465546417300172425ustar00rootroot00000000000000websocket-1.8.12/internal/examples/README.md000066400000000000000000000001361465546417300205210ustar00rootroot00000000000000# Examples This directory contains more involved examples unsuitable for display with godoc. websocket-1.8.12/internal/examples/chat/000077500000000000000000000000001465546417300201615ustar00rootroot00000000000000websocket-1.8.12/internal/examples/chat/README.md000066400000000000000000000026771465546417300214540ustar00rootroot00000000000000# Chat Example This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket. ```bash $ cd examples/chat $ go run . localhost:0 listening on ws://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser. ![Image of Example](https://i.imgur.com/VwJl9Bh.png) ## Structure The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the DOM with a scrollable div at the top that is populated with new messages as they are broadcast. At the bottom it adds a form to submit messages. The messages are received via the WebSocket `/subscribe` endpoint and published via the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket is so that you can easily publish a message with curl. The server portion is `main.go` and `chat.go` and implements serving the static frontend assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint. The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. There are two automated tests for the server included in `chat_test.go`. The first is a simple one client echo test. It publishes a single message and ensures it's received. The second is a complex concurrency test where 10 clients send 128 unique messages of max 128 bytes concurrently. The test ensures all messages are seen by every client. websocket-1.8.12/internal/examples/chat/chat.go000066400000000000000000000117461465546417300214400ustar00rootroot00000000000000package main import ( "context" "errors" "io" "log" "net" "net/http" "sync" "time" "golang.org/x/time/rate" "github.com/coder/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { // subscriberMessageBuffer controls the max number // of messages that can be queued for a subscriber // before it is kicked. // // Defaults to 16. subscriberMessageBuffer int // publishLimiter controls the rate limit applied to the publish endpoint. // // Defaults to one publish every 100ms with a burst of 8. publishLimiter *rate.Limiter // logf controls where logs are sent. // Defaults to log.Printf. logf func(f string, v ...interface{}) // serveMux routes the various endpoints to the appropriate handler. serveMux http.ServeMux subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } // newChatServer constructs a chatServer with the defaults. func newChatServer() *chatServer { cs := &chatServer{ subscriberMessageBuffer: 16, logf: log.Printf, subscribers: make(map[*subscriber]struct{}), publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), } cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) cs.serveMux.HandleFunc("/publish", cs.publishHandler) return cs } // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. type subscriber struct { msgs chan []byte closeSlow func() } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { err := cs.subscribe(r.Context(), w, r) if errors.Is(err, context.Canceled) { return } if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { return } if err != nil { cs.logf("%v", err) return } } // publishHandler reads the request body with a limit of 8192 bytes and then publishes // the received message. func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } body := http.MaxBytesReader(w, r.Body, 8192) msg, err := io.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } cs.publish(msg) w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. // It creates a subscriber with a buffered msgs chan to give some room to slower // connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { var mu sync.Mutex var c *websocket.Conn var closed bool s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { mu.Lock() defer mu.Unlock() closed = true if c != nil { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") } }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) c2, err := websocket.Accept(w, r, nil) if err != nil { return err } mu.Lock() if closed { mu.Unlock() return net.ErrClosed } c = c2 mu.Unlock() defer c.CloseNow() ctx = c.CloseRead(ctx) for { select { case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } // publish publishes the msg to all subscribers. // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.Lock() defer cs.subscribersMu.Unlock() cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { case s.msgs <- msg: default: go s.closeSlow() } } } // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } // deleteSubscriber deletes the given subscriber. func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() delete(cs.subscribers, s) cs.subscribersMu.Unlock() } func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return c.Write(ctx, websocket.MessageText, msg) } websocket-1.8.12/internal/examples/chat/chat_test.go000066400000000000000000000134721465546417300224750ustar00rootroot00000000000000package main import ( "context" "crypto/rand" "fmt" "math/big" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "golang.org/x/time/rate" "github.com/coder/websocket" ) func Test_chatServer(t *testing.T) { t.Parallel() // This is a simple echo test with a single client. // The client sends a message and ensures it receives // it on its WebSocket. t.Run("simple", func(t *testing.T) { t.Parallel() url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() expMsg := randString(512) err = cl.publish(ctx, expMsg) assertSuccess(t, err) msg, err := cl.nextMessage() assertSuccess(t, err) if expMsg != msg { t.Fatalf("expected %v but got %v", expMsg, msg) } }) // This test is a complex concurrency test. // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // // The test verifies that every message is seen by ever client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() const nmessages = 128 const maxMessageSize = 128 const nclients = 16 url, closeFn := setupTest(t) defer closeFn() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var clients []*client var clientMsgs []map[string]struct{} for i := 0; i < nclients; i++ { cl, err := newClient(ctx, url) assertSuccess(t, err) defer cl.Close() clients = append(clients, cl) clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) } allMessages := make(map[string]struct{}) for _, msgs := range clientMsgs { for m := range msgs { allMessages[m] = struct{}{} } } var wg sync.WaitGroup for i, cl := range clients { i := i cl := cl wg.Add(1) go func() { defer wg.Done() err := cl.publishMsgs(ctx, clientMsgs[i]) if err != nil { t.Errorf("client %d failed to publish all messages: %v", i, err) } }() wg.Add(1) go func() { defer wg.Done() err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) if err != nil { t.Errorf("client %d failed to receive all messages: %v", i, err) } }() } wg.Wait() }) } // setupTest sets up chatServer that can be used // via the returned url. // // Defer closeFn to ensure everything is cleaned up at // the end of the test. // // chatServer logs will be logged via t.Logf. func setupTest(t *testing.T) (url string, closeFn func()) { cs := newChatServer() cs.logf = t.Logf // To ensure tests run quickly under even -race. cs.subscriberMessageBuffer = 4096 cs.publishLimiter.SetLimit(rate.Inf) s := httptest.NewServer(cs) return s.URL, func() { s.Close() } } // testAllMessagesReceived ensures that after n reads, all msgs in msgs // have been read. func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { msgs = cloneMessages(msgs) for i := 0; i < n; i++ { msg, err := cl.nextMessage() if err != nil { return err } delete(msgs, msg) } if len(msgs) != 0 { return fmt.Errorf("did not receive all expected messages: %q", msgs) } return nil } func cloneMessages(msgs map[string]struct{}) map[string]struct{} { msgs2 := make(map[string]struct{}, len(msgs)) for m := range msgs { msgs2[m] = struct{}{} } return msgs2 } func randMessages(n, maxMessageLength int) map[string]struct{} { msgs := make(map[string]struct{}) for i := 0; i < n; i++ { m := randString(randInt(maxMessageLength)) if _, ok := msgs[m]; ok { i-- continue } msgs[m] = struct{}{} } return msgs } func assertSuccess(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } type client struct { url string c *websocket.Conn } func newClient(ctx context.Context, url string) (*client, error) { c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) if err != nil { return nil, err } cl := &client{ url: url, c: c, } return cl, nil } func (cl *client) publish(ctx context.Context, msg string) (err error) { defer func() { if err != nil { cl.c.Close(websocket.StatusInternalError, "publish failed") } }() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusAccepted { return fmt.Errorf("publish request failed: %v", resp.StatusCode) } return nil } func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { for m := range msgs { err := cl.publish(ctx, m) if err != nil { return err } } return nil } func (cl *client) nextMessage() (string, error) { typ, b, err := cl.c.Read(context.Background()) if err != nil { return "", err } if typ != websocket.MessageText { cl.c.Close(websocket.StatusUnsupportedData, "expected text message") return "", fmt.Errorf("expected text message but got %v", typ) } return string(b), nil } func (cl *client) Close() error { return cl.c.Close(websocket.StatusNormalClosure, "") } // randString generates a random string with length n. func randString(n int) string { b := make([]byte, n) _, err := rand.Reader.Read(b) if err != nil { panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) } s := strings.ToValidUTF8(string(b), "_") s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } if len(s) < n { // Pad with = extra := n - len(s) return s + strings.Repeat("=", extra) } return s } // randInt returns a randomly generated integer between [0, max). func randInt(max int) int { x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { panic(fmt.Sprintf("failed to get random int: %v", err)) } return int(x.Int64()) } websocket-1.8.12/internal/examples/chat/index.css000066400000000000000000000022431465546417300220030ustar00rootroot00000000000000body { width: 100vw; min-width: 320px; } #root { padding: 40px 20px; max-width: 600px; margin: auto; height: 100vh; display: flex; flex-direction: column; align-items: center; justify-content: center; } #root > * + * { margin: 20px 0 0 0; } /* 100vh on safari does not include the bottom bar. */ @supports (-webkit-overflow-scrolling: touch) { #root { height: 85vh; } } #message-log { width: 100%; flex-grow: 1; overflow: auto; } #message-log p:first-child { margin: 0; } #message-log > * + * { margin: 10px 0 0 0; } #publish-form-container { width: 100%; } #publish-form { width: 100%; display: flex; height: 40px; } #publish-form > * + * { margin: 0 0 0 10px; } #publish-form input[type='text'] { flex-grow: 1; -moz-appearance: none; -webkit-appearance: none; word-break: normal; border-radius: 5px; border: 1px solid #ccc; } #publish-form input[type='submit'] { color: white; background-color: black; border-radius: 5px; padding: 5px 10px; border: none; } #publish-form input[type='submit']:hover { background-color: red; } #publish-form input[type='submit']:active { background-color: red; } websocket-1.8.12/internal/examples/chat/index.html000066400000000000000000000015241465546417300221600ustar00rootroot00000000000000 github.com/coder/websocket - Chat Example
websocket-1.8.12/internal/examples/chat/index.js000066400000000000000000000042701465546417300216310ustar00rootroot00000000000000;(() => { // expectingMessage is set to true // if the user has just submitted a message // and so we should scroll the next message into view when received. let expectingMessage = false function dial() { const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener('close', ev => { appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) if (ev.code !== 1001) { appendLog('Reconnecting in 1s', true) setTimeout(dial, 1000) } }) conn.addEventListener('open', ev => { console.info('websocket connected') }) // This is where we handle messages received. conn.addEventListener('message', ev => { if (typeof ev.data !== 'string') { console.error('unexpected message type', typeof ev.data) return } const p = appendLog(ev.data) if (expectingMessage) { p.scrollIntoView() expectingMessage = false } }) } dial() const messageLog = document.getElementById('message-log') const publishForm = document.getElementById('publish-form') const messageInput = document.getElementById('message-input') // appendLog appends the passed text to messageLog. function appendLog(text, error) { const p = document.createElement('p') // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` if (error) { p.style.color = 'red' p.style.fontStyle = 'bold' } messageLog.append(p) return p } appendLog('Submit a message to get started!') // onsubmit publishes the message from the user when the form is submitted. publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value if (msg === '') { return } messageInput.value = '' expectingMessage = true try { const resp = await fetch('/publish', { method: 'POST', body: msg, }) if (resp.status !== 202) { throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) } } catch (err) { appendLog(`Publish failed: ${err.message}`, true) } } })() websocket-1.8.12/internal/examples/chat/main.go000066400000000000000000000020351465546417300214340ustar00rootroot00000000000000package main import ( "context" "errors" "log" "net" "net/http" "os" "os/signal" "time" ) func main() { log.SetFlags(0) err := run() if err != nil { log.Fatal(err) } } // run initializes the chatServer and then // starts a http.Server for the passed in address. func run() error { if len(os.Args) < 2 { return errors.New("please provide an address to listen on as the first argument") } l, err := net.Listen("tcp", os.Args[1]) if err != nil { return err } log.Printf("listening on ws://%v", l.Addr()) cs := newChatServer() s := &http.Server{ Handler: cs, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } errc := make(chan error, 1) go func() { errc <- s.Serve(l) }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) select { case err := <-errc: log.Printf("failed to serve: %v", err) case sig := <-sigs: log.Printf("terminating: %v", sig) } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() return s.Shutdown(ctx) } websocket-1.8.12/internal/examples/echo/000077500000000000000000000000001465546417300201605ustar00rootroot00000000000000websocket-1.8.12/internal/examples/echo/README.md000066400000000000000000000012261465546417300214400ustar00rootroot00000000000000# Echo Example This directory contains a echo server example using github.com/coder/websocket. ```bash $ cd examples/echo $ go run . localhost:0 listening on ws://127.0.0.1:51055 ``` You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages written will be echoed back. ## Structure The server is in `server.go` and is implemented as a `http.HandlerFunc` that accepts the WebSocket and then reads all messages and writes them exactly as is back to the connection. `server_test.go` contains a small unit test to verify it works correctly. `main.go` brings it all together so that you can run it and play around with it. websocket-1.8.12/internal/examples/echo/main.go000066400000000000000000000020471465546417300214360ustar00rootroot00000000000000package main import ( "context" "errors" "log" "net" "net/http" "os" "os/signal" "time" ) func main() { log.SetFlags(0) err := run() if err != nil { log.Fatal(err) } } // run starts a http.Server for the passed in address // with all requests handled by echoServer. func run() error { if len(os.Args) < 2 { return errors.New("please provide an address to listen on as the first argument") } l, err := net.Listen("tcp", os.Args[1]) if err != nil { return err } log.Printf("listening on ws://%v", l.Addr()) s := &http.Server{ Handler: echoServer{ logf: log.Printf, }, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } errc := make(chan error, 1) go func() { errc <- s.Serve(l) }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) select { case err := <-errc: log.Printf("failed to serve: %v", err) case sig := <-sigs: log.Printf("terminating: %v", sig) } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() return s.Shutdown(ctx) } websocket-1.8.12/internal/examples/echo/server.go000066400000000000000000000031721465546417300220200ustar00rootroot00000000000000package main import ( "context" "fmt" "io" "net/http" "time" "golang.org/x/time/rate" "github.com/coder/websocket" ) // echoServer is the WebSocket echo server implementation. // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. type echoServer struct { // logf controls where logs are sent. logf func(f string, v ...interface{}) } func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { s.logf("%v", err) return } defer c.CloseNow() if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") return } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) for { err = echo(r.Context(), c, l) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { return } if err != nil { s.logf("failed to echo with %v: %v", r.RemoteAddr, err) return } } } // echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() err := l.Wait(ctx) if err != nil { return err } typ, r, err := c.Reader(ctx) if err != nil { return err } w, err := c.Writer(ctx, typ) if err != nil { return err } _, err = io.Copy(w, r) if err != nil { return fmt.Errorf("failed to io.Copy: %w", err) } err = w.Close() return err } websocket-1.8.12/internal/examples/echo/server_test.go000066400000000000000000000017751465546417300230660ustar00rootroot00000000000000package main import ( "context" "net/http/httptest" "testing" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) // Test_echoServer tests the echoServer by sending it 5 different messages // and ensuring the responses all match. func Test_echoServer(t *testing.T) { t.Parallel() s := httptest.NewServer(echoServer{ logf: t.Logf, }) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, _, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) if err != nil { t.Fatal(err) } defer c.Close(websocket.StatusInternalError, "the sky is falling") for i := 0; i < 5; i++ { err = wsjson.Write(ctx, c, map[string]int{ "i": i, }) if err != nil { t.Fatal(err) } v := map[string]int{} err = wsjson.Read(ctx, c, &v) if err != nil { t.Fatal(err) } if v["i"] != i { t.Fatalf("expected %v but got %v", i, v) } } c.Close(websocket.StatusNormalClosure, "") } websocket-1.8.12/internal/examples/go.mod000066400000000000000000000003071465546417300203500ustar00rootroot00000000000000module github.com/coder/websocket/examples go 1.19 replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 golang.org/x/time v0.3.0 ) websocket-1.8.12/internal/examples/go.sum000066400000000000000000000002311465546417300203710ustar00rootroot00000000000000golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= websocket-1.8.12/internal/test/000077500000000000000000000000001465546417300164035ustar00rootroot00000000000000websocket-1.8.12/internal/test/assert/000077500000000000000000000000001465546417300177045ustar00rootroot00000000000000websocket-1.8.12/internal/test/assert/assert.go000066400000000000000000000016661465546417300215450ustar00rootroot00000000000000package assert import ( "errors" "fmt" "reflect" "strings" "testing" ) // Equal asserts exp == act. func Equal(t testing.TB, name string, exp, got interface{}) { t.Helper() if !reflect.DeepEqual(exp, got) { t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got) } } // Success asserts err == nil. func Success(t testing.TB, err error) { t.Helper() if err != nil { t.Fatal(err) } } // Error asserts err != nil. func Error(t testing.TB, err error) { t.Helper() if err == nil { t.Fatal("expected error") } } // Contains asserts the fmt.Sprint(v) contains sub. func Contains(t testing.TB, v interface{}, sub string) { t.Helper() s := fmt.Sprint(v) if !strings.Contains(s, sub) { t.Fatalf("expected %q to contain %q", s, sub) } } // ErrorIs asserts errors.Is(got, exp) func ErrorIs(t testing.TB, exp, got error) { t.Helper() if !errors.Is(got, exp) { t.Fatalf("expected %v but got %v", exp, got) } } websocket-1.8.12/internal/test/doc.go000066400000000000000000000001061465546417300174740ustar00rootroot00000000000000// Package test contains subpackages only used in tests. package test websocket-1.8.12/internal/test/wstest/000077500000000000000000000000001465546417300177345ustar00rootroot00000000000000websocket-1.8.12/internal/test/wstest/echo.go000066400000000000000000000032041465546417300212000ustar00rootroot00000000000000package wstest import ( "bytes" "context" "fmt" "io" "time" "github.com/coder/websocket" "github.com/coder/websocket/internal/test/xrand" "github.com/coder/websocket/internal/xsync" ) // EchoLoop echos every msg received from c until an error // occurs or the context expires. // The read limit is set to 1 << 30. func EchoLoop(ctx context.Context, c *websocket.Conn) error { defer c.Close(websocket.StatusInternalError, "") c.SetReadLimit(1 << 30) ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() b := make([]byte, 32<<10) for { typ, r, err := c.Reader(ctx) if err != nil { return err } w, err := c.Writer(ctx, typ) if err != nil { return err } _, err = io.CopyBuffer(w, r, b) if err != nil { return err } err = w.Close() if err != nil { return err } } } // Echo writes a message and ensures the same is sent back on c. func Echo(ctx context.Context, c *websocket.Conn, max int) error { expType := websocket.MessageBinary if xrand.Bool() { expType = websocket.MessageText } msg := randMessage(expType, xrand.Int(max)) writeErr := xsync.Go(func() error { return c.Write(ctx, expType, msg) }) actType, act, err := c.Read(ctx) if err != nil { return err } err = <-writeErr if err != nil { return err } if expType != actType { return fmt.Errorf("unexpected message typ (%v): %v", expType, actType) } if !bytes.Equal(msg, act) { return fmt.Errorf("unexpected msg read: %#v", act) } return nil } func randMessage(typ websocket.MessageType, n int) []byte { if typ == websocket.MessageBinary { return xrand.Bytes(n) } return []byte(xrand.String(n)) } websocket-1.8.12/internal/test/wstest/pipe.go000066400000000000000000000027731465546417300212310ustar00rootroot00000000000000//go:build !js // +build !js package wstest import ( "bufio" "context" "net" "net/http" "net/http/httptest" "github.com/coder/websocket" ) // Pipe is used to create an in memory connection // between two websockets analogous to net.Pipe. func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) { tt := fakeTransport{ h: func(w http.ResponseWriter, r *http.Request) { serverConn, _ = websocket.Accept(w, r, acceptOpts) }, } if dialOpts == nil { dialOpts = &websocket.DialOptions{} } _dialOpts := *dialOpts dialOpts = &_dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts) return clientConn, serverConn } type fakeTransport struct { h http.HandlerFunc } func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { clientConn, serverConn := net.Pipe() hj := testHijacker{ ResponseRecorder: httptest.NewRecorder(), serverConn: serverConn, } t.h.ServeHTTP(hj, r) resp := hj.ResponseRecorder.Result() if resp.StatusCode == http.StatusSwitchingProtocols { resp.Body = clientConn } return resp, nil } type testHijacker struct { *httptest.ResponseRecorder serverConn net.Conn } var _ http.Hijacker = testHijacker{} func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil } websocket-1.8.12/internal/test/xrand/000077500000000000000000000000001465546417300175175ustar00rootroot00000000000000websocket-1.8.12/internal/test/xrand/xrand.go000066400000000000000000000021271465546417300211640ustar00rootroot00000000000000package xrand import ( "crypto/rand" "encoding/base64" "fmt" "math/big" "strings" ) // Bytes generates random bytes with length n. func Bytes(n int) []byte { b := make([]byte, n) _, err := rand.Reader.Read(b) if err != nil { panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) } return b } // String generates a random string with length n. func String(n int) string { s := strings.ToValidUTF8(string(Bytes(n)), "_") s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } if len(s) < n { // Pad with = extra := n - len(s) return s + strings.Repeat("=", extra) } return s } // Bool returns a randomly generated boolean. func Bool() bool { return Int(2) == 1 } // Int returns a randomly generated integer between [0, max). func Int(max int) int { x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { panic(fmt.Sprintf("failed to get random int: %v", err)) } return int(x.Int64()) } // Base64 returns a randomly generated base64 string of length n. func Base64(n int) string { return base64.StdEncoding.EncodeToString(Bytes(n)) } websocket-1.8.12/internal/thirdparty/000077500000000000000000000000001465546417300176165ustar00rootroot00000000000000websocket-1.8.12/internal/thirdparty/doc.go000066400000000000000000000001241465546417300207070ustar00rootroot00000000000000// Package thirdparty contains third party benchmarks and tests. package thirdparty websocket-1.8.12/internal/thirdparty/frame_test.go000066400000000000000000000046461465546417300223100ustar00rootroot00000000000000package thirdparty import ( "encoding/binary" "runtime" "strconv" "testing" _ "unsafe" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" _ "github.com/lesismal/nbio/nbhttp/websocket" _ "github.com/coder/websocket" ) func basicMask(b []byte, maskKey [4]byte, pos int) int { for i := range b { b[i] ^= maskKey[pos&3] pos++ } return pos & 3 } //go:linkname maskGo github.com/coder/websocket.maskGo func maskGo(b []byte, key32 uint32) int //go:linkname maskAsm github.com/coder/websocket.maskAsm func maskAsm(b *byte, len int, key32 uint32) uint32 //go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR func nbioMaskBytes(b, key []byte) int //go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes func gorillaMaskBytes(key [4]byte, pos int, b []byte) int func Benchmark_mask(b *testing.B) { b.Run(runtime.GOARCH, benchmark_mask) } func benchmark_mask(b *testing.B) { sizes := []int{ 8, 16, 32, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, } fns := []struct { name string fn func(b *testing.B, key [4]byte, p []byte) }{ { name: "basic", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { basicMask(p, key, 0) } }, }, { name: "nhooyr-go", fn: func(b *testing.B, key [4]byte, p []byte) { key32 := binary.LittleEndian.Uint32(key[:]) b.ResetTimer() for i := 0; i < b.N; i++ { maskGo(p, key32) } }, }, { name: "wdvxdr1123-asm", fn: func(b *testing.B, key [4]byte, p []byte) { key32 := binary.LittleEndian.Uint32(key[:]) b.ResetTimer() for i := 0; i < b.N; i++ { maskAsm(&p[0], len(p), key32) } }, }, { name: "gorilla", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { gorillaMaskBytes(key, 0, p) } }, }, { name: "gobwas", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { ws.Cipher(p, key, 0) } }, }, { name: "nbio", fn: func(b *testing.B, key [4]byte, p []byte) { keyb := key[:] for i := 0; i < b.N; i++ { nbioMaskBytes(p, keyb) } }, }, } key := [4]byte{1, 2, 3, 4} for _, fn := range fns { b.Run(fn.name, func(b *testing.B) { for _, size := range sizes { p := make([]byte, size) b.Run(strconv.Itoa(size), func(b *testing.B) { b.SetBytes(int64(size)) fn.fn(b, key, p) }) } }) } } websocket-1.8.12/internal/thirdparty/gin_test.go000066400000000000000000000033501465546417300217620ustar00rootroot00000000000000package thirdparty import ( "context" "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/coder/websocket" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/wstest" "github.com/coder/websocket/wsjson" ) func TestGin(t *testing.T) { t.Parallel() gin.SetMode(gin.ReleaseMode) r := gin.New() r.GET("/", func(ginCtx *gin.Context) { err := echoServer(ginCtx.Writer, ginCtx.Request, nil) if err != nil { t.Error(err) } }) s := httptest.NewServer(r) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() c, _, err := websocket.Dial(ctx, s.URL, nil) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") err = wsjson.Write(ctx, c, "hello") assert.Success(t, err) var v interface{} err = wsjson.Read(ctx, c, &v) assert.Success(t, err) assert.Equal(t, "read msg", "hello", v) err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { defer errd.Wrap(&err, "echo server failed") c, err := websocket.Accept(w, r, opts) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") err = wstest.EchoLoop(r.Context(), c) return assertCloseStatus(websocket.StatusNormalClosure, err) } func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { return fmt.Errorf("expected close status %v but got %v", exp, err) } return nil } websocket-1.8.12/internal/thirdparty/go.mod000066400000000000000000000032201465546417300207210ustar00rootroot00000000000000module github.com/coder/websocket/internal/thirdparty go 1.19 replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 github.com/gin-gonic/gin v1.9.1 github.com/gobwas/ws v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/lesismal/nbio v1.3.18 ) require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/lesismal/llib v1.1.12 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.9.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) websocket-1.8.12/internal/thirdparty/go.sum000066400000000000000000000271411465546417300207560ustar00rootroot00000000000000github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/lesismal/llib v1.1.12 h1:KJFB8bL02V+QGIvILEw/w7s6bKj9Ps9Px97MZP2EOk0= github.com/lesismal/llib v1.1.12/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= github.com/lesismal/nbio v1.3.18 h1:kmJZlxjQpVfuCPYcXdv0Biv9LHVViJZet5K99Xs3RAs= github.com/lesismal/nbio v1.3.18/go.mod h1:KWlouFT5cgDdW5sMX8RsHASUMGniea9X0XIellZ0B38= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= websocket-1.8.12/internal/util/000077500000000000000000000000001465546417300164015ustar00rootroot00000000000000websocket-1.8.12/internal/util/util.go000066400000000000000000000005321465546417300177050ustar00rootroot00000000000000package util // WriterFunc is used to implement one off io.Writers. type WriterFunc func(p []byte) (int, error) func (f WriterFunc) Write(p []byte) (int, error) { return f(p) } // ReaderFunc is used to implement one off io.Readers. type ReaderFunc func(p []byte) (int, error) func (f ReaderFunc) Read(p []byte) (int, error) { return f(p) } websocket-1.8.12/internal/wsjs/000077500000000000000000000000001465546417300164125ustar00rootroot00000000000000websocket-1.8.12/internal/wsjs/wsjs_js.go000066400000000000000000000077241465546417300204350ustar00rootroot00000000000000//go:build js // +build js // Package wsjs implements typed access to the browser javascript WebSocket API. // // https://developer.mozilla.org/en-US/docs/Web/API/WebSocket package wsjs import ( "syscall/js" ) func handleJSError(err *error, onErr func()) { r := recover() if jsErr, ok := r.(js.Error); ok { *err = jsErr if onErr != nil { onErr() } return } if r != nil { panic(r) } } // New is a wrapper around the javascript WebSocket constructor. func New(url string, protocols []string) (c WebSocket, err error) { defer handleJSError(&err, func() { c = WebSocket{} }) jsProtocols := make([]interface{}, len(protocols)) for i, p := range protocols { jsProtocols[i] = p } c = WebSocket{ v: js.Global().Get("WebSocket").New(url, jsProtocols), } c.setBinaryType("arraybuffer") return c, nil } // WebSocket is a wrapper around a javascript WebSocket object. type WebSocket struct { v js.Value } func (c WebSocket) setBinaryType(typ string) { c.v.Set("binaryType", string(typ)) } func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { fn(args[0]) return nil }) c.v.Call("addEventListener", eventType, f) return func() { c.v.Call("removeEventListener", eventType, f) f.Release() } } // CloseEvent is the type passed to a WebSocket close handler. type CloseEvent struct { Code uint16 Reason string WasClean bool } // OnClose registers a function to be called when the WebSocket is closed. func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { return c.addEventListener("close", func(e js.Value) { ce := CloseEvent{ Code: uint16(e.Get("code").Int()), Reason: e.Get("reason").String(), WasClean: e.Get("wasClean").Bool(), } fn(ce) }) } // OnError registers a function to be called when there is an error // with the WebSocket. func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { return c.addEventListener("error", fn) } // MessageEvent is the type passed to a message handler. type MessageEvent struct { // string or []byte. Data interface{} // There are more fields to the interface but we don't use them. // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent } // OnMessage registers a function to be called when the WebSocket receives a message. func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { return c.addEventListener("message", func(e js.Value) { var data interface{} arrayBuffer := e.Get("data") if arrayBuffer.Type() == js.TypeString { data = arrayBuffer.String() } else { data = extractArrayBuffer(arrayBuffer) } me := MessageEvent{ Data: data, } fn(me) }) } // Subprotocol returns the WebSocket subprotocol in use. func (c WebSocket) Subprotocol() string { return c.v.Get("protocol").String() } // OnOpen registers a function to be called when the WebSocket is opened. func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { return c.addEventListener("open", fn) } // Close closes the WebSocket with the given code and reason. func (c WebSocket) Close(code int, reason string) (err error) { defer handleJSError(&err, nil) c.v.Call("close", code, reason) return err } // SendText sends the given string as a text message // on the WebSocket. func (c WebSocket) SendText(v string) (err error) { defer handleJSError(&err, nil) c.v.Call("send", v) return err } // SendBytes sends the given message as a binary message // on the WebSocket. func (c WebSocket) SendBytes(v []byte) (err error) { defer handleJSError(&err, nil) c.v.Call("send", uint8Array(v)) return err } func extractArrayBuffer(arrayBuffer js.Value) []byte { uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer) dst := make([]byte, uint8Array.Length()) js.CopyBytesToGo(dst, uint8Array) return dst } func uint8Array(src []byte) js.Value { uint8Array := js.Global().Get("Uint8Array").New(len(src)) js.CopyBytesToJS(uint8Array, src) return uint8Array } websocket-1.8.12/internal/xsync/000077500000000000000000000000001465546417300165705ustar00rootroot00000000000000websocket-1.8.12/internal/xsync/go.go000066400000000000000000000006441465546417300175300ustar00rootroot00000000000000package xsync import ( "fmt" "runtime/debug" ) // Go allows running a function in another goroutine // and waiting for its error. func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { r := recover() if r != nil { select { case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): default: } } }() errs <- fn() }() return errs } websocket-1.8.12/internal/xsync/go_test.go000066400000000000000000000003601465546417300205620ustar00rootroot00000000000000package xsync import ( "testing" "github.com/coder/websocket/internal/test/assert" ) func TestGoRecover(t *testing.T) { t.Parallel() errs := Go(func() error { panic("anmol") }) err := <-errs assert.Contains(t, err, "anmol") } websocket-1.8.12/internal/xsync/int64.go000066400000000000000000000006301465546417300200620ustar00rootroot00000000000000package xsync import ( "sync/atomic" ) // Int64 represents an atomic int64. type Int64 struct { // We do not use atomic.Load/StoreInt64 since it does not // work on 32 bit computers but we need 64 bit integers. i atomic.Value } // Load loads the int64. func (v *Int64) Load() int64 { i, _ := v.i.Load().(int64) return i } // Store stores the int64. func (v *Int64) Store(i int64) { v.i.Store(i) } websocket-1.8.12/main_test.go000066400000000000000000000011401465546417300161160ustar00rootroot00000000000000package websocket_test import ( "fmt" "os" "runtime" "testing" ) func goroutineStacks() []byte { buf := make([]byte, 512) for { m := runtime.Stack(buf, true) if m < len(buf) { return buf[:m] } buf = make([]byte, len(buf)*2) } } func TestMain(m *testing.M) { code := m.Run() if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 || runtime.GOOS == "js" && runtime.NumGoroutine() != 2 { fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine()) fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks()) os.Exit(1) } os.Exit(code) } websocket-1.8.12/make.sh000077500000000000000000000002601465546417300150620ustar00rootroot00000000000000#!/bin/sh set -eu cd -- "$(dirname "$0")" echo "=== fmt.sh" ./ci/fmt.sh echo "=== lint.sh" ./ci/lint.sh echo "=== test.sh" ./ci/test.sh "$@" echo "=== bench.sh" ./ci/bench.sh websocket-1.8.12/mask.go000066400000000000000000000103411465546417300150710ustar00rootroot00000000000000package websocket import ( "encoding/binary" "math/bits" ) // maskGo applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // // The returned value is the correctly rotated key to // to continue to mask/unmask the message. // // It is optimized for LittleEndian and expects the key // to be in little endian. // // See https://github.com/golang/go/issues/31586 func maskGo(b []byte, key uint32) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) // At some point in the future we can clean these unrolled loops up. // See https://github.com/golang/go/issues/31586#issuecomment-487436401 // Then we xor until b is less than 128 bytes. for len(b) >= 128 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) v = binary.LittleEndian.Uint64(b[64:72]) binary.LittleEndian.PutUint64(b[64:72], v^key64) v = binary.LittleEndian.Uint64(b[72:80]) binary.LittleEndian.PutUint64(b[72:80], v^key64) v = binary.LittleEndian.Uint64(b[80:88]) binary.LittleEndian.PutUint64(b[80:88], v^key64) v = binary.LittleEndian.Uint64(b[88:96]) binary.LittleEndian.PutUint64(b[88:96], v^key64) v = binary.LittleEndian.Uint64(b[96:104]) binary.LittleEndian.PutUint64(b[96:104], v^key64) v = binary.LittleEndian.Uint64(b[104:112]) binary.LittleEndian.PutUint64(b[104:112], v^key64) v = binary.LittleEndian.Uint64(b[112:120]) binary.LittleEndian.PutUint64(b[112:120], v^key64) v = binary.LittleEndian.Uint64(b[120:128]) binary.LittleEndian.PutUint64(b[120:128], v^key64) b = b[128:] } // Then we xor until b is less than 64 bytes. for len(b) >= 64 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) v = binary.LittleEndian.Uint64(b[32:40]) binary.LittleEndian.PutUint64(b[32:40], v^key64) v = binary.LittleEndian.Uint64(b[40:48]) binary.LittleEndian.PutUint64(b[40:48], v^key64) v = binary.LittleEndian.Uint64(b[48:56]) binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) b = b[64:] } // Then we xor until b is less than 32 bytes. for len(b) >= 32 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) binary.LittleEndian.PutUint64(b[16:24], v^key64) v = binary.LittleEndian.Uint64(b[24:32]) binary.LittleEndian.PutUint64(b[24:32], v^key64) b = b[32:] } // Then we xor until b is less than 16 bytes. for len(b) >= 16 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) b = b[16:] } // Then we xor until b is less than 8 bytes. for len(b) >= 8 { v := binary.LittleEndian.Uint64(b) binary.LittleEndian.PutUint64(b, v^key64) b = b[8:] } } // Then we xor until b is less than 4 bytes. for len(b) >= 4 { v := binary.LittleEndian.Uint32(b) binary.LittleEndian.PutUint32(b, v^key) b = b[4:] } // xor remaining bytes. for i := range b { b[i] ^= byte(key) key = bits.RotateLeft32(key, -8) } return key } websocket-1.8.12/mask_amd64.s000066400000000000000000000037231465546417300157270ustar00rootroot00000000000000#include "textflag.h" // func maskAsm(b *byte, len int, key uint32) TEXT ·maskAsm(SB), NOSPLIT, $0-28 // AX = b // CX = len (left length) // SI = key (uint32) // DI = uint64(SI) | uint64(SI)<<32 MOVQ b+0(FP), AX MOVQ len+8(FP), CX MOVL key+16(FP), SI // calculate the DI // DI = SI<<32 | SI MOVL SI, DI MOVQ DI, DX SHLQ $32, DI ORQ DX, DI CMPQ CX, $15 JLE less_than_16 CMPQ CX, $63 JLE less_than_64 CMPQ CX, $128 JLE sse TESTQ $31, AX JNZ unaligned unaligned_loop_1byte: XORB SI, (AX) INCQ AX DECQ CX ROLL $24, SI TESTQ $7, AX JNZ unaligned_loop_1byte // calculate DI again since SI was modified // DI = SI<<32 | SI MOVL SI, DI MOVQ DI, DX SHLQ $32, DI ORQ DX, DI TESTQ $31, AX JZ sse unaligned: TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. JNZ unaligned_loop_1byte unaligned_loop: // we don't need to check the CX since we know it's above 128 XORQ DI, (AX) ADDQ $8, AX SUBQ $8, CX TESTQ $31, AX JNZ unaligned_loop JMP sse sse: CMPQ CX, $0x40 JL less_than_64 MOVQ DI, X0 PUNPCKLQDQ X0, X0 sse_loop: MOVOU 0*16(AX), X1 MOVOU 1*16(AX), X2 MOVOU 2*16(AX), X3 MOVOU 3*16(AX), X4 PXOR X0, X1 PXOR X0, X2 PXOR X0, X3 PXOR X0, X4 MOVOU X1, 0*16(AX) MOVOU X2, 1*16(AX) MOVOU X3, 2*16(AX) MOVOU X4, 3*16(AX) ADDQ $0x40, AX SUBQ $0x40, CX CMPQ CX, $0x40 JAE sse_loop less_than_64: TESTQ $32, CX JZ less_than_32 XORQ DI, (AX) XORQ DI, 8(AX) XORQ DI, 16(AX) XORQ DI, 24(AX) ADDQ $32, AX less_than_32: TESTQ $16, CX JZ less_than_16 XORQ DI, (AX) XORQ DI, 8(AX) ADDQ $16, AX less_than_16: TESTQ $8, CX JZ less_than_8 XORQ DI, (AX) ADDQ $8, AX less_than_8: TESTQ $4, CX JZ less_than_4 XORL SI, (AX) ADDQ $4, AX less_than_4: TESTQ $2, CX JZ less_than_2 XORW SI, (AX) ROLL $16, SI ADDQ $2, AX less_than_2: TESTQ $1, CX JZ done XORB SI, (AX) ROLL $24, SI done: MOVL SI, ret+24(FP) RET websocket-1.8.12/mask_arm64.s000066400000000000000000000026051465546417300157430ustar00rootroot00000000000000#include "textflag.h" // func maskAsm(b *byte, len int, key uint32) TEXT ·maskAsm(SB), NOSPLIT, $0-28 // R0 = b // R1 = len // R3 = key (uint32) // R2 = uint64(key)<<32 | uint64(key) MOVD b_ptr+0(FP), R0 MOVD b_len+8(FP), R1 MOVWU key+16(FP), R3 MOVD R3, R2 ORR R2<<32, R2, R2 VDUP R2, V0.D2 CMP $64, R1 BLT less_than_64 loop_64: VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] VEOR V1.B16, V0.B16, V1.B16 VEOR V2.B16, V0.B16, V2.B16 VEOR V3.B16, V0.B16, V3.B16 VEOR V4.B16, V0.B16, V4.B16 VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) SUBS $64, R1 CMP $64, R1 BGE loop_64 less_than_64: CBZ R1, end TBZ $5, R1, less_than_32 VLD1 (R0), [V1.B16, V2.B16] VEOR V1.B16, V0.B16, V1.B16 VEOR V2.B16, V0.B16, V2.B16 VST1.P [V1.B16, V2.B16], 32(R0) less_than_32: TBZ $4, R1, less_than_16 LDP (R0), (R11, R12) EOR R11, R2, R11 EOR R12, R2, R12 STP.P (R11, R12), 16(R0) less_than_16: TBZ $3, R1, less_than_8 MOVD (R0), R11 EOR R2, R11, R11 MOVD.P R11, 8(R0) less_than_8: TBZ $2, R1, less_than_4 MOVWU (R0), R11 EORW R2, R11, R11 MOVWU.P R11, 4(R0) less_than_4: TBZ $1, R1, less_than_2 MOVHU (R0), R11 EORW R3, R11, R11 MOVHU.P R11, 2(R0) RORW $16, R3 less_than_2: TBZ $0, R1, end MOVBU (R0), R11 EORW R3, R11, R11 MOVBU.P R11, 1(R0) RORW $8, R3 end: MOVWU R3, ret+24(FP) RET websocket-1.8.12/mask_asm.go000066400000000000000000000015521465546417300157350ustar00rootroot00000000000000//go:build amd64 || arm64 package websocket func mask(b []byte, key uint32) uint32 { // TODO: Will enable in v1.9.0. return maskGo(b, key) /* if len(b) > 0 { return maskAsm(&b[0], len(b), key) } return key */ } // @nhooyr: I am not confident that the amd64 or the arm64 implementations of this // function are perfect. There are almost certainly missing optimizations or // opportunities for simplification. I'm confident there are no bugs though. // For example, the arm64 implementation doesn't align memory like the amd64. // Or the amd64 implementation could use AVX512 instead of just AVX2. // The AVX2 code I had to disable anyway as it wasn't performing as expected. // See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 // //go:noescape //lint:ignore U1000 disabled till v1.9.0 func maskAsm(b *byte, len int, key uint32) uint32 websocket-1.8.12/mask_asm_test.go000066400000000000000000000002201465546417300167630ustar00rootroot00000000000000//go:build amd64 || arm64 package websocket import "testing" func TestMaskASM(t *testing.T) { t.Parallel() testMask(t, "maskASM", mask) } websocket-1.8.12/mask_go.go000066400000000000000000000001711465546417300155560ustar00rootroot00000000000000//go:build !amd64 && !arm64 && !js package websocket func mask(b []byte, key uint32) uint32 { return maskGo(b, key) } websocket-1.8.12/mask_test.go000066400000000000000000000026171465546417300161370ustar00rootroot00000000000000package websocket import ( "bytes" "crypto/rand" "encoding/binary" "math/big" "math/bits" "testing" "github.com/coder/websocket/internal/test/assert" ) func basicMask(b []byte, key uint32) uint32 { for i := range b { b[i] ^= byte(key) key = bits.RotateLeft32(key, -8) } return key } func basicMask2(b []byte, key uint32) uint32 { keyb := binary.LittleEndian.AppendUint32(nil, key) pos := 0 for i := range b { b[i] ^= keyb[pos&3] pos++ } return bits.RotateLeft32(key, (pos&3)*-8) } func TestMask(t *testing.T) { t.Parallel() testMask(t, "basicMask", basicMask) testMask(t, "maskGo", maskGo) testMask(t, "basicMask2", basicMask2) } func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) { t.Run(name, func(t *testing.T) { t.Parallel() for i := 0; i < 9999; i++ { keyb := make([]byte, 4) _, err := rand.Read(keyb) assert.Success(t, err) key := binary.LittleEndian.Uint32(keyb) n, err := rand.Int(rand.Reader, big.NewInt(1<<16)) assert.Success(t, err) b := make([]byte, 1+n.Int64()) _, err = rand.Read(b) assert.Success(t, err) b2 := make([]byte, len(b)) copy(b2, b) b3 := make([]byte, len(b)) copy(b3, b) key2 := basicMask(b2, key) key3 := fn(b3, key) if key2 != key3 { t.Errorf("expected key %X but got %X", key2, key3) } if !bytes.Equal(b2, b3) { t.Error("bad bytes") return } } }) } websocket-1.8.12/netconn.go000066400000000000000000000130121465546417300156000ustar00rootroot00000000000000package websocket import ( "context" "fmt" "io" "math" "net" "sync/atomic" "time" ) // NetConn converts a *websocket.Conn into a net.Conn. // // It's for tunneling arbitrary protocols over WebSockets. // Few users of the library will need this but it's tricky to implement // correctly and so provided in the library. // See https://github.com/nhooyr/websocket/issues/100. // // Every Write to the net.Conn will correspond to a message write of // the given type on *websocket.Conn. // // The passed ctx bounds the lifetime of the net.Conn. If cancelled, // all reads and writes on the net.Conn will be cancelled. // // If a message is read that is not of the correct type, the connection // will be closed with StatusUnsupportedData and an error will be returned. // // Close will close the *websocket.Conn with StatusNormalClosure. // // When a deadline is hit and there is an active read or write goroutine, the // connection will be closed. This is different from most net.Conn implementations // where only the reading/writing goroutines are interrupted but the connection // is kept alive. // // The Addr methods will return the real addresses for connections obtained // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the // full net.Conn to us. // // When running as WASM, the Addr methods will always return the mock address described above. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. // // Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { c.SetReadLimit(-1) nc := &netConn{ c: c, msgType: msgType, readMu: newMu(c), writeMu: newMu(c), } nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) nc.readCtx, nc.readCancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { if !nc.writeMu.tryLock() { // If the lock cannot be acquired, then there is an // active write goroutine and so we should cancel the context. nc.writeCancel() return } defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. atomic.StoreInt64(&nc.writeExpired, 1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } nc.readTimer = time.AfterFunc(math.MaxInt64, func() { if !nc.readMu.tryLock() { // If the lock cannot be acquired, then there is an // active read goroutine and so we should cancel the context. nc.readCancel() return } defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. atomic.StoreInt64(&nc.readExpired, 1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C } return nc } type netConn struct { // These must be first to be aligned on 32 bit platforms. // https://github.com/nhooyr/websocket/pull/438 readExpired int64 writeExpired int64 c *Conn msgType MessageType writeTimer *time.Timer writeMu *mu writeCtx context.Context writeCancel context.CancelFunc readTimer *time.Timer readMu *mu readCtx context.Context readCancel context.CancelFunc readEOFed bool reader io.Reader } var _ net.Conn = &netConn{} func (nc *netConn) Close() error { nc.writeTimer.Stop() nc.writeCancel() nc.readTimer.Stop() nc.readCancel() return nc.c.Close(StatusNormalClosure, "") } func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() if atomic.LoadInt64(&nc.writeExpired) == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } func (nc *netConn) Read(p []byte) (int, error) { nc.readMu.forceLock() defer nc.readMu.unlock() for { n, err := nc.read(p) if err != nil { return n, err } if n == 0 { continue } return n, nil } } func (nc *netConn) read(p []byte) (int, error) { if atomic.LoadInt64(&nc.readExpired) == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } if nc.readEOFed { return 0, io.EOF } if nc.reader == nil { typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: nc.readEOFed = true return 0, io.EOF } return 0, err } if typ != nc.msgType { err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } nc.reader = r } n, err := nc.reader.Read(p) if err == io.EOF { nc.reader = nil err = nil } return n, err } type websocketAddr struct { } func (a websocketAddr) Network() string { return "websocket" } func (a websocketAddr) String() string { return "websocket/unknown-addr" } func (nc *netConn) SetDeadline(t time.Time) error { nc.SetWriteDeadline(t) nc.SetReadDeadline(t) return nil } func (nc *netConn) SetWriteDeadline(t time.Time) error { atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { nc.writeTimer.Stop() } else { dur := time.Until(t) if dur <= 0 { dur = 1 } nc.writeTimer.Reset(dur) } return nil } func (nc *netConn) SetReadDeadline(t time.Time) error { atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { nc.readTimer.Stop() } else { dur := time.Until(t) if dur <= 0 { dur = 1 } nc.readTimer.Reset(dur) } return nil } websocket-1.8.12/netconn_js.go000066400000000000000000000002531465546417300162770ustar00rootroot00000000000000package websocket import "net" func (nc *netConn) RemoteAddr() net.Addr { return websocketAddr{} } func (nc *netConn) LocalAddr() net.Addr { return websocketAddr{} } websocket-1.8.12/netconn_notjs.go000066400000000000000000000005241465546417300170210ustar00rootroot00000000000000//go:build !js // +build !js package websocket import "net" func (nc *netConn) RemoteAddr() net.Addr { if unc, ok := nc.c.rwc.(net.Conn); ok { return unc.RemoteAddr() } return websocketAddr{} } func (nc *netConn) LocalAddr() net.Addr { if unc, ok := nc.c.rwc.(net.Conn); ok { return unc.LocalAddr() } return websocketAddr{} } websocket-1.8.12/read.go000066400000000000000000000256011465546417300150560ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "net" "strings" "time" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" "github.com/coder/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. // // If you need a separate timeout on the Reader call and the Read itself, // use time.AfterFunc to cancel the context passed in. // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 // Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err } b, err := io.ReadAll(r) return typ, b, err } // CloseRead starts a goroutine to read from the connection until it is closed // or a data message is received. // // Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // // If a data message is received, the connection will be closed with StatusPolicyViolation. // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. // // This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { c.closeReadMu.Lock() ctx2 := c.closeReadCtx if ctx2 != nil { c.closeReadMu.Unlock() return ctx2 } ctx, cancel := context.WithCancel(ctx) c.closeReadCtx = ctx c.closeReadDone = make(chan struct{}) c.closeReadMu.Unlock() go func() { defer close(c.closeReadDone) defer cancel() defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") } }() return ctx } // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. // // Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { if n >= 0 { // We read one more byte than the limit in case // there is a fin frame that needs to be read. n++ } c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } mr.readFunc = mr.read mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { if mr.dict == nil { mr.dict = &slidingWindow{} } mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } if mr.flateContextTakeover() { mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) } else { mr.flateReader = getFlateReader(mr.flateBufio, nil) } mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil } } func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() if mr.dict != nil { mr.dict.close() mr.dict = nil } if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } if mr.c.client { putBufioReader(mr.c.br) mr.c.br = nil } } func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { return !mr.c.copts.serverNoContextTakeover } return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. if h.opcode != opText && h.opcode != opBinary { return true } return false } func (c *Conn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } } } func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed case c.readTimeout <- ctx: } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: return header{}, net.ErrClosed case <-ctx.Done(): return header{}, ctx.Err() default: return header{}, err } } select { case <-c.closed: return header{}, net.ErrClosed case c.readTimeout <- context.Background(): } return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, net.ErrClosed case c.readTimeout <- ctx: } n, err := io.ReadFull(c.br, p) if err != nil { select { case <-c.closed: return n, net.ErrClosed case <-ctx.Done(): return n, ctx.Err() default: return n, fmt.Errorf("failed to read frame payload: %w", err) } } select { case <-c.closed: return n, net.ErrClosed case c.readTimeout <- context.Background(): } return n, err } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := c.readControlBuf[:h.payloadLength] _, err = c.readFramePayload(ctx, b) if err != nil { return err } if h.masked { mask(b, h.maskKey) } switch h.opcode { case opPing: return c.writeControl(ctx, opPong, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() if ok { select { case pong <- struct{}{}: default: } } return nil } // opClose ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) c.writeClose(ce.Code, ce.Reason) c.readMu.unlock() c.close() return err } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) if err != nil { return 0, nil, err } defer c.readMu.unlock() if !c.msgReader.fin { return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } c.msgReader.reset(ctx, h) return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { c *Conn ctx context.Context flate bool flateReader io.Reader flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict *slidingWindow fin bool payloadLength int64 maskKey uint32 // util.ReaderFunc(mr.Read) to avoid continuous allocations. readFunc util.ReaderFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() } mr.setFrame(h) } func (mr *msgReader) setFrame(h header) { mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey } func (mr *msgReader) Read(p []byte) (n int, err error) { err = mr.c.readMu.lock(mr.ctx) if err != nil { return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() n, err = mr.limitReader.Read(p) if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { mr.putFlateReader() return n, io.EOF } if err != nil { return n, fmt.Errorf("failed to read: %w", err) } return n, nil } func (mr *msgReader) read(p []byte) (int, error) { for { if mr.payloadLength == 0 { if mr.fin { if mr.flate { return mr.flateTail.Read(p) } return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) continue } if int64(len(p)) > mr.payloadLength { p = p[:mr.payloadLength] } n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) if !mr.c.client { mr.maskKey = mask(p, mr.maskKey) } return n, nil } } type limitReader struct { c *Conn r io.Reader limit xsync.Int64 n int64 } func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } lr.limit.Store(limit) lr.reset(r) return lr } func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { if lr.n < 0 { return lr.r.Read(p) } if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } if int64(len(p)) > lr.n { p = p[:lr.n] } n, err := lr.r.Read(p) lr.n -= int64(n) if lr.n < 0 { lr.n = 0 } return n, err } websocket-1.8.12/stringer.go000066400000000000000000000055771465546417300160120ustar00rootroot00000000000000// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. package websocket import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[opContinuation-0] _ = x[opText-1] _ = x[opBinary-2] _ = x[opClose-8] _ = x[opPing-9] _ = x[opPong-10] } const ( _opcode_name_0 = "opContinuationopTextopBinary" _opcode_name_1 = "opCloseopPingopPong" ) var ( _opcode_index_0 = [...]uint8{0, 14, 20, 28} _opcode_index_1 = [...]uint8{0, 7, 13, 19} ) func (i opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] case 8 <= i && i <= 10: i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" } } func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[MessageText-1] _ = x[MessageBinary-2] } const _MessageType_name = "MessageTextMessageBinary" var _MessageType_index = [...]uint8{0, 11, 24} func (i MessageType) String() string { i -= 1 if i < 0 || i >= MessageType(len(_MessageType_index)-1) { return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" } return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] } func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[StatusNormalClosure-1000] _ = x[StatusGoingAway-1001] _ = x[StatusProtocolError-1002] _ = x[StatusUnsupportedData-1003] _ = x[statusReserved-1004] _ = x[StatusNoStatusRcvd-1005] _ = x[StatusAbnormalClosure-1006] _ = x[StatusInvalidFramePayloadData-1007] _ = x[StatusPolicyViolation-1008] _ = x[StatusMessageTooBig-1009] _ = x[StatusMandatoryExtension-1010] _ = x[StatusInternalError-1011] _ = x[StatusServiceRestart-1012] _ = x[StatusTryAgainLater-1013] _ = x[StatusBadGateway-1014] _ = x[StatusTLSHandshake-1015] } const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} func (i StatusCode) String() string { i -= 1000 if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" } return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] } websocket-1.8.12/write.go000066400000000000000000000166751465546417300153100ustar00rootroot00000000000000//go:build !js // +build !js package websocket import ( "bufio" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "net" "time" "compress/flate" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // // You must close the writer once you have written the entire message. // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } return w, nil } // Write writes a message to the connection. // // See the Writer method if you want to stream a message. // // If compression is disabled or the compression threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } type msgWriter struct { c *Conn mu *mu writeMu *mu closed bool ctx context.Context opcode opcode flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer } func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ c: c, mu: newMu(c), writeMu: newMu(c), } return mw } func (mw *msgWriter) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: util.WriterFunc(mw.write), } } if mw.flateWriter == nil { mw.flateWriter = getFlateWriter(mw.trimWriter) } mw.flate = true } func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } return c.msgWriter, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { mw, err := c.writer(ctx, typ) if err != nil { return 0, err } if !c.flate() { defer c.msgWriter.mu.unlock() return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) if err != nil { return n, err } err = mw.Close() return n, err } func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err } mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false mw.closed = false mw.trimWriter.reset() return nil } func (mw *msgWriter) putFlateWriter() { if mw.flateWriter != nil { putFlateWriter(mw.flateWriter) mw.flateWriter = nil } } // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() if mw.closed { return 0, errors.New("cannot use closed writer") } defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) } }() if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { mw.ensureFlate() } } if mw.flate { return mw.flateWriter.Write(p) } return mw.write(p) } func (mw *msgWriter) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil } // Close flushes the frame to the connection. func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) if err != nil { return err } defer mw.writeMu.unlock() if mw.closed { return errors.New("writer already closed") } mw.closed = true if mw.flate { err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate: %w", err) } } _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { mw.putFlateWriter() } mw.mu.unlock() return nil } func (mw *msgWriter) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() mw.putFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } // writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } defer c.writeFrameMu.unlock() select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } defer func() { if err != nil { select { case <-c.closed: err = net.ErrClosed case <-ctx.Done(): err = ctx.Err() default: } err = fmt.Errorf("failed to write frame: %w", err) } }() c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) if c.client { c.writeHeader.masked = true _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false if flate && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err } n, err := c.writeFramePayload(p) if err != nil { return n, err } if c.writeHeader.fin { err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { case <-c.closed: if opcode == opClose { return n, nil } return n, net.ErrClosed case c.writeTimeout <- context.Background(): } return n, nil } func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. i := c.bw.Buffered() j := len(p) if j > c.bw.Available() { j = c.bw.Available() } _, err := c.bw.Write(p[:j]) if err != nil { return n, err } maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey) p = p[j:] n += j } return n, nil } // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) bw.WriteByte(0) bw.Flush() bw.Reset(w) return writeBuf } func (c *Conn) writeError(code StatusCode, err error) { c.writeClose(code, err.Error()) } websocket-1.8.12/ws_js.go000066400000000000000000000352611465546417300152730ustar00rootroot00000000000000package websocket // import "github.com/coder/websocket" import ( "bytes" "context" "errors" "fmt" "io" "net" "net/http" "reflect" "runtime" "strings" "sync" "syscall/js" "github.com/coder/websocket/internal/bpool" "github.com/coder/websocket/internal/wsjs" "github.com/coder/websocket/internal/xsync" ) // opcode represents a WebSocket opcode. type opcode int // https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText opBinary // 3 - 7 are reserved for further non-control frames. _ _ _ _ _ opClose opPing opPong // 11-16 are reserved for further control frames. ) // Conn provides a wrapper around the browser WebSocket API. type Conn struct { noCopy noCopy ws wsjs.WebSocket // read limit for a message in bytes. msgReadLimit xsync.Int64 closeReadMu sync.Mutex closeReadCtx context.Context closingMu sync.Mutex closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once closeErr error closeWasClean bool releaseOnClose func() releaseOnError func() releaseOnMessage func() readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent } func (c *Conn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) if !wasClean { err = fmt.Errorf("unclean connection close: %w", err) } c.setCloseErr(err) c.closeWasClean = wasClean close(c.closed) }) } func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) c.msgReadLimit.Store(32768) c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } // We do not know if we sent or received this close as // its possible the browser triggered it without us // explicitly sending it. c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnError() c.releaseOnMessage() }) c.releaseOnError = c.ws.OnError(func(v js.Value) { c.setCloseErr(errors.New(v.Get("message").String())) c.closeWithInternal() }) c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { c.readBufMu.Lock() defer c.readBufMu.Unlock() c.readBuf = append(c.readBuf, e) // Lets the read goroutine know there is definitely something in readBuf. select { case c.readSignal <- struct{}{}: default: } }) runtime.SetFinalizer(c, func(c *Conn) { c.setCloseErr(errors.New("connection garbage collected")) c.closeWithInternal() }) } func (c *Conn) closeWithInternal() { c.Close(StatusInternalError, "something went wrong") } // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { c.closeReadMu.Lock() closedRead := c.closeReadCtx != nil c.closeReadMu.Unlock() if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) } readLimit := c.msgReadLimit.Load() if readLimit >= 0 && int64(len(p)) > readLimit { err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } return typ, p, nil } func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "read timed out") return 0, nil, ctx.Err() case <-c.readSignal: case <-c.closed: return 0, nil, net.ErrClosed } c.readBufMu.Lock() defer c.readBufMu.Unlock() me := c.readBuf[0] // We copy the messages forward and decrease the size // of the slice to avoid reallocating. copy(c.readBuf, c.readBuf[1:]) c.readBuf = c.readBuf[:len(c.readBuf)-1] if len(c.readBuf) > 0 { // Next time we read, we'll grab the message. select { case c.readSignal <- struct{}{}: default: } } switch p := me.Data.(type) { case string: return MessageText, []byte(p), nil case []byte: return MessageBinary, p, nil default: panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) } } // Ping is mocked out for Wasm. func (c *Conn) Ping(ctx context.Context) error { return nil } // Write writes a message of the given type to the connection. // Always non blocking. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { err := c.write(ctx, typ, p) if err != nil { // Have to ensure the WebSocket is closed after a write error // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. err := fmt.Errorf("failed to write: %w", err) c.setCloseErr(err) c.closeWithInternal() return err } return nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { if c.isClosed() { return net.ErrClosed } switch typ { case MessageBinary: return c.ws.SendBytes(p) case MessageText: return c.ws.SendText(string(p)) default: return fmt.Errorf("unexpected message type: %v", typ) } } // Close closes the WebSocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) } return nil } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. // // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { return c.Close(StatusGoingAway, "") } func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() if c.isClosed() { return net.ErrClosed } ce := fmt.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { return err } <-c.closed if !c.closeWasClean { return c.closeErr } return nil } // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.ws.Subprotocol() } // DialOptions represents the options available to pass to Dial. type DialOptions struct { // Subprotocols lists the subprotocols to negotiate with the server. Subprotocols []string } // Dial creates a new WebSocket connection to the given url with the given options. // The passed context bounds the maximum time spent waiting for the connection to open. // The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { c, resp, err := dial(ctx, url, opts) if err != nil { return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, resp, nil } func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { if opts == nil { opts = &DialOptions{} } url = strings.Replace(url, "http://", "ws://", 1) url = strings.Replace(url, "https://", "wss://", 1) ws, err := wsjs.New(url, opts.Subprotocols) if err != nil { return nil, nil, err } c := &Conn{ ws: ws, } c.init() opench := make(chan struct{}) releaseOpen := ws.OnOpen(func(e js.Value) { close(opench) }) defer releaseOpen() select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: return c, &http.Response{ StatusCode: http.StatusSwitchingProtocols, }, nil case <-c.closed: return nil, nil, net.ErrClosed } } // Reader attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { typ, p, err := c.Read(ctx) if err != nil { return 0, nil, err } return typ, bytes.NewReader(p), nil } // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { return &writer{ c: c, ctx: ctx, typ: typ, b: bpool.Get(), }, nil } type writer struct { closed bool c *Conn ctx context.Context typ MessageType b *bytes.Buffer } func (w *writer) Write(p []byte) (int, error) { if w.closed { return 0, errors.New("cannot write to closed writer") } n, err := w.b.Write(p) if err != nil { return n, fmt.Errorf("failed to write message: %w", err) } return n, nil } func (w *writer) Close() error { if w.closed { return errors.New("cannot close closed writer") } w.closed = true defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { return fmt.Errorf("failed to close writer: %w", err) } return nil } // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { c.closeReadMu.Lock() ctx2 := c.closeReadCtx if ctx2 != nil { c.closeReadMu.Unlock() return ctx2 } ctx, cancel := context.WithCancel(ctx) c.closeReadCtx = ctx c.closeReadMu.Unlock() go func() { defer cancel() defer c.CloseNow() _, _, err := c.read(ctx) if err != nil { c.Close(StatusPolicyViolation, "unexpected data message") } }() return ctx } // SetReadLimit implements *Conn.SetReadLimit for wasm. func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit.Store(n) } func (c *Conn) setCloseErr(err error) { c.closeErrOnce.Do(func() { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) }) } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } // AcceptOptions represents Accept's options. type AcceptOptions struct { Subprotocols []string InsecureSkipVerify bool OriginPatterns []string CompressionMode CompressionMode CompressionThreshold int } // Accept is stubbed out for Wasm. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return nil, errors.New("unimplemented") } // StatusCode represents a WebSocket status code. // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // // These are only the status codes defined by the protocol. // // You can define custom codes in the 3000-4999 range. // The 3000-3999 range is reserved for use by libraries, frameworks and applications. // The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without // a status code. StatusNoStatusRcvd StatusCode = 1005 // StatusAbnormalClosure is exported for use only with Wasm. // In non Wasm Go, the returned error will indicate whether the // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 StatusPolicyViolation StatusCode = 1008 StatusMessageTooBig StatusCode = 1009 StatusMandatoryExtension StatusCode = 1010 StatusInternalError StatusCode = 1011 StatusServiceRestart StatusCode = 1012 StatusTryAgainLater StatusCode = 1013 StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. // In non Wasm Go, the returned error will indicate whether there was // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) // CloseError is returned when the connection is closed with a status and reason. // // Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string } func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { return ce.Code } return -1 } // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // Works in all browsers except Safari which does not implement the deflate extension. type CompressionMode int const ( // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. CompressionContextTakeover // CompressionDisabled disables the deflate extension. // // Use this if you are using a predominantly binary protocol with very // little duplication in between messages or CPU and memory are more // important than bandwidth. CompressionDisabled ) // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int // MessageType constants. const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 // MessageBinary is for binary messages like protobufs. MessageBinary ) type mu struct { c *Conn ch chan struct{} } func newMu(c *Conn) *mu { return &mu{ c: c, ch: make(chan struct{}, 1), } } func (m *mu) forceLock() { m.ch <- struct{}{} } func (m *mu) tryLock() bool { select { case m.ch <- struct{}{}: return true default: return false } } func (m *mu) unlock() { select { case <-m.ch: default: } } type noCopy struct{} func (*noCopy) Lock() {} websocket-1.8.12/ws_js_test.go000066400000000000000000000024651465546417300163320ustar00rootroot00000000000000package websocket_test import ( "context" "net/http" "os" "testing" "time" "github.com/coder/websocket" "github.com/coder/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/wstest" ) func TestWasm(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") assert.Equal(t, "subprotocol", "echo", c.Subprotocol()) assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode) c.SetReadLimit(65536) for i := 0; i < 10; i++ { err = wstest.Echo(ctx, c, 65536) assert.Success(t, err) } err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } func TestWasmDialTimeout(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() beforeDial := time.Now() _, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) assert.Error(t, err) if time.Since(beforeDial) >= time.Second { t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial)) } } websocket-1.8.12/wsjson/000077500000000000000000000000001465546417300151335ustar00rootroot00000000000000websocket-1.8.12/wsjson/wsjson.go000066400000000000000000000034421465546417300170100ustar00rootroot00000000000000// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "github.com/coder/websocket/wsjson" import ( "context" "encoding/json" "fmt" "github.com/coder/websocket" "github.com/coder/websocket/internal/bpool" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) // Read reads a JSON message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to read JSON message") _, r, err := c.Reader(ctx) if err != nil { return err } b := bpool.Get() defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { return err } err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil } // Write writes the JSON message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { err := c.Write(ctx, websocket.MessageText, p) if err != nil { return 0, err } return len(p), nil })).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } return nil } websocket-1.8.12/wsjson/wsjson_test.go000066400000000000000000000015621465546417300200500ustar00rootroot00000000000000package wsjson_test import ( "encoding/json" "io" "strconv" "testing" "github.com/coder/websocket/internal/test/xrand" ) func BenchmarkJSON(b *testing.B) { sizes := []int{ 8, 16, 32, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, } b.Run("json.Encoder", func(b *testing.B) { for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { msg := xrand.String(size) b.SetBytes(int64(size)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { json.NewEncoder(io.Discard).Encode(msg) } }) } }) b.Run("json.Marshal", func(b *testing.B) { for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { msg := xrand.String(size) b.SetBytes(int64(size)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { json.Marshal(msg) } }) } }) }