pax_global_header00006660000000000000000000000064147621245030014516gustar00rootroot0000000000000052 comment=37c9bc0781f0cc5af7c729947ef1833c1e12b70d websockets-15.0.1/000077500000000000000000000000001476212450300137535ustar00rootroot00000000000000websockets-15.0.1/.github/000077500000000000000000000000001476212450300153135ustar00rootroot00000000000000websockets-15.0.1/.github/FUNDING.yml000066400000000000000000000001201476212450300171210ustar00rootroot00000000000000github: python-websockets open_collective: websockets tidelift: pypi/websockets websockets-15.0.1/.github/ISSUE_TEMPLATE/000077500000000000000000000000001476212450300174765ustar00rootroot00000000000000websockets-15.0.1/.github/ISSUE_TEMPLATE/config.yml000066400000000000000000000000341476212450300214630ustar00rootroot00000000000000blank_issues_enabled: false websockets-15.0.1/.github/ISSUE_TEMPLATE/issue.md000066400000000000000000000015721476212450300211550ustar00rootroot00000000000000--- name: Report an issue about: Let us know about a problem with websockets title: '' labels: '' assignees: '' --- websockets-15.0.1/.github/dependabot.yml000066400000000000000000000002771476212450300201510ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" day: "saturday" time: "07:00" timezone: "Europe/Paris" websockets-15.0.1/.github/workflows/000077500000000000000000000000001476212450300173505ustar00rootroot00000000000000websockets-15.0.1/.github/workflows/release.yml000066400000000000000000000046751476212450300215270ustar00rootroot00000000000000name: Make release on: push: tags: - '*' workflow_dispatch: jobs: sdist: name: Build source distribution and architecture-independent wheel runs-on: ubuntu-latest steps: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v5 with: python-version: 3.x - name: Install build run: pip install build - name: Build sdist & wheel run: python -m build env: BUILD_EXTENSION: no - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: name: dist-architecture-independent path: | dist/*.tar.gz dist/*.whl wheels: name: Build architecture-specific wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: - ubuntu-latest - windows-latest - macOS-latest steps: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v5 with: python-version: 3.x - name: Set up QEMU if: runner.os == 'Linux' uses: docker/setup-qemu-action@v3 with: platforms: all - name: Build wheels uses: pypa/cibuildwheel@v2.22.0 env: BUILD_EXTENSION: yes - name: Save wheels uses: actions/upload-artifact@v4 with: name: dist-${{ matrix.os }} path: wheelhouse/*.whl upload: name: Upload needs: - sdist - wheels runs-on: ubuntu-latest # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: id-token: write attestations: write contents: write steps: - name: Download artifacts uses: actions/download-artifact@v4 with: pattern: dist-* merge-multiple: true path: dist - name: Attest provenance uses: actions/attest-build-provenance@v2 with: subject-path: dist/* - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh release -R python-websockets/websockets create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." websockets-15.0.1/.github/workflows/tests.yml000066400000000000000000000034621476212450300212420ustar00rootroot00000000000000name: Run tests on: push: branches: - main pull_request: branches: - main env: WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 10 jobs: coverage: name: Run test coverage checks runs-on: ubuntu-latest steps: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox run: pip install tox - name: Run tests with coverage run: tox -e coverage - name: Run tests with per-module coverage run: tox -e maxi_cov quality: name: Run code quality checks runs-on: ubuntu-latest steps: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox run: pip install tox - name: Check code formatting & style run: tox -e ruff - name: Check types statically run: tox -e mypy matrix: name: Run tests on Python ${{ matrix.python }} needs: - coverage - quality runs-on: ubuntu-latest strategy: matrix: python: - "3.9" - "3.10" - "3.11" - "3.12" - "3.13" - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - python: "pypy-3.10" is_main: false steps: - name: Check out repository uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} allow-prereleases: true - name: Install tox run: pip install tox - name: Run tests run: tox -e py websockets-15.0.1/.gitignore000066400000000000000000000002701476212450300157420ustar00rootroot00000000000000*.pyc *.so .coverage .direnv/ .envrc .idea/ .mypy_cache/ .tox/ .vscode/ build/ compliance/reports/ dist/ docs/_build/ experiments/compression/corpus/ htmlcov/ src/websockets.egg-info/ websockets-15.0.1/.readthedocs.yml000066400000000000000000000003401476212450300170360ustar00rootroot00000000000000version: 2 build: os: ubuntu-20.04 tools: python: "3.10" jobs: post_checkout: - git fetch --unshallow sphinx: configuration: docs/conf.py python: install: - requirements: docs/requirements.txt websockets-15.0.1/CODE_OF_CONDUCT.md000066400000000000000000000062511476212450300165560ustar00rootroot00000000000000# Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at aymeric DOT augustin AT fractalideas DOT com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ websockets-15.0.1/LICENSE000066400000000000000000000027521476212450300147660ustar00rootroot00000000000000Copyright (c) Aymeric Augustin and contributors Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. websockets-15.0.1/MANIFEST.in000066400000000000000000000001661476212450300155140ustar00rootroot00000000000000include LICENSE include src/websockets/py.typed include src/websockets/speedups.c # required when BUILD_EXTENSION=no websockets-15.0.1/Makefile000066400000000000000000000014061476212450300154140ustar00rootroot00000000000000.PHONY: default style types tests coverage maxi_cov build clean export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src export PYTHONWARNINGS=default build: python setup.py build_ext --inplace style: ruff format compliance src tests ruff check --fix compliance src tests types: mypy --strict src tests: python -m unittest coverage: coverage run --source src/websockets,tests -m unittest coverage html coverage report --show-missing --fail-under=100 maxi_cov: python tests/maxi_cov.py coverage html coverage report --show-missing --fail-under=100 clean: find src -name '*.so' -delete find . -name '*.pyc' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info websockets-15.0.1/README.rst000066400000000000000000000143771476212450300154560ustar00rootroot00000000000000.. image:: logo/horizontal.svg :width: 480px :alt: websockets |licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ .. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge :target: https://bestpractices.coreinfrastructure.org/projects/6475 What is ``websockets``? ----------------------- websockets is a library for building WebSocket_ servers and clients in Python with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API Built on top of ``asyncio``, Python's standard asynchronous I/O framework, the default implementation provides an elegant coroutine-based API. An implementation on top of ``threading`` and a Sans-I/O implementation are also available. `Documentation is available on Read the Docs. `_ .. copy-pasted because GitHub doesn't support the include directive Here's an echo server with the ``asyncio`` API: .. code:: python #!/usr/bin/env python import asyncio from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): async with serve(echo, "localhost", 8765) as server: await server.serve_forever() asyncio.run(main()) Here's how a client sends and receives messages with the ``threading`` API: .. code:: python #!/usr/bin/env python from websockets.sync.client import connect def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() print(f"Received: {message}") hello() Does that look good? `Get started with the tutorial! `_ .. raw:: html

websockets for enterprise

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.


(If you contribute to websockets and would like to become an official support provider, let me know.)

Why should I use ``websockets``? -------------------------------- The development of ``websockets`` is shaped by four principles: 1. **Correctness**: ``websockets`` is heavily tested for compliance with :rfc:`6455`. Continuous integration fails under 100% branch coverage. 2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and ``await ws.send(msg)``. ``websockets`` takes care of managing connections so you can focus on your application. 3. **Robustness**: ``websockets`` is built for production. For example, it was the only library to `handle backpressure correctly`_ before the issue became widely known in the Python community. 4. **Performance**: memory usage is optimized and configurable. A C extension accelerates expensive operations. It's pre-compiled for Linux, macOS and Windows and packaged in the wheel format for each system and Python version. Documentation is a first class concern in the project. Head over to `Read the Docs`_ and see for yourself. .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers Why shouldn't I use ``websockets``? ----------------------------------- * If you prefer callbacks over coroutines: ``websockets`` was created to provide the best coroutine-based API to manage WebSocket connections in Python. Pick another library for a callback-based API. * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for an HTTP health check. If you want to do both in the same server, look at HTTP + WebSocket servers that build on top of ``websockets`` to support WebSocket connections, like uvicorn_ or Sanic_. .. _uvicorn: https://www.uvicorn.org/ .. _Sanic: https://sanic.dev/en/ What else? ---------- Bug reports, patches and suggestions are welcome! To report a security vulnerability, please use the `Tidelift security contact`_. Tidelift will coordinate the fix and disclosure. .. _Tidelift security contact: https://tidelift.com/security For anything else, please open an issue_ or send a `pull request`_. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ Participants must uphold the `Contributor Covenant code of conduct`_. .. _Contributor Covenant code of conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md ``websockets`` is released under the `BSD license`_. .. _BSD license: https://github.com/python-websockets/websockets/blob/main/LICENSE websockets-15.0.1/SECURITY.md000066400000000000000000000003741476212450300155500ustar00rootroot00000000000000# Security ## Policy Only the latest version receives security updates. ## Contact information Please report security vulnerabilities to the [Tidelift security team](https://tidelift.com/security). Tidelift will coordinate the fix and disclosure. websockets-15.0.1/compliance/000077500000000000000000000000001476212450300160655ustar00rootroot00000000000000websockets-15.0.1/compliance/README.rst000066400000000000000000000055241476212450300175620ustar00rootroot00000000000000Autobahn Testsuite ================== General information and installation instructions are available at https://github.com/crossbario/autobahn-testsuite. Running the test suite ---------------------- All commands below must be run from the root directory of the repository. To get acceptable performance, compile the C extension first: .. code-block:: console $ python setup.py build_ext --inplace Run each command in a different shell. Testing takes several minutes to complete — wstest is the bottleneck. When clients finish, stop servers with Ctrl-C. You can exclude slow tests by modifying the configuration files as follows:: "exclude-cases": ["9.*", "12.*", "13.*"] The test server and client applications shouldn't display any exceptions. To test the servers: .. code-block:: console $ PYTHONPATH=src python compliance/asyncio/server.py $ PYTHONPATH=src python compliance/sync/server.py $ docker run --interactive --tty --rm \ --volume "${PWD}/compliance/config:/config" \ --volume "${PWD}/compliance/reports:/reports" \ --name fuzzingclient \ crossbario/autobahn-testsuite \ wstest --mode fuzzingclient --spec /config/fuzzingclient.json $ open compliance/reports/servers/index.html To test the clients: .. code-block:: console $ docker run --interactive --tty --rm \ --volume "${PWD}/compliance/config:/config" \ --volume "${PWD}/compliance/reports:/reports" \ --publish 9001:9001 \ --name fuzzingserver \ crossbario/autobahn-testsuite \ wstest --mode fuzzingserver --spec /config/fuzzingserver.json $ PYTHONPATH=src python compliance/asyncio/client.py $ PYTHONPATH=src python compliance/sync/client.py $ open compliance/reports/clients/index.html Conformance notes ----------------- Some test cases are more strict than the RFC. Given the implementation of the library and the test client and server applications, websockets passes with a "Non-Strict" result in these cases. In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the protocol error and closes the connection at the library level before the application gets a chance to echo the previous frame. In 6.4.1, 6.4.2, 6.4.3, and 6.4.4, even though it uses an incremental decoder, websockets doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests are more strict than the RFC. Test case 7.1.5 fails because websockets treats closing the connection in the middle of a fragmented message as a protocol error. As a consequence, it sends a close frame with code 1002. The test suite expects a close frame with code 1000, echoing the close code that it sent. This isn't required. RFC 6455 states that "the endpoint typically echos the status code it received", which leaves the possibility to send a close frame with a different status code. websockets-15.0.1/compliance/asyncio/000077500000000000000000000000001476212450300175325ustar00rootroot00000000000000websockets-15.0.1/compliance/asyncio/client.py000066400000000000000000000025501476212450300213640ustar00rootroot00000000000000import asyncio import json import logging from websockets.asyncio.client import connect from websockets.exceptions import WebSocketException logging.basicConfig(level=logging.WARNING) SERVER = "ws://localhost:9001" AGENT = "websockets.asyncio" async def get_case_count(): async with connect(f"{SERVER}/getCaseCount") as ws: return json.loads(await ws.recv()) async def run_case(case): async with connect( f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: try: async for msg in ws: await ws.send(msg) except WebSocketException: pass async def update_reports(): async with connect( f"{SERVER}/updateReports?agent={AGENT}", open_timeout=60, ): pass async def main(): cases = await get_case_count() for case in range(1, cases + 1): print(f"Running test case {case:03d} / {cases}... ", end="\t") try: await run_case(case) except WebSocketException as exc: print(f"ERROR: {type(exc).__name__}: {exc}") except Exception as exc: print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") print(f"Ran {cases} test cases") await update_reports() print("Updated reports") if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/compliance/asyncio/server.py000066400000000000000000000012261476212450300214130ustar00rootroot00000000000000import asyncio import logging from websockets.asyncio.server import serve from websockets.exceptions import WebSocketException logging.basicConfig(level=logging.WARNING) HOST, PORT = "0.0.0.0", 9002 async def echo(ws): try: async for msg in ws: await ws.send(msg) except WebSocketException: pass async def main(): async with serve( echo, HOST, PORT, server_header="websockets.sync", max_size=2**25, ) as server: try: await server.serve_forever() except KeyboardInterrupt: pass if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/compliance/config/000077500000000000000000000000001476212450300173325ustar00rootroot00000000000000websockets-15.0.1/compliance/config/fuzzingclient.json000066400000000000000000000003141476212450300231160ustar00rootroot00000000000000 { "servers": [{ "url": "ws://host.docker.internal:9002" }, { "url": "ws://host.docker.internal:9003" }], "outdir": "/reports/servers", "cases": ["*"], "exclude-cases": [] } websockets-15.0.1/compliance/config/fuzzingserver.json000066400000000000000000000001611476212450300231460ustar00rootroot00000000000000 { "url": "ws://localhost:9001", "outdir": "/reports/clients", "cases": ["*"], "exclude-cases": [] } websockets-15.0.1/compliance/sync/000077500000000000000000000000001476212450300170415ustar00rootroot00000000000000websockets-15.0.1/compliance/sync/client.py000066400000000000000000000023701476212450300206730ustar00rootroot00000000000000import json import logging from websockets.exceptions import WebSocketException from websockets.sync.client import connect logging.basicConfig(level=logging.WARNING) SERVER = "ws://localhost:9001" AGENT = "websockets.sync" def get_case_count(): with connect(f"{SERVER}/getCaseCount") as ws: return json.loads(ws.recv()) def run_case(case): with connect( f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: try: for msg in ws: ws.send(msg) except WebSocketException: pass def update_reports(): with connect( f"{SERVER}/updateReports?agent={AGENT}", open_timeout=60, ): pass def main(): cases = get_case_count() for case in range(1, cases + 1): print(f"Running test case {case:03d} / {cases}... ", end="\t") try: run_case(case) except WebSocketException as exc: print(f"ERROR: {type(exc).__name__}: {exc}") except Exception as exc: print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") print(f"Ran {cases} test cases") update_reports() print("Updated reports") if __name__ == "__main__": main() websockets-15.0.1/compliance/sync/server.py000066400000000000000000000011261476212450300207210ustar00rootroot00000000000000import logging from websockets.exceptions import WebSocketException from websockets.sync.server import serve logging.basicConfig(level=logging.WARNING) HOST, PORT = "0.0.0.0", 9003 def echo(ws): try: for msg in ws: ws.send(msg) except WebSocketException: pass def main(): with serve( echo, HOST, PORT, server_header="websockets.asyncio", max_size=2**25, ) as server: try: server.serve_forever() except KeyboardInterrupt: pass if __name__ == "__main__": main() websockets-15.0.1/docs/000077500000000000000000000000001476212450300147035ustar00rootroot00000000000000websockets-15.0.1/docs/Makefile000066400000000000000000000013451476212450300163460ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) livehtml: sphinx-autobuild --watch "$(SOURCEDIR)/../src" "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) websockets-15.0.1/docs/_static/000077500000000000000000000000001476212450300163315ustar00rootroot00000000000000websockets-15.0.1/docs/_static/favicon.ico000077700000000000000000000000001476212450300241542../../logo/favicon.icoustar00rootroot00000000000000websockets-15.0.1/docs/_static/tidelift.png000077700000000000000000000000001476212450300245362../../logo/tidelift.pngustar00rootroot00000000000000websockets-15.0.1/docs/_static/websockets.svg000077700000000000000000000000001476212450300251362../../logo/vertical.svgustar00rootroot00000000000000websockets-15.0.1/docs/conf.py000066400000000000000000000141071476212450300162050ustar00rootroot00000000000000# Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html import datetime import importlib import inspect import os import subprocess import sys # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.join(os.path.abspath(".."), "src")) # -- Project information ----------------------------------------------------- project = "websockets" copyright = f"2013-{datetime.date.today().year}, Aymeric Augustin and contributors" author = "Aymeric Augustin" from websockets.version import tag as version, version as release # -- General configuration --------------------------------------------------- nitpicky = True nitpick_ignore = [ # topics/design.rst discusses undocumented APIs ("py:meth", "client.WebSocketClientProtocol.handshake"), ("py:meth", "server.WebSocketServerProtocol.handshake"), ("py:attr", "protocol.WebSocketCommonProtocol.is_client"), ("py:attr", "protocol.WebSocketCommonProtocol.messages"), ("py:meth", "protocol.WebSocketCommonProtocol.close_connection"), ("py:attr", "protocol.WebSocketCommonProtocol.close_connection_task"), ("py:meth", "protocol.WebSocketCommonProtocol.keepalive_ping"), ("py:attr", "protocol.WebSocketCommonProtocol.keepalive_ping_task"), ("py:meth", "protocol.WebSocketCommonProtocol.transfer_data"), ("py:attr", "protocol.WebSocketCommonProtocol.transfer_data_task"), ("py:meth", "protocol.WebSocketCommonProtocol.connection_open"), ("py:meth", "protocol.WebSocketCommonProtocol.ensure_open"), ("py:meth", "protocol.WebSocketCommonProtocol.fail_connection"), ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), ] # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx.ext.linkcode", "sphinx.ext.napoleon", "sphinx_copybutton", "sphinx_inline_tabs", "sphinxcontrib.spelling", "sphinxcontrib_trio", "sphinxext.opengraph", ] # It is currently inconvenient to install PyEnchant on Apple Silicon. try: import sphinxcontrib.spelling except ImportError: extensions.remove("sphinxcontrib.spelling") autodoc_typehints = "description" autodoc_typehints_description_target = "documented" # Workaround for https://github.com/sphinx-doc/sphinx/issues/9560 from sphinx.domains.python import PythonDomain assert PythonDomain.object_types["data"].roles == ("data", "obj") PythonDomain.object_types["data"].roles = ("data", "class", "obj") intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "sesame": ("https://django-sesame.readthedocs.io/en/stable/", None), "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), } spelling_show_suggestions = True # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # Configure viewcode extension. from websockets.version import commit code_url = f"https://github.com/python-websockets/websockets/blob/{commit}" def linkcode_resolve(domain, info): # Non-linkable objects from the starter kit in the tutorial. if domain == "js" or info["module"] == "connect4": return assert domain == "py", "expected only Python objects" mod = importlib.import_module(info["module"]) if "." in info["fullname"]: objname, attrname = info["fullname"].split(".") obj = getattr(mod, objname) try: # object is a method of a class obj = getattr(obj, attrname) except AttributeError: # object is an attribute of a class return None else: obj = getattr(mod, info["fullname"]) try: file = inspect.getsourcefile(obj) lines = inspect.getsourcelines(obj) except TypeError: # e.g. object is a typing.Union return None file = os.path.relpath(file, os.path.abspath("..")) if not file.startswith("src/websockets"): # e.g. object is a typing.NewType return None start, end = lines[1], lines[1] + len(lines[0]) - 1 return f"{code_url}/{file}#L{start}-L{end}" # Configure opengraph extension # Social cards don't support the SVG logo. Also, the text preview looks bad. ogp_social_cards = {"enable": False} # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = "furo" html_theme_options = { "light_css_variables": { "color-brand-primary": "#306998", # blue from logo "color-brand-content": "#0b487a", # blue more saturated and less dark }, "dark_css_variables": { "color-brand-primary": "#ffd43bcc", # yellow from logo, more muted than content "color-brand-content": "#ffd43bd9", # yellow from logo, transparent like text }, "sidebar_hide_name": True, } html_logo = "_static/websockets.svg" html_favicon = "_static/favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] html_copy_source = False html_show_sphinx = False websockets-15.0.1/docs/deploy/000077500000000000000000000000001476212450300161775ustar00rootroot00000000000000websockets-15.0.1/docs/deploy/architecture.svg000066400000000000000000000355361476212450300214160ustar00rootroot00000000000000Internetwebsocketswebsocketswebsocketsroutingwebsockets-15.0.1/docs/deploy/fly.rst000066400000000000000000000116341476212450300175300ustar00rootroot00000000000000Deploy to Fly ============= This guide describes how to deploy a websockets server to Fly_. .. _Fly: https://fly.io/ .. admonition:: The free tier of Fly is sufficient for trying this guide. :class: tip The `free tier`__ include up to three small VMs. This guide uses only one. __ https://fly.io/docs/about/pricing/ We're going to deploy a very simple app. The process would be identical for a more realistic app. Create application ------------------ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/fly/app.py :language: python This app implements typical requirements for running on a Platform as a Service: * it provides a health check at ``/healthz``; * it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. Create a ``requirements.txt`` file containing this line to declare a dependency on websockets: .. literalinclude:: ../../example/deployment/fly/requirements.txt :language: text The app is ready. Let's deploy it! Deploy application ------------------ Follow the instructions__ to install the Fly CLI, if you haven't done that yet. __ https://fly.io/docs/hands-on/install-flyctl/ Sign up or log in to Fly. Launch the app — you'll have to pick a different name because I'm already using ``websockets-echo``: .. code-block:: console $ fly launch Creating app in ... Scanning source code Detected a Python app Using the following build configuration: Builder: paketobuildpacks/builder:base ? App Name (leave blank to use an auto-generated name): websockets-echo ? Select organization: ... ? Select region: ... Created app websockets-echo in organization ... Wrote config file fly.toml ? Would you like to set up a Postgresql database now? No We have generated a simple Procfile for you. Modify it to fit your needs and run "fly deploy" to deploy your application. .. admonition:: This will build the image with a generic buildpack. :class: tip Fly can `build images`__ with a Dockerfile or a buildpack. Here, ``fly launch`` configures a generic Paketo buildpack. If you'd rather package the app with a Dockerfile, check out the guide to :ref:`containerize an application `. __ https://fly.io/docs/reference/builders/ Replace the auto-generated ``fly.toml`` with: .. literalinclude:: ../../example/deployment/fly/fly.toml :language: toml This configuration: * listens on port 443, terminates TLS, and forwards to the app on port 8080; * declares a health check at ``/healthz``; * requests a ``SIGTERM`` for terminating the app. Replace the auto-generated ``Procfile`` with: .. literalinclude:: ../../example/deployment/fly/Procfile :language: text This tells Fly how to run the app. Now you can deploy it: .. code-block:: console $ fly deploy ... lots of output... ==> Monitoring deployment 1 desired, 1 placed, 1 healthy, 0 unhealthy [health checks: 1 total, 1 passing] --> v0 deployed successfully Validate deployment ------------------- Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: .. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate $ pip install websockets Connect the interactive client — you must replace ``websockets-echo`` with the name of your Fly app in this command: .. code-block:: console $ websockets wss://websockets-echo.fly.dev/ Connected to wss://websockets-echo.fly.dev/. > Great! Your app is running! Once you're connected, you can send any message and the server will echo it, or press Ctrl-D to terminate the connection: .. code-block:: console > Hello! < Hello! Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: .. code-block:: console $ websockets wss://websockets-echo.fly.dev/ Connected to wss://websockets-echo.fly.dev/. > In another shell, restart the app — again, replace ``websockets-echo`` with your app: .. code-block:: console $ fly restart websockets-echo websockets-echo is being restarted Go back to the first shell. The connection is closed with code 1001 (going away). .. code-block:: console $ websockets wss://websockets-echo.fly.dev/ Connected to wss://websockets-echo.fly.dev/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (abnormal closure). websockets-15.0.1/docs/deploy/haproxy.rst000066400000000000000000000031421476212450300204230ustar00rootroot00000000000000Deploy behind HAProxy ===================== This guide demonstrates a way to load balance connections across multiple websockets server processes running on the same machine with HAProxy_. We'll run server processes with Supervisor as described in :doc:`this guide `. .. _HAProxy: https://www.haproxy.org/ Run server processes -------------------- Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/haproxy/app.py :language: python Each server process listens on a different port by extracting an incremental index from an environment variable set by Supervisor. Save this configuration to ``supervisord.conf``: .. literalinclude:: ../../example/deployment/haproxy/supervisord.conf This configuration runs four instances of the app. Install Supervisor and run it: .. code-block:: console $ supervisord -c supervisord.conf -n Configure and run HAProxy ------------------------- Here's a simple HAProxy configuration to load balance connections across four processes: .. literalinclude:: ../../example/deployment/haproxy/haproxy.cfg In the backend configuration, we set the load balancing method to ``leastconn`` in order to balance the number of active connections across servers. This is best for long running connections. Save the configuration to ``haproxy.cfg``, install HAProxy, and run it: .. code-block:: console $ haproxy -f haproxy.cfg You can confirm that HAProxy proxies connections properly: .. code-block:: console $ websockets ws://localhost:8080/ Connected to ws://localhost:8080/. > Hello! < Hello! Connection closed: 1000 (OK). websockets-15.0.1/docs/deploy/heroku.rst000066400000000000000000000125511476212450300202320ustar00rootroot00000000000000Deploy to Heroku ================ This guide describes how to deploy a websockets server to Heroku_. The same principles should apply to other Platform as a Service providers. .. _Heroku: https://www.heroku.com/ .. admonition:: Heroku no longer offers a free tier. :class: attention When this tutorial was written, in September 2021, Heroku offered a free tier where a websockets app could run at no cost. In November 2022, Heroku removed the free tier, making it impossible to maintain this document. As a consequence, it isn't updated anymore and may be removed in the future. We're going to deploy a very simple app. The process would be identical for a more realistic app. Create repository ----------------- Deploying to Heroku requires a git repository. Let's initialize one: .. code-block:: console $ mkdir websockets-echo $ cd websockets-echo $ git init -b main Initialized empty Git repository in websockets-echo/.git/ $ git commit --allow-empty -m "Initial commit." [main (root-commit) 1e7947d] Initial commit. Create application ------------------ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/heroku/app.py :language: python Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to :func:`~websockets.asyncio.server.serve`. .. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port Heroku sends a ``SIGTERM`` signal to all processes when `shutting down a dyno`_. When the app receives this signal, it closes connections and exits cleanly. .. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown Create a ``requirements.txt`` file containing this line to declare a dependency on websockets: .. literalinclude:: ../../example/deployment/heroku/requirements.txt :language: text Create a ``Procfile`` to tell Heroku how to run the app. .. literalinclude:: ../../example/deployment/heroku/Procfile Confirm that you created the correct files and commit them to git: .. code-block:: console $ ls Procfile app.py requirements.txt $ git add . $ git commit -m "Initial implementation." [main 8418c62] Initial implementation.  3 files changed, 32 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py  create mode 100644 requirements.txt The app is ready. Let's deploy it! Deploy application ------------------ Follow the instructions_ to install the Heroku CLI, if you haven't done that yet. .. _instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up Sign up or log in to Heroku. Create a Heroku app — you'll have to pick a different name because I'm already using ``websockets-echo``: .. code-block:: console $ heroku create websockets-echo Creating ⬢ websockets-echo... done https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git .. code-block:: console $ git push heroku ... lots of output... remote: -----> Launching... remote: Released v1 remote: https://websockets-echo.herokuapp.com/ deployed to Heroku remote: remote: Verifying deploy... done. To https://git.heroku.com/websockets-echo.git  * [new branch] main -> main Validate deployment ------------------- Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: .. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate $ pip install websockets Connect the interactive client — you must replace ``websockets-echo`` with the name of your Heroku app in this command: .. code-block:: console $ websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. > Great! Your app is running! Once you're connected, you can send any message and the server will echo it, or press Ctrl-D to terminate the connection: .. code-block:: console > Hello! < Hello! Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: .. code-block:: console $ websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. > In another shell, restart the app — again, replace ``websockets-echo`` with your app: .. code-block:: console $ heroku dyno:restart -a websockets-echo Restarting dynos on ⬢ websockets-echo... done Go back to the first shell. The connection is closed with code 1001 (going away). .. code-block:: console $ websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (abnormal closure). websockets-15.0.1/docs/deploy/index.rst000066400000000000000000000150761476212450300200510ustar00rootroot00000000000000Deployment ========== .. currentmodule:: websockets Architecture decisions ---------------------- When you deploy your websockets server to production, at a high level, your architecture will almost certainly look like the following diagram: .. image:: architecture.svg The basic unit for scaling a websockets server is "one server process". Each blue box in the diagram represents one server process. There's more variation in routing connections to processes. While the routing layer is shown as one big box, it is likely to involve several subsystems. As a consequence, when you design a deployment, you must answer two questions: 1. How will I run the appropriate number of server processes? 2. How will I route incoming connections to these processes? These questions are interrelated. There's a wide range of valid answers, depending on your goals and your constraints. Platforms-as-a-Service ...................... Platforms-as-a-Service are the easiest option. They provide end-to-end, integrated solutions and they require little configuration. Here's how to deploy on some popular PaaS providers. Since all PaaS use similar patterns, the concepts translate to other providers. .. toctree:: :titlesonly: render koyeb fly heroku Self-hosted infrastructure .......................... If you need more control over your infrastructure, you can deploy on your own infrastructure. This requires more configuration. Here's how to configure some components mentioned in this guide. .. toctree:: :titlesonly: kubernetes supervisor nginx haproxy Running server processes ------------------------ How many processes do I need? ............................. Typically, one server process will manage a few hundreds or thousands connections, depending on the frequency of messages and the amount of work they require. CPU and memory usage increase with the number of connections to the server. Often CPU is the limiting factor. If a server process goes to 100% CPU, then you reached the limit. How much headroom you want to keep is up to you. Once you know how many connections a server process can manage and how many connections you need to handle, you can calculate how many processes to run. You can also automate this calculation by configuring an autoscaler to keep CPU usage or connection count within acceptable limits. .. admonition:: Don't scale with threads. Scale only with processes. :class: tip Threads don't make sense for a server built with :mod:`asyncio`. How do I run processes? ....................... Most solutions for running multiple instances of a server process fall into one of these three buckets: 1. Running N processes on a platform: * a Kubernetes Deployment * its equivalent on a Platform as a Service provider 2. Running N servers: * an AWS Auto Scaling group, a GCP Managed instance group, etc. * a fixed set of long-lived servers 3. Running N processes on a server: * preferably via a process manager or supervisor Option 1 is easiest if you have access to such a platform. Option 2 usually combines with option 3. How do I start a process? ......................... Run a Python program that invokes :func:`~asyncio.server.serve` or :func:`~asyncio.router.route`. That's it! Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't designed to run WebSocket applications. Applications servers handle network connections and expose a Python API. You don't need one because websockets handles network connections directly. How do I stop a process? ........................ Process managers send the SIGTERM signal to terminate processes. Catch this signal and exit the server to ensure a graceful shutdown. Here's an example: .. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 14-16 When exiting the context manager, :func:`~asyncio.server.serve` closes all connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting :meth:`~asyncio.server.ServerConnection.recv`, it receives a :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception and clean up before exiting. * Otherwise, it should be waiting on :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the :exc:`~exceptions.ConnectionClosedOK` exception and exit. This example is easily adapted to handle other signals. If you override the default signal handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a program with Ctrl-C anymore when it's stuck in a loop. Routing connections to processes -------------------------------- What does routing involve? .......................... Since the routing layer is directly exposed to the Internet, it should provide appropriate protection against threats ranging from Internet background noise to targeted attacks. You should always secure WebSocket connections with TLS. Since the routing layer carries the public domain name, it should terminate TLS connections. Finally, it must route connections to the server processes, balancing new connections across them. How do I route connections? ........................... Here are typical solutions for load balancing, matched to ways of running processes: 1. If you're running on a platform, it comes with a routing layer: * a Kubernetes Ingress and Service * a service mesh: Istio, Consul, Linkerd, etc. * the routing mesh of a Platform as a Service 2. If you're running N servers, you may load balance with: * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load Balancing, etc. * A software load balancer: HAProxy, NGINX, etc. 3. If you're running N processes on a server, you may load balance with: * A software load balancer: HAProxy, NGINX, etc. * The operating system — all processes listen on the same port You may trust the load balancer to handle encryption and to provide security. You may add another layer in front of the load balancer for these purposes. There are many possibilities. Don't add layers that you don't need, though. How do I implement a health check? .................................. Load balancers need a way to check whether server processes are up and running to avoid routing connections to a non-functional backend. websockets provide minimal support for responding to HTTP requests with the ``process_request`` hook. Here's an example: .. literalinclude:: ../../example/faq/health_check_server.py :emphasize-lines: 7-9,16 websockets-15.0.1/docs/deploy/koyeb.rst000066400000000000000000000120521476212450300200420ustar00rootroot00000000000000Deploy to Koyeb ================ This guide describes how to deploy a websockets server to Koyeb_. .. _Koyeb: https://www.koyeb.com .. admonition:: The free tier of Koyeb is sufficient for trying this guide. :class: tip The `free tier`__ include one web service, which this guide uses. __ https://www.koyeb.com/pricing We’re going to deploy a very simple app. The process would be identical to a more realistic app. Create repository ----------------- Koyeb supports multiple deployment methods. Its quick start guides recommend git-driven deployment as the first option. Let's initialize a git repository: .. code-block:: console $ mkdir websockets-echo $ cd websockets-echo $ git init -b main Initialized empty Git repository in websockets-echo/.git/ $ git commit --allow-empty -m "Initial commit." [main (root-commit) 740f699] Initial commit. Render requires the git repository to be hosted at GitHub. Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. Don't enable any of the initialization options offered by GitHub. Then, follow instructions for pushing an existing repository from the command line. After pushing, refresh your repository's homepage on GitHub. You should see an empty repository with an empty initial commit. Create application ------------------ Here’s the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/koyeb/app.py :language: python This app implements typical requirements for running on a Platform as a Service: * it listens on the port provided in the ``$PORT`` environment variable; * it provides a health check at ``/healthz``; * it closes connections and exits cleanly when it receives a ``SIGTERM`` signal; while not documented, this is how Koyeb terminates apps. Create a ``requirements.txt`` file containing this line to declare a dependency on websockets: .. literalinclude:: ../../example/deployment/koyeb/requirements.txt :language: text Create a ``Procfile`` to tell Koyeb how to run the app. .. literalinclude:: ../../example/deployment/koyeb/Procfile Confirm that you created the correct files and commit them to git: .. code-block:: console $ ls Procfile app.py requirements.txt $ git add . $ git commit -m "Initial implementation." [main f634b8b] Initial implementation.  3 files changed, 39 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py  create mode 100644 requirements.txt The app is ready. Let's deploy it! Deploy application ------------------ Sign up or log in to Koyeb. In the Koyeb control panel, create a web service with GitHub as the deployment method. Install and authorize Koyeb's GitHub app if you haven't done that yet. Follow the steps to create a new service: 1. Select the ``websockets-echo`` repository in the list of your repositories. 2. Confirm that the **Free** instance type is selected. Click **Next**. 3. Configure health checks: change the protocol from TCP to HTTP and set the path to ``/healthz``. Review other settings; defaults should be correct. Click **Deploy**. Koyeb builds the app, deploys it, verifies that the health checks passes, and makes the deployment active. Validate deployment ------------------- Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: .. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate $ pip install websockets Look for the URL of your app in the Koyeb control panel. It looks like ``https://--.koyeb.app/``. Connect the interactive client — you must replace ``https`` with ``wss`` in the URL: .. code-block:: console $ websockets wss://--.koyeb.app/ Connected to wss://--.koyeb.app/. > Great! Your app is running! Once you're connected, you can send any message and the server will echo it, or press Ctrl-D to terminate the connection: .. code-block:: console > Hello! < Hello! Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully. Connect an interactive client again: .. code-block:: console $ websockets wss://--.koyeb.app/ Connected to wss://--.koyeb.app/. > In the Koyeb control panel, go to the **Settings** tab, click **Pause**, and confirm. Eventually, the connection gets closed with code 1001 (going away). .. code-block:: console $ websockets wss://--.koyeb.app/ Connected to wss://--.koyeb.app/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (abnormal closure). websockets-15.0.1/docs/deploy/kubernetes.rst000066400000000000000000000151611476212450300211040ustar00rootroot00000000000000Deploy to Kubernetes ==================== This guide describes how to deploy a websockets server to Kubernetes_. It assumes familiarity with Docker and Kubernetes. We're going to deploy a simple app to a local Kubernetes cluster and to ensure that it scales as expected. In a more realistic context, you would follow your organization's practices for deploying to Kubernetes, but you would apply the same principles as far as websockets is concerned. .. _Kubernetes: https://kubernetes.io/ .. _containerize-application: Containerize application ------------------------ Here's the app we're going to deploy. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/kubernetes/app.py This is an echo server with one twist: every message blocks the server for 100ms, which creates artificial starvation of CPU time. This makes it easier to saturate the server for load testing. The app exposes a health check on ``/healthz``. It also provides two other endpoints for testing purposes: ``/inemuri`` will make the app unresponsive for 10 seconds and ``/seppuku`` will terminate it. The quest for the perfect Python container image is out of scope of this guide, so we'll go for the simplest possible configuration instead: .. literalinclude:: ../../example/deployment/kubernetes/Dockerfile After saving this ``Dockerfile``, build the image: .. code-block:: console $ docker build -t websockets-test:1.0 . Test your image by running: .. code-block:: console $ docker run --name run-websockets-test --publish 32080:80 --rm \ websockets-test:1.0 Then, in another shell, in a virtualenv where websockets is installed, connect to the app and check that it echoes anything you send: .. code-block:: console $ websockets ws://localhost:32080/ Connected to ws://localhost:32080/. > Hey there! < Hey there! > Now, in yet another shell, stop the app with: .. code-block:: console $ docker kill -s TERM run-websockets-test Going to the shell where you connected to the app, you can confirm that it shut down gracefully: .. code-block:: console $ websockets ws://localhost:32080/ Connected to ws://localhost:32080/. > Hey there! < Hey there! Connection closed: 1001 (going away). If it didn't, you'd get code 1006 (abnormal closure). Deploy application ------------------ Configuring Kubernetes is even further beyond the scope of this guide, so we'll use a basic configuration for testing, with just one Service_ and one Deployment_: .. literalinclude:: ../../example/deployment/kubernetes/deployment.yaml For local testing, a service of type NodePort_ is good enough. For deploying to production, you would configure an Ingress_. .. _Service: https://kubernetes.io/docs/concepts/services-networking/service/ .. _Deployment: https://kubernetes.io/docs/concepts/workloads/controllers/deployment/ .. _NodePort: https://kubernetes.io/docs/concepts/services-networking/service/#nodeport .. _Ingress: https://kubernetes.io/docs/concepts/services-networking/ingress/ After saving this to a file called ``deployment.yaml``, you can deploy: .. code-block:: console $ kubectl apply -f deployment.yaml service/websockets-test created deployment.apps/websockets-test created Now you have a deployment with one pod running: .. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE websockets-test 1/1 1 1 10s $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE websockets-test-86b48f4bb7-nltfh 1/1 Running 0 10s You can connect to the service — press Ctrl-D to exit: .. code-block:: console $ websockets ws://localhost:32080/ Connected to ws://localhost:32080/. Connection closed: 1000 (OK). Validate deployment ------------------- First, let's ensure the liveness probe works by making the app unresponsive: .. code-block:: console $ curl http://localhost:32080/inemuri Sleeping for 10s Since we have only one pod, we know that this pod will go to sleep. The liveness probe is configured to run every second. By default, liveness probes time out after one second and have a threshold of three failures. Therefore Kubernetes should restart the pod after at most 5 seconds. Indeed, after a few seconds, the pod reports a restart: .. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE websockets-test-86b48f4bb7-nltfh 1/1 Running 1 42s Next, let's take it one step further and crash the app: .. code-block:: console $ curl http://localhost:32080/seppuku Terminating The pod reports a second restart: .. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE websockets-test-86b48f4bb7-nltfh 1/1 Running 2 72s All good — Kubernetes delivers on its promise to keep our app alive! Scale deployment ---------------- Of course, Kubernetes is for scaling. Let's scale — modestly — to 10 pods: .. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=10 deployment.apps/websockets-test scaled After a few seconds, we have 10 pods running: .. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE websockets-test 10/10 10 10 10m Now let's generate load. We'll use this script: .. literalinclude:: ../../example/deployment/kubernetes/benchmark.py We'll connect 500 clients in parallel, meaning 50 clients per pod, and have each client send 6 messages. Since the app blocks for 100ms before responding, if connections are perfectly distributed, we expect a total run time slightly over 50 * 6 * 0.1 = 30 seconds. Let's try it: .. code-block:: console $ ulimit -n 512 $ time python benchmark.py 500 6 python benchmark.py 500 6 2.40s user 0.51s system 7% cpu 36.471 total A total runtime of 36 seconds is in the right ballpark. Repeating this experiment with other parameters shows roughly consistent results, with the high variability you'd expect from a quick benchmark without any effort to stabilize the test setup. Finally, we can scale back to one pod. .. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=1 deployment.apps/websockets-test scaled $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE websockets-test 1/1 1 1 15m websockets-15.0.1/docs/deploy/nginx.rst000066400000000000000000000051371476212450300200620ustar00rootroot00000000000000Deploy behind nginx =================== This guide demonstrates a way to load balance connections across multiple websockets server processes running on the same machine with nginx_. We'll run server processes with Supervisor as described in :doc:`this guide `. .. _nginx: https://nginx.org/ Run server processes -------------------- Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py :language: python We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the same OS. We start the app with :func:`~websockets.asyncio.server.unix_serve`. Each server process listens on a different socket thanks to an environment variable set by Supervisor to a different value. Save this configuration to ``supervisord.conf``: .. literalinclude:: ../../example/deployment/nginx/supervisord.conf This configuration runs four instances of the app. Install Supervisor and run it: .. code-block:: console $ supervisord -c supervisord.conf -n Configure and run nginx ----------------------- Here's a simple nginx configuration to load balance connections across four processes: .. literalinclude:: ../../example/deployment/nginx/nginx.conf We set ``daemon off`` so we can run nginx in the foreground for testing. Then we combine the `WebSocket proxying`_ and `load balancing`_ guides: * The WebSocket protocol requires HTTP/1.1. We must set the HTTP protocol version to 1.1, else nginx defaults to HTTP/1.0 for proxying. * The WebSocket handshake involves the ``Connection`` and ``Upgrade`` HTTP headers. We must pass them to the upstream explicitly, else nginx drops them because they're hop-by-hop headers. We deviate from the `WebSocket proxying`_ guide because its example adds a ``Connection: Upgrade`` header to every upstream request, even if the original request didn't contain that header. * In the upstream configuration, we set the load balancing method to ``least_conn`` in order to balance the number of active connections across servers. This is best for long running connections. .. _WebSocket proxying: http://nginx.org/en/docs/http/websocket.html .. _load balancing: http://nginx.org/en/docs/http/load_balancing.html Save the configuration to ``nginx.conf``, install nginx, and run it: .. code-block:: console $ nginx -c nginx.conf -p . You can confirm that nginx proxies connections properly: .. code-block:: console $ websockets ws://localhost:8080/ Connected to ws://localhost:8080/. > Hello! < Hello! Connection closed: 1000 (OK). websockets-15.0.1/docs/deploy/render.rst000066400000000000000000000117161476212450300202160ustar00rootroot00000000000000Deploy to Render ================ This guide describes how to deploy a websockets server to Render_. .. _Render: https://render.com/ .. admonition:: The free plan of Render is sufficient for trying this guide. :class: tip However, on a `free plan`__, connections are dropped after five minutes, which is quite short for WebSocket application. __ https://render.com/docs/free We're going to deploy a very simple app. The process would be identical for a more realistic app. Create repository ----------------- Deploying to Render requires a git repository. Let's initialize one: .. code-block:: console $ mkdir websockets-echo $ cd websockets-echo $ git init -b main Initialized empty Git repository in websockets-echo/.git/ $ git commit --allow-empty -m "Initial commit." [main (root-commit) 816c3b1] Initial commit. Render requires the git repository to be hosted at GitHub or GitLab. Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. Don't enable any of the initialization options offered by GitHub. Then, follow instructions for pushing an existing repository from the command line. After pushing, refresh your repository's homepage on GitHub. You should see an empty repository with an empty initial commit. Create application ------------------ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/render/app.py :language: python This app implements requirements for `zero downtime deploys`_: * it provides a health check at ``/healthz``; * it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. .. _zero downtime deploys: https://render.com/docs/deploys#zero-downtime-deploys Create a ``requirements.txt`` file containing this line to declare a dependency on websockets: .. literalinclude:: ../../example/deployment/render/requirements.txt :language: text Confirm that you created the correct files and commit them to git: .. code-block:: console $ ls app.py requirements.txt $ git add . $ git commit -m "Initial implementation." [main f26bf7f] Initial implementation. 2 files changed, 37 insertions(+) create mode 100644 app.py create mode 100644 requirements.txt Push the changes to GitHub: .. code-block:: console $ git push ... To github.com:/websockets-echo.git 816c3b1..f26bf7f main -> main The app is ready. Let's deploy it! Deploy application ------------------ Sign up or log in to Render. Create a new web service. Connect the git repository that you just created. Then, finalize the configuration of your app as follows: * **Name**: websockets-echo * **Start Command**: ``python app.py`` If you're just experimenting, select the free plan. Create the web service. To configure the health check, go to Settings, scroll down to Health & Alerts, and set: * **Health Check Path**: /healthz This triggers a new deployment. Validate deployment ------------------- Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: .. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate $ pip install websockets Connect the interactive client — you must replace ``websockets-echo`` with the name of your Render app in this command: .. code-block:: console $ websockets wss://websockets-echo.onrender.com/ Connected to wss://websockets-echo.onrender.com/. > Great! Your app is running! Once you're connected, you can send any message and the server will echo it, or press Ctrl-D to terminate the connection: .. code-block:: console > Hello! < Hello! Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully when you deploy a new version. Due to limitations of Render's free plan, you must upgrade to a paid plan before you perform this test. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: .. code-block:: console $ websockets wss://websockets-echo.onrender.com/ Connected to wss://websockets-echo.onrender.com/. > Trigger a new deployment with Manual Deploy > Deploy latest commit. When the deployment completes, the connection is closed with code 1001 (going away). .. code-block:: console $ websockets wss://websockets-echo.onrender.com/ Connected to wss://websockets-echo.onrender.com/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (abnormal closure). Remember to downgrade to a free plan if you upgraded just for testing this feature. websockets-15.0.1/docs/deploy/supervisor.rst000066400000000000000000000102751476212450300211570ustar00rootroot00000000000000Deploy with Supervisor ====================== This guide proposes a simple way to deploy a websockets server directly on a Linux or BSD operating system. We'll configure Supervisor_ to run several server processes and to restart them if needed. .. _Supervisor: http://supervisord.org/ We'll bind all servers to the same port. The OS will take care of balancing connections. Create and activate a virtualenv: .. code-block:: console $ python -m venv supervisor-websockets $ . supervisor-websockets/bin/activate Install websockets and Supervisor: .. code-block:: console $ pip install websockets $ pip install supervisor Save this app to a file called ``app.py``: .. literalinclude:: ../../example/deployment/supervisor/app.py This is an echo server with two features added for the purpose of this guide: * It shuts down gracefully when receiving a ``SIGTERM`` signal; * It enables the ``reuse_port`` option of :meth:`~asyncio.loop.create_server`, which in turns sets ``SO_REUSEPORT`` on the accept socket. Save this Supervisor configuration to ``supervisord.conf``: .. literalinclude:: ../../example/deployment/supervisor/supervisord.conf This is the minimal configuration required to keep four instances of the app running, restarting them if they exit. Now start Supervisor in the foreground: .. code-block:: console $ supervisord -c supervisord.conf -n INFO Increased RLIMIT_NOFILE limit to 1024 INFO supervisord started with pid 43596 INFO spawned: 'websockets-test_00' with pid 43597 INFO spawned: 'websockets-test_01' with pid 43598 INFO spawned: 'websockets-test_02' with pid 43599 INFO spawned: 'websockets-test_03' with pid 43600 INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) INFO success: websockets-test_01 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) INFO success: websockets-test_02 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) INFO success: websockets-test_03 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) In another shell, after activating the virtualenv, we can connect to the app — press Ctrl-D to exit: .. code-block:: console $ websockets ws://localhost:8080/ Connected to ws://localhost:8080/. > Hello! < Hello! Connection closed: 1000 (OK). Look at the pid of an instance of the app in the logs and terminate it: .. code-block:: console $ kill -TERM 43597 The logs show that Supervisor restarted this instance: .. code-block:: console INFO exited: websockets-test_00 (exit status 0; expected) INFO spawned: 'websockets-test_00' with pid 43629 INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) Now let's check what happens when we shut down Supervisor, but first let's establish a connection and leave it open: .. code-block:: console $ websockets ws://localhost:8080/ Connected to ws://localhost:8080/. > Look at the pid of supervisord itself in the logs and terminate it: .. code-block:: console $ kill -TERM 43596 The logs show that Supervisor terminated all instances of the app before exiting: .. code-block:: console WARN received SIGTERM indicating exit request INFO waiting for websockets-test_00, websockets-test_01, websockets-test_02, websockets-test_03 to die INFO stopped: websockets-test_02 (exit status 0) INFO stopped: websockets-test_03 (exit status 0) INFO stopped: websockets-test_01 (exit status 0) INFO stopped: websockets-test_00 (exit status 0) And you can see that the connection to the app was closed gracefully: .. code-block:: console $ websockets ws://localhost:8080/ Connected to ws://localhost:8080/. Connection closed: 1001 (going away). In this example, we've been sharing the same virtualenv for supervisor and websockets. In a real deployment, you would likely: * Install Supervisor with the package manager of the OS. * Create a virtualenv dedicated to your application. * Add ``environment=PATH="path/to/your/virtualenv/bin"`` in the Supervisor configuration. Then ``python app.py`` runs in that virtualenv. websockets-15.0.1/docs/faq/000077500000000000000000000000001476212450300154525ustar00rootroot00000000000000websockets-15.0.1/docs/faq/asyncio.rst000066400000000000000000000055001476212450300176510ustar00rootroot00000000000000Using asyncio ============= .. currentmodule:: websockets.asyncio.connection .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. How do I run two coroutines in parallel? ---------------------------------------- You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. Keep track of the tasks and make sure that they terminate or that you cancel them when the connection terminates. Why does my program never receive any messages? ----------------------------------------------- Your program runs a coroutine that never yields control to the event loop. The coroutine that receives messages never gets a chance to run. Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. For example, :meth:`~Connection.send` only yields control when send buffers are full, which never happens in most practical cases. If you run a loop that contains only synchronous operations and a :meth:`~Connection.send` call, you must yield control explicitly with :func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() await websocket.send(message) await asyncio.sleep(0) # yield control to the event loop :func:`asyncio.sleep` always suspends the current task, allowing other tasks to run. This behavior is documented precisely because it isn't expected from every coroutine. See `issue 867`_. .. _issue 867: https://github.com/python-websockets/websockets/issues/867 Why am I having problems with threads? -------------------------------------- If you choose websockets' :mod:`asyncio` implementation, then you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency is mutually exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. If it has to run in another thread because it would block the event loop, :func:`~asyncio.to_thread` or :meth:`~asyncio.loop.run_in_executor` is the way to go. Please review the advice about :ref:`asyncio-multithreading` in the Python documentation. Why does my simple program misbehave mysteriously? -------------------------------------------------- You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. This may lead to messages getting send but not received, to connection timeouts, and to unexpected results of shotgun debugging e.g. adding an unnecessary call to a coroutine makes the program functional. websockets-15.0.1/docs/faq/client.rst000066400000000000000000000067661476212450300175010ustar00rootroot00000000000000Client ====== .. currentmodule:: websockets.asyncio.client .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. They translate to the :mod:`threading` implementation by removing ``await`` and ``async`` keywords and by using a :class:`~threading.Thread` instead of a :class:`~asyncio.Task` for concurrent execution. Why does the client close the connection prematurely? ----------------------------------------------------- You're exiting the context manager prematurely. Wait for the work to be finished before exiting. For example, if your code has a structure similar to:: async with connect(...) as websocket: asyncio.create_task(do_some_work()) change it to:: async with connect(...) as websocket: await do_some_work() How do I access HTTP headers? ----------------------------- Once the connection is established, HTTP headers are available in the :attr:`~ClientConnection.request` and :attr:`~ClientConnection.response` objects:: async with connect(...) as websocket: websocket.request.headers websocket.response.headers How do I set HTTP headers? -------------------------- To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the ``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~connect`. To override the ``User-Agent`` header, use the ``user_agent_header`` argument. Set it to :obj:`None` to remove the header. To set other HTTP headers, for example the ``Authorization`` header, use the ``additional_headers`` argument:: async with connect(..., additional_headers={"Authorization": ...}) as websocket: ... In the legacy :mod:`asyncio` API, this argument is named ``extra_headers``. How do I force the IP address that the client connects to? ---------------------------------------------------------- Use the ``host`` argument :func:`~connect`:: async with connect(..., host="192.168.0.1") as websocket: ... :func:`~connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection` and passes them through. How do I close a connection? ---------------------------- The easiest is to use :func:`~connect` as a context manager:: async with connect(...) as websocket: ... The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- Use :func:`connect` as an asynchronous iterator:: from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed async for websocket in connect(...): try: ... except ConnectionClosed: continue Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions will break out of the loop. How do I stop a client that is processing messages in a loop? ------------------------------------------------------------- You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py :emphasize-lines: 10-12 How do I disable TLS/SSL certificate verification? -------------------------------------------------- Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. :func:`~connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection` and passes them through. websockets-15.0.1/docs/faq/common.rst000066400000000000000000000116251476212450300175010ustar00rootroot00000000000000Both sides ========== .. currentmodule:: websockets.asyncio.connection What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent or if a client crashes with this traceback: .. code-block:: pytb Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent it means that the TCP connection was lost. As a consequence, the WebSocket connection was closed without receiving and sending a close frame, which is abnormal. You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to prevent it from being logged. There are several reasons why long-lived connections may be lost: * End-user devices tend to lose network connectivity often and unpredictably because they can move out of wireless network coverage, get unplugged from a wired network, enter airplane mode, be put to sleep, etc. * HTTP load balancers or proxies that aren't configured for long-lived connections may terminate connections after a short amount of time, usually 30 seconds, despite websockets' keepalive mechanism. If you're facing a reproducible issue, :doc:`enable debug logs <../howto/debugging>` to see when and how connections are closed. What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? --------------------------------------------------------------------------------------------------------------------- If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received or if a client crashes with this traceback: .. code-block:: pytb Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to prevent it from being logged. There are two main reasons why latency may increase: * Poor network connectivity. * More traffic than the recipient can handle. See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. How do I set a timeout on :meth:`~Connection.recv`? --------------------------------------------------- On Python ≥ 3.11, use :func:`asyncio.timeout`:: async with asyncio.timeout(timeout=10): message = await websocket.recv() On older versions of Python, use :func:`asyncio.wait_for`:: message = await asyncio.wait_for(websocket.recv(), timeout=10) This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. How can I pass arguments to a custom connection subclass? --------------------------------------------------------- You can bind additional arguments to the connection factory with :func:`functools.partial`:: import asyncio import functools from websockets.asyncio.server import ServerConnection, serve class MyServerConnection(ServerConnection): def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument create_connection = functools.partial(ServerConnection, extra_argument=42) async with serve(..., create_connection=create_connection): ... This example was for a server. The same pattern applies on a client. How do I keep idle connections open? ------------------------------------ websockets sends pings at 20 seconds intervals to keep the connection open. It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. See :doc:`../topics/keepalive` for details. How do I respond to pings? -------------------------- If you are referring to Ping_ and Pong_ frames defined in the WebSocket protocol, don't bother, because websockets handles them for you. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 If you are connecting to a server that defines its own heartbeat at the application level, then you need to build that logic into your application. websockets-15.0.1/docs/faq/index.rst000066400000000000000000000012311476212450300173100ustar00rootroot00000000000000Frequently asked questions ========================== .. currentmodule:: websockets .. admonition:: Many questions asked in websockets' issue tracker are really about :mod:`asyncio`. :class: seealso If you're new to ``asyncio``, you will certainly encounter issues that are related to asynchronous programming in general rather than to websockets in particular. Fortunately, Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html .. toctree:: server client common asyncio misc websockets-15.0.1/docs/faq/misc.rst000066400000000000000000000030331476212450300171360ustar00rootroot00000000000000Miscellaneous ============= .. currentmodule:: websockets .. Remove this question when dropping Python < 3.13, which provides natively .. a good error message in this case. Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. Why is websockets slower than another library in my benchmark? .............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, you must disable: * Compression: set ``compression=None`` * Keepalive: set ``ping_interval=None`` * UTF-8 decoding: send ``bytes`` rather than ``str`` Then, please consider whether websockets is the bottleneck of the performance of your application. Usually, in real-world applications, CPU time spent in websockets is negligible compared to time spent in the application logic. Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ No, there aren't. websockets provides high-level, coroutine-based APIs. Compared to callbacks, coroutines make it easier to manage control flow in concurrent code. If you prefer callback-based APIs, you should use another library. websockets-15.0.1/docs/faq/server.rst000066400000000000000000000260141476212450300175150ustar00rootroot00000000000000Server ====== .. currentmodule:: websockets.asyncio.server .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. They translate to the :mod:`threading` implementation by removing ``await`` and ``async`` keywords and by using a :class:`~threading.Thread` instead of a :class:`~asyncio.Task` for concurrent execution. Why does the server close the connection prematurely? ----------------------------------------------------- Your connection handler exits prematurely. Wait for the work to be finished before returning. For example, if your handler has a structure similar to:: async def handler(websocket): asyncio.create_task(do_some_work()) change it to:: async def handler(websocket): await do_some_work() Why does the server close the connection after one message? ----------------------------------------------------------- Your connection handler exits after processing one message. Write a loop to process multiple messages. For example, if your handler looks like this:: async def handler(websocket): print(websocket.recv()) change it like this:: async def handler(websocket): async for message in websocket: print(message) If you have prior experience with an API that relies on callbacks, you may assume that ``handler()`` is executed every time a message is received. The API of websockets relies on coroutines instead. The handler coroutine is started when a new connection is established. Then, it is responsible for receiving or sending messages throughout the lifetime of that connection. Why can only one client connect at a time? ------------------------------------------ Your connection handler blocks the event loop. Look for blocking calls. Any call that may take some time must be asynchronous. For example, this connection handler prevents the event loop from running during one second:: async def handler(websocket): time.sleep(1) ... Change it to:: async def handler(websocket): await asyncio.sleep(1) ... In addition, calling a coroutine doesn't guarantee that it will yield control to the event loop. For example, this connection handler blocks the event loop by sending messages continuously:: async def handler(websocket): while True: await websocket.send("firehose!") :meth:`~ServerConnection.send` completes synchronously as long as there's space in send buffers. The event loop never runs. (This pattern is uncommon in real-world applications. It occurs mostly in toy programs.) You can avoid the issue by yielding control to the event loop explicitly:: async def handler(websocket): while True: await websocket.send("firehose!") await asyncio.sleep(0) All this is part of learning asyncio. It isn't specific to websockets. See also Python's documentation about `running blocking code`_. .. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code .. _send-message-to-all-users: How do I send a message to all users? ------------------------------------- Record all connections in a global variable:: CONNECTIONS = set() async def handler(websocket): CONNECTIONS.add(websocket) try: await websocket.wait_closed() finally: CONNECTIONS.remove(websocket) Then, call :func:`broadcast`:: from websockets.asyncio.server import broadcast def message_all(message): broadcast(CONNECTIONS, message) If you're running multiple server processes, make sure you call ``message_all`` in each process. .. _send-message-to-single-user: How do I send a message to a single user? ----------------------------------------- Record connections in a global variable, keyed by user identifier:: CONNECTIONS = {} async def handler(websocket): user_id = ... # identify user in your app's context CONNECTIONS[user_id] = websocket try: await websocket.wait_closed() finally: del CONNECTIONS[user_id] Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected await websocket.send(message) # may raise websockets.exceptions.ConnectionClosed Add error handling according to the behavior you want if the user disconnected before the message could be sent. This example supports only one connection per user. To support concurrent connections by the same user, you can change ``CONNECTIONS`` to store a set of connections for each user. If you're running multiple server processes, call ``message_user`` in each process. The process managing the user's connection sends the message; other processes do nothing. When you reach a scale where server processes cannot keep up with the stream of all messages, you need a better architecture. For example, you could deploy an external publish / subscribe system such as Redis_. Server processes would subscribe their clients. Then, they would receive messages only for the connections that they're managing. .. _Redis: https://redis.io/ How do I send a message to a channel, a topic, or some users? ------------------------------------------------------------- websockets doesn't provide built-in publish / subscribe functionality. Record connections in a global variable, keyed by user identifier, as shown in :ref:`How do I send a message to a single user?` Then, build the set of recipients and broadcast the message to them, as shown in :ref:`How do I send a message to all users?` :doc:`../howto/django` contains a complete implementation of this pattern. Again, as you scale, you may reach the performance limits of a basic in-process implementation. You may need an external publish / subscribe system like Redis_. .. _Redis: https://redis.io/ How do I pass arguments to the connection handler? -------------------------------------------------- You can bind additional arguments to the connection handler with :func:`functools.partial`:: import functools async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument=42) Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. How do I access the request path? --------------------------------- It is available in the :attr:`~ServerConnection.request` object. Refer to the :doc:`routing guide <../topics/routing>` for details on how to route connections to different handlers depending on the request path. How do I access HTTP headers? ----------------------------- You can access HTTP headers during the WebSocket handshake by providing a ``process_request`` callable or coroutine:: def process_request(connection, request): authorization = request.headers["Authorization"] ... async with serve(handler, process_request=process_request): ... Once the connection is established, HTTP headers are available in the :attr:`~ServerConnection.request` and :attr:`~ServerConnection.response` objects:: async def handler(websocket): authorization = websocket.request.headers["Authorization"] How do I set HTTP headers? -------------------------- To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` arguments of :func:`~serve`. To override the ``Server`` header, use the ``server_header`` argument. Set it to :obj:`None` to remove the header. To set other HTTP headers, provide a ``process_response`` callable or coroutine:: def process_response(connection, request, response): response.headers["X-Blessing"] = "May the network be with you" async with serve(handler, process_response=process_response): ... How do I get the IP address of the client? ------------------------------------------ It's available in :attr:`~ServerConnection.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] How do I set the IP addresses that my server listens on? -------------------------------------------------------- Use the ``host`` argument of :meth:`~serve`:: async with serve(handler, host="192.168.0.1", port=8080): ... :func:`~serve` accepts the same arguments as :meth:`~asyncio.loop.create_server` and passes them through. What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -------------------------------------------------------------------------------------------------------------------------- You are calling :func:`~serve` without a ``host`` argument in a context where IPv6 isn't available. To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. How do I close a connection? ---------------------------- websockets takes care of closing the connection when the handler exits. How do I stop a server? ----------------------- Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 14-16 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- Call the server's :meth:`~Server.close` method with ``close_connections=False``. Here's how to adapt the example just above:: async def server(): ... server = await serve(echo, "localhost", 8765) await stop server.close(close_connections=False) await server.wait_closed() How do I implement a health check? ---------------------------------- Intercept requests with the ``process_request`` hook. When a request is sent to the health check endpoint, treat is as an HTTP request and return a response: .. literalinclude:: ../../example/faq/health_check_server.py :emphasize-lines: 7-9,16 :meth:`~ServerConnection.respond` makes it easy to send a plain text response. You can also construct a :class:`~websockets.http11.Response` object directly. How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- You don't. HTTP and WebSocket have widely different operational characteristics. Running them with the same server becomes inconvenient when you scale. Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the ``process_request`` hook. If you need more, pick an HTTP server and run it separately. Alternatively, pick an HTTP framework that builds on top of ``websockets`` to support WebSocket connections, like Sanic_. .. _Sanic: https://sanicframework.org/en/ websockets-15.0.1/docs/howto/000077500000000000000000000000001476212450300160435ustar00rootroot00000000000000websockets-15.0.1/docs/howto/autoreload.rst000066400000000000000000000021401476212450300207310ustar00rootroot00000000000000Reload on code changes ====================== When developing a websockets server, you are likely to run it locally to test changes. Unfortunately, whenever you want to try a new version of the code, you must stop the server and restart it, which slows down your development process. Web frameworks such as Django or Flask provide a development server that reloads the application automatically when you make code changes. There is no equivalent functionality in websockets because it's designed only for production. However, you can achieve the same result easily with a third-party library and a shell command. Install watchdog_ with the ``watchmedo`` shell utility: .. code-block:: console $ pip install 'watchdog[watchmedo]' .. _watchdog: https://pypi.org/project/watchdog/ Run your server with ``watchmedo auto-restart``: .. code-block:: console $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ python app.py This example assumes that the server is defined in a script called ``app.py`` and exits cleanly when receiving the ``SIGTERM`` signal. Adapt as necessary. websockets-15.0.1/docs/howto/debugging.rst000066400000000000000000000016261476212450300205350ustar00rootroot00000000000000Enable debug logs ================== websockets logs events with the :mod:`logging` module from the standard library. It emits logs in the ``"websockets.server"`` and ``"websockets.client"`` loggers. You can enable logs at the ``DEBUG`` level to see exactly what websockets does. If logging isn't configured in your application:: import logging logging.basicConfig( format="%(asctime)s %(message)s", level=logging.DEBUG, ) If logging is already configured:: import logging logger = logging.getLogger("websockets") logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) Refer to the :doc:`logging guide <../topics/logging>` for more information about logging in websockets. You may also enable asyncio's `debug mode`_ to get warnings about classic pitfalls. .. _debug mode: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode websockets-15.0.1/docs/howto/django.rst000066400000000000000000000252021476212450300200400ustar00rootroot00000000000000Integrate with Django ===================== If you're looking at adding real-time capabilities to a Django project with WebSocket, you have two main options. 1. Using Django Channels_, a project adding WebSocket to Django, among other features. This approach is fully supported by Django. However, it requires switching to a new deployment architecture. 2. Deploying a separate WebSocket server next to your Django project. This technique is well suited when you need to add a small set of real-time features — maybe a notification service — to an HTTP application. .. _Channels: https://channels.readthedocs.io/ This guide shows how to implement the second technique with websockets. It assumes familiarity with Django. Authenticate connections ------------------------ Since the websockets server runs outside of Django, we need to integrate it with ``django.contrib.auth``. We will generate authentication tokens in the Django project. Then we will send them to the websockets server, where they will authenticate the user. Generating a token for the current user and making it available in the browser is up to you. You could render the token in a template or fetch it with an API call. Refer to the topic guide on :doc:`authentication <../topics/authentication>` for details on this design. Generate tokens ............... We want secure, short-lived tokens containing the user ID. We'll rely on `django-sesame`_, a small library designed exactly for this purpose. .. _django-sesame: https://github.com/aaugustin/django-sesame Add django-sesame to the dependencies of your Django project, install it, and configure it in the settings of the project: .. code-block:: python AUTHENTICATION_BACKENDS = [ "django.contrib.auth.backends.ModelBackend", "sesame.backends.ModelBackend", ] (If your project already uses another authentication backend than the default ``"django.contrib.auth.backends.ModelBackend"``, adjust accordingly.) You don't need ``"sesame.middleware.AuthenticationMiddleware"``. It is for authenticating users in the Django server, while we're authenticating them in the websockets server. We'd like our tokens to be valid for 30 seconds. We expect web pages to load and to establish the WebSocket connection within this delay. Configure django-sesame accordingly in the settings of your Django project: .. code-block:: python SESAME_MAX_AGE = 30 If you expect your web site to load faster for all clients, a shorter lifespan is possible. However, in the context of this document, it would make manual testing more difficult. You could also enable single-use tokens. However, this would update the last login date of the user every time a WebSocket connection is established. This doesn't seem like a good idea, both in terms of behavior and in terms of performance. Now you can generate tokens in a ``django-admin shell`` as follows: .. code-block:: pycon >>> from django.contrib.auth import get_user_model >>> User = get_user_model() >>> user = User.objects.get(username="") >>> from sesame.utils import get_token >>> get_token(user) '' Keep this console open: since tokens expire after 30 seconds, you'll have to generate a new token every time you want to test connecting to the server. Validate tokens ............... Let's move on to the websockets server. Add websockets to the dependencies of your Django project and install it. Indeed, we're going to reuse the environment of the Django project, so we can call its APIs in the websockets server. Now here's how to implement authentication. .. literalinclude:: ../../example/django/authentication.py :caption: authentication.py Let's unpack this code. We're calling ``django.setup()`` before doing anything with Django because we're using Django in a `standalone script`_. This assumes that the ``DJANGO_SETTINGS_MODULE`` environment variable is set to the Python path to your settings module. .. _standalone script: https://docs.djangoproject.com/en/stable/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage The connection handler reads the first message received from the client, which is expected to contain a django-sesame token. Then it authenticates the user with :func:`~sesame.utils.get_user`, the API provided by django-sesame for `authentication outside a view`_. .. _authentication outside a view: https://django-sesame.readthedocs.io/en/stable/howto.html#outside-a-view If authentication fails, it closes the connection and exits. When we call an API that makes a database query such as :func:`~sesame.utils.get_user`, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! Download :download:`authentication.py <../../example/django/authentication.py>`, make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the websockets server: .. code-block:: console $ python authentication.py Generate a new token — remember, they're only valid for 30 seconds — and use it to connect to your server. Paste your token and press Enter when you get a prompt: .. code-block:: console $ websockets ws://localhost:8888/ Connected to ws://localhost:8888/ > < Hello ! Connection closed: 1000 (OK). It works! If you enter an expired or invalid token, authentication fails and the server closes the connection: .. code-block:: console $ websockets ws://localhost:8888/ Connected to ws://localhost:8888. > not a token Connection closed: 1011 (internal error) authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: .. code-block:: javascript websocket = new WebSocket("ws://localhost:8888/"); websocket.onopen = (event) => websocket.send(""); websocket.onmessage = (event) => console.log(event.data); If you don't want to import your entire Django project into the websockets server, you can create a simpler Django project with ``django.contrib.auth``, ``django-sesame``, a suitable ``User`` model, and a subset of the settings of the main project. Stream events ------------- We can connect and authenticate but our server doesn't do anything useful yet! Let's send a message every time a user makes an action in the admin. This message will be broadcast to all users who can access the model on which the action was made. This may be used for showing notifications to other users. Many use cases for WebSocket with Django follow a similar pattern. Set up event stream ................... We need an event stream to enable communications between Django and websockets. Both sides connect permanently to the stream. Then Django writes events and websockets reads them. For the sake of simplicity, we'll rely on `Redis Pub/Sub`_. .. _Redis Pub/Sub: https://redis.io/topics/pubsub The easiest way to add Redis to a Django project is by configuring a cache backend with `django-redis`_. This library manages connections to Redis efficiently, persisting them between requests, and provides an API to access the Redis connection directly. .. _django-redis: https://github.com/jazzband/django-redis Install Redis, add django-redis to the dependencies of your Django project, install it, and configure it in the settings of the project: .. code-block:: python CACHES = { "default": { "BACKEND": "django_redis.cache.RedisCache", "LOCATION": "redis://127.0.0.1:6379/1", }, } If you already have a default cache, add a new one with a different name and change ``get_redis_connection("default")`` in the code below to the same name. Publish events .............. Now let's write events to the stream. Add the following code to a module that is imported when your Django project starts. Typically, you would put it in a :download:`signals.py <../../example/django/signals.py>` module, which you would import in the ``AppConfig.ready()`` method of one of your apps: .. literalinclude:: ../../example/django/signals.py :caption: signals.py This code runs every time the admin saves a ``LogEntry`` object to keep track of a change. It extracts interesting data, serializes it to JSON, and writes an event to Redis. Let's check that it works: .. code-block:: console $ redis-cli 127.0.0.1:6379> SELECT 1 OK 127.0.0.1:6379[1]> SUBSCRIBE events Reading messages... (press Ctrl-C to quit) 1) "subscribe" 2) "events" 3) (integer) 1 Leave this command running, start the Django development server and make changes in the admin: add, modify, or delete objects. You should see corresponding events published to the ``"events"`` stream. Broadcast events ................ Now let's turn to reading events and broadcasting them to connected clients. We need to add several features: * Keep track of connected clients so we can broadcast messages. * Tell which content types the user has permission to view or to change. * Connect to the message stream and read events. * Broadcast these events to users who have corresponding permissions. Here's a complete implementation. .. literalinclude:: ../../example/django/notifications.py :caption: notifications.py Since the ``get_content_types()`` function makes a database query, it is wrapped inside :func:`asyncio.to_thread()`. It runs once when each WebSocket connection is open; then its result is cached for the lifetime of the connection. Indeed, running it for each message would trigger database queries for all connected users at the same time, which would hurt the database. The connection handler merely registers the connection in a global variable, associated to the list of content types for which events should be sent to that connection, and waits until the client disconnects. The ``process_events()`` function reads events from Redis and broadcasts them to all connections that should receive them. We don't care much if a sending a notification fails. This happens when a connection drops between the moment we iterate on connections and the moment the corresponding message is sent. Since Redis can publish a message to multiple subscribers, multiple instances of this server can safely run in parallel. Does it scale? -------------- In theory, given enough servers, this design can scale to a hundred million clients, since Redis can handle ten thousand servers and each server can handle ten thousand clients. In practice, you would need a more scalable message stream before reaching that scale, due to the volume of messages. websockets-15.0.1/docs/howto/encryption.rst000066400000000000000000000042131476212450300207670ustar00rootroot00000000000000Encrypt connections ==================== .. currentmodule:: websockets You should always secure WebSocket connections with TLS_ (Transport Layer Security). .. admonition:: TLS vs. SSL :class: tip TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an earlier encryption protocol; the name stuck. The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. Secure WebSocket connections require certificates just like HTTPS. .. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security .. admonition:: Configure the TLS context securely :class: attention The examples below demonstrate the ``ssl`` argument with a TLS certificate shared between the client and the server. This is a simplistic setup. Please review the advice and security considerations in the documentation of the :mod:`ssl` module to configure the TLS context appropriately. Servers ------- In a typical :doc:`deployment <../deploy/index>`, the server is behind a reverse proxy that terminates TLS. The client connects to the reverse proxy with TLS and the reverse proxy connects to the server without TLS. In that case, you don't need to configure TLS in websockets. If needed in your setup, you can terminate TLS in the server. In the example below, :func:`~asyncio.server.serve` is configured to receive secure connections. Before running this server, download :download:`localhost.pem <../../example/tls/localhost.pem>` and save it in the same directory as ``server.py``. .. literalinclude:: ../../example/tls/server.py :caption: server.py Receive both plain and TLS connections on the same port isn't supported. Clients ------- :func:`~asyncio.client.connect` enables TLS automatically when connecting to a ``wss://...`` URI. This works out of the box when the TLS certificate of the server is valid, meaning it's signed by a certificate authority that your Python installation trusts. In the example above, since the server uses a self-signed certificate, the client needs to be configured to trust the certificate. Here's how to do so. .. literalinclude:: ../../example/tls/client.py :caption: client.py websockets-15.0.1/docs/howto/extensions.rst000066400000000000000000000025441476212450300210010ustar00rootroot00000000000000Write an extension ================== .. currentmodule:: websockets During the opening handshake, WebSocket clients and servers negotiate which extensions_ will be used and with which parameters. .. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 Then, each frame is processed before being sent and after being received according to the extensions that were negotiated. Writing an extension requires implementing at least two classes, an extension factory and an extension. They inherit from base classes provided by websockets. Extension factory ----------------- An extension factory negotiates parameters and instantiates the extension. Clients and servers require separate extension factories with distinct APIs. Base classes are :class:`~extensions.ClientExtensionFactory` and :class:`~extensions.ServerExtensionFactory`. Extension factories are the public API of an extension. Extensions are enabled with the ``extensions`` parameter of :func:`~asyncio.client.connect` or :func:`~asyncio.server.serve`. Extension --------- An extension decodes incoming frames and encodes outgoing frames. If the extension is symmetrical, clients and servers can use the same class. The base class is :class:`~extensions.Extension`. Since extensions are initialized by extension factories, they don't need to be part of the public API of an extension. websockets-15.0.1/docs/howto/index.rst000066400000000000000000000014011476212450300177000ustar00rootroot00000000000000How-to guides ============= Set up your development environment comfortably. .. toctree:: autoreload debugging Configure websockets securely in production. .. toctree:: encryption These guides will help you design and build your application. .. toctree:: :maxdepth: 2 patterns django Upgrading from the legacy :mod:`asyncio` implementation to the new one? Read this. .. toctree:: :maxdepth: 2 upgrade If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. .. toctree:: :maxdepth: 2 sansio The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. .. toctree:: extensions websockets-15.0.1/docs/howto/patterns.rst000066400000000000000000000077561476212450300204540ustar00rootroot00000000000000Design a WebSocket application ============================== .. currentmodule:: websockets WebSocket server or client applications follow common patterns. This guide describes patterns that you're likely to implement in your application. All examples are connection handlers for a server. However, they would also apply to a client, assuming that ``websocket`` is a connection created with :func:`~asyncio.client.connect`. .. admonition:: WebSocket connections are long-lived. :class: tip You need a loop to process several messages during the lifetime of a connection. Consumer pattern ---------------- To receive messages from the WebSocket connection:: async def consumer_handler(websocket): async for message in websocket: await consume(message) In this example, ``consume()`` is a coroutine implementing your business logic for processing a message received on the WebSocket connection. Iteration terminates when the client disconnects. Producer pattern ---------------- To send messages to the WebSocket connection:: from websockets.exceptions import ConnectionClosed async def producer_handler(websocket): while True: try: message = await produce() await websocket.send(message) except ConnectionClosed: break In this example, ``produce()`` is a coroutine implementing your business logic for generating the next message to send on the WebSocket connection. Iteration terminates when the client disconnects because :meth:`~asyncio.server.ServerConnection.send` raises a :exc:`~exceptions.ConnectionClosed` exception, which breaks out of the ``while True`` loop. Consumer and producer --------------------- You can receive and send messages on the same WebSocket connection by combining the consumer and producer patterns. This requires running two tasks in parallel. The simplest option offered by :mod:`asyncio` is:: import asyncio async def handler(websocket): await asyncio.gather( consumer_handler(websocket), producer_handler(websocket), ) If a task terminates, :func:`~asyncio.gather` doesn't cancel the other task. This can result in a situation where the producer keeps running after the consumer finished, which may leak resources. Here's a way to exit and close the WebSocket connection as soon as a task terminates, after canceling the other task:: async def handler(websocket): consumer_task = asyncio.create_task(consumer_handler(websocket)) producer_task = asyncio.create_task(producer_handler(websocket)) done, pending = await asyncio.wait( [consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED, ) for task in pending: task.cancel() Registration ------------ To keep track of currently connected clients, you can register them when they connect and unregister them when they disconnect:: connected = set() async def handler(websocket): # Register. connected.add(websocket) try: # Broadcast a message to all connected clients. broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. connected.remove(websocket) This example maintains the set of connected clients in memory. This works as long as you run a single process. It doesn't scale to multiple processes. If you just need the set of connected clients, as in this example, use the :attr:`~asyncio.server.Server.connections` property of the server. This pattern is needed only when recording additional information about each client. Publish–subscribe ----------------- If you plan to run multiple processes and you want to communicate updates between processes, then you must deploy a messaging system. You may find publish-subscribe functionality useful. A complete implementation of this idea with Redis is described in the :doc:`Django integration guide <../howto/django>`. websockets-15.0.1/docs/howto/sansio.rst000066400000000000000000000260741476212450300201020ustar00rootroot00000000000000Integrate the Sans-I/O layer ============================ .. currentmodule:: websockets This guide explains how to integrate the `Sans-I/O`_ layer of websockets to add support for WebSocket in another library. .. _Sans-I/O: https://sans-io.readthedocs.io/ As a prerequisite, you should decide how you will handle network I/O and asynchronous control flow. Your integration layer will provide an API for the application on one side, will talk to the network on the other side, and will rely on websockets to implement the protocol in the middle. .. image:: ../topics/data-flow.svg :align: center Opening a connection -------------------- Client-side ........... If you're building a client, parse the URI you'd like to connect to:: from websockets.uri import parse_uri uri = parse_uri("ws://example.com/") Open a TCP connection to ``(uri.host, uri.port)`` and perform a TLS handshake if ``uri.secure`` is :obj:`True`. Initialize a :class:`~client.ClientProtocol`:: from websockets.client import ClientProtocol protocol = ClientProtocol(uri) Create a WebSocket handshake request with :meth:`~client.ClientProtocol.connect` and send it with :meth:`~client.ClientProtocol.send_request`:: request = protocol.connect() protocol.send_request(request) Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. Once you receive enough data, as explained in `Receive data`_ below, the first event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket handshake response. When the handshake fails, the reason is available in :attr:`~client.ClientProtocol.handshake_exc`:: if protocol.handshake_exc is not None: raise protocol.handshake_exc Else, the WebSocket connection is open. A WebSocket client API usually performs the handshake then returns a wrapper around the network socket and the :class:`~client.ClientProtocol`. Server-side ........... If you're building a server, accept network connections from clients and perform a TLS handshake if desired. For each connection, initialize a :class:`~server.ServerProtocol`:: from websockets.server import ServerProtocol protocol = ServerProtocol() Once you receive enough data, as explained in `Receive data`_ below, the first event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket handshake request. Create a WebSocket handshake response with :meth:`~server.ServerProtocol.accept` and send it with :meth:`~server.ServerProtocol.send_response`:: response = protocol.accept(request) protocol.send_response(response) Alternatively, you may reject the WebSocket handshake and return an HTTP response with :meth:`~server.ServerProtocol.reject`:: response = protocol.reject(status, explanation) protocol.send_response(response) Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. Even when you call :meth:`~server.ServerProtocol.accept`, the WebSocket handshake may fail if the request is incorrect or unsupported. When the handshake fails, the reason is available in :attr:`~server.ServerProtocol.handshake_exc`:: if protocol.handshake_exc is not None: raise protocol.handshake_exc Else, the WebSocket connection is open. A WebSocket server API usually builds a wrapper around the network socket and the :class:`~server.ServerProtocol`. Then it invokes a connection handler that accepts the wrapper in argument. It may also provide a way to close all connections and to shut down the server gracefully. Going forwards, this guide focuses on handling an individual connection. From the network to the application ----------------------------------- Go through the five steps below until you reach the end of the data stream. Receive data ............ When receiving data from the network, feed it to the protocol's :meth:`~protocol.Protocol.receive_data` method. When reaching the end of the data stream, call the protocol's :meth:`~protocol.Protocol.receive_eof` method. For example, if ``sock`` is a :obj:`~socket.socket`:: try: data = sock.recv(65536) except OSError: # socket closed data = b"" if data: protocol.receive_data(data) else: protocol.receive_eof() These methods aren't expected to raise exceptions — unless you call them again after calling :meth:`~protocol.Protocol.receive_eof`, which is an error. (If you get an exception, please file a bug!) Send data ......... Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network:: for data in protocol.data_to_send(): if data: sock.sendall(data) else: sock.shutdown(socket.SHUT_WR) The empty bytestring signals the end of the data stream. When you see it, you must half-close the TCP connection. Sending data right after receiving data is necessary because websockets responds to ping frames, close frames, and incorrect inputs automatically. Expect TCP connection to close .............................. Closing a WebSocket connection normally involves a two-way WebSocket closing handshake. Then, regardless of whether the closure is normal or abnormal, the server starts the four-way TCP closing handshake. If the network fails at the wrong point, you can end up waiting until the TCP timeout, which is very long. To prevent dangling TCP connections when you expect the end of the data stream but you never reach it, call :meth:`~protocol.Protocol.close_expected` and, if it returns :obj:`True`, schedule closing the TCP connection after a short timeout:: # start a new execution thread to run this code sleep(10) sock.close() # does nothing if the socket is already closed If the connection is still open when the timeout elapses, closing the socket makes the execution thread that reads from the socket reach the end of the data stream, possibly with an exception. Close TCP connection .................... If you called :meth:`~protocol.Protocol.receive_eof`, close the TCP connection now. This is a clean closure because the receive buffer is empty. After :meth:`~protocol.Protocol.receive_eof` signals the end of the read stream, :meth:`~protocol.Protocol.data_to_send` always signals the end of the write stream, unless it already ended. So, at this point, the TCP connection is already half-closed. The only reason for closing it now is to release resources related to the socket. Now you can exit the loop relaying data from the network to the application. Receive events .............. Finally, call :meth:`~protocol.Protocol.events_received` to obtain events parsed from the data provided to :meth:`~protocol.Protocol.receive_data`:: events = connection.events_received() The first event will be the WebSocket opening handshake request or response. See `Opening a connection`_ above for details. All later events are WebSocket frames. There are two types of frames: * Data frames contain messages transferred over the WebSocket connections. You should provide them to the application. See `Fragmentation`_ below for how to reassemble messages from frames. * Control frames provide information about the connection's state. The main use case is to expose an abstraction over ping and pong to the application. Keep in mind that websockets responds to ping frames and close frames automatically. Don't duplicate this functionality! From the application to the network ----------------------------------- The connection object provides one method for each type of WebSocket frame. For sending a data frame: * :meth:`~protocol.Protocol.send_continuation` * :meth:`~protocol.Protocol.send_text` * :meth:`~protocol.Protocol.send_binary` These methods raise :exc:`~exceptions.ProtocolError` if you don't set the :attr:`FIN ` bit correctly in fragmented messages. For sending a control frame: * :meth:`~protocol.Protocol.send_close` * :meth:`~protocol.Protocol.send_ping` * :meth:`~protocol.Protocol.send_pong` :meth:`~protocol.Protocol.send_close` initiates the closing handshake. See `Closing a connection`_ below for details. If you encounter an unrecoverable error and you must fail the WebSocket connection, call :meth:`~protocol.Protocol.fail`. After any of the above, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as shown in `Send data`_ above. If you called :meth:`~protocol.Protocol.send_close` or :meth:`~protocol.Protocol.fail`, you expect the end of the data stream. You should follow the process described in `Close TCP connection`_ above in order to prevent dangling TCP connections. Closing a connection -------------------- Under normal circumstances, when a server wants to close the TCP connection: * it closes the write side; * it reads until the end of the stream, because it expects the client to close the read side; * it closes the socket. When a client wants to close the TCP connection: * it reads until the end of the stream, because it expects the server to close the read side; * it closes the write side; * it closes the socket. Applying the rules described earlier in this document gives the intended result. As a reminder, the rules are: * When :meth:`~protocol.Protocol.data_to_send` returns the empty bytestring, close the write side of the TCP connection. * When you reach the end of the read stream, close the TCP connection. * When :meth:`~protocol.Protocol.close_expected` returns :obj:`True`, if you don't reach the end of the read stream quickly, close the TCP connection. Fragmentation ------------- WebSocket messages may be fragmented. Since this is a protocol-level concern, you may choose to reassemble fragmented messages before handing them over to the application. To reassemble a message, read data frames until you get a frame where the :attr:`FIN ` bit is set, then concatenate the payloads of all frames. You will never receive an inconsistent sequence of frames because websockets raises a :exc:`~exceptions.ProtocolError` and fails the connection when this happens. However, you may receive an incomplete sequence if the connection drops in the middle of a fragmented message. Tips ---- Serialize operations .................... The Sans-I/O layer is designed to run sequentially. If you interact with it from multiple threads or coroutines, you must ensure correct serialization. Usually, this comes for free in a cooperative multitasking environment. In a preemptive multitasking environment, it requires mutual exclusion. Furthermore, you must serialize writes to the network. When :meth:`~protocol.Protocol.data_to_send` returns several values, you must write them all before starting the next write. Minimize buffers ................ The Sans-I/O layer doesn't perform any buffering. It makes events available in :meth:`~protocol.Protocol.events_received` as soon as they're received. You should make incoming messages available to the application immediately. A small buffer of incoming messages will usually result in the best performance. It will reduce context switching between the library and the application while ensuring that backpressure is propagated. websockets-15.0.1/docs/howto/upgrade.rst000066400000000000000000000562771476212450300202450ustar00rootroot00000000000000Upgrade to the new :mod:`asyncio` implementation ================================================ .. currentmodule:: websockets The new :mod:`asyncio` implementation, which is now the default, is a rewrite of the original implementation of websockets. It provides a very similar API. However, there are a few differences. The recommended upgrade process is: #. Make sure that your code doesn't use any `deprecated APIs`_. If it doesn't raise warnings, you're fine. #. `Update import paths`_. For straightforward use cases, this could be the only step you need to take. #. Check out `new features and improvements`_. Consider taking advantage of them in your code. #. Review `API changes`_. If needed, update your application to preserve its current behavior. In the interest of brevity, only :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` are discussed below but everything also applies to :func:`~asyncio.client.unix_connect` and :func:`~asyncio.server.unix_serve` respectively. .. admonition:: What will happen to the original implementation? :class: hint The original implementation is deprecated. It will be maintained for five years after deprecation according to the :ref:`backwards-compatibility policy `. Then, by 2030, it will be removed. .. _deprecated APIs: Deprecated APIs --------------- Here's the list of deprecated behaviors that the original implementation still supports and that the new implementation doesn't reproduce. If you're seeing a :class:`DeprecationWarning`, follow upgrade instructions from the release notes of the version in which the feature was deprecated. * The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` and deprecated in :ref:`13.0`. * The ``loop`` and ``legacy_recv`` arguments of :func:`~legacy.client.connect` and :func:`~legacy.server.serve`, which were removed — deprecated in :ref:`10.0`. * The ``timeout`` and ``klass`` arguments of :func:`~legacy.client.connect` and :func:`~legacy.server.serve`, which were renamed to ``close_timeout`` and ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. * An empty string in the ``origins`` argument of :func:`~legacy.server.serve` — deprecated in :ref:`7.0`. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. .. _Update import paths: Import paths ------------ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package and deprecated. * The ``websockets`` package provides aliases for convenience. They were switched to the new implementation in version 14.0 or deprecated when there wasn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They were deprecated. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. .. |br| raw:: html
Client APIs ........... +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.connect()`` *(before 14.0)* |br| | ``websockets.connect()`` *(since 14.0)* |br| | | ``websockets.client.connect()`` |br| | :func:`websockets.asyncio.client.connect` | | :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_connect()`` *(before 14.0)* |br| | ``websockets.unix_connect()`` *(since 14.0)* |br| | | ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | ``websockets.ClientConnection`` *(since 14.2)* |br| | | ``websockets.client.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | | :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ Server APIs ........... +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.serve()`` *(before 14.0)* |br| | ``websockets.serve()`` *(since 14.0)* |br| | | ``websockets.server.serve()`` |br| | :func:`websockets.asyncio.server.serve` | | :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_serve()`` *(before 14.0)* |br| | ``websockets.unix_serve()`` *(since 14.0)* |br| | | ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | ``websockets.Server`` *(since 14.2)* |br| | | ``websockets.server.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServerProtocol`` |br| | ``websockets.ServerConnection`` *(since 14.2)* |br| | | ``websockets.server.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | | :func:`websockets.legacy.server.broadcast()` | :func:`websockets.asyncio.server.broadcast` | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.basic_auth_protocol_factory()`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.basic_auth_protocol_factory()`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ .. _new features and improvements: New features and improvements ----------------------------- Customizing the opening handshake ................................. On the server side, if you're customizing how :func:`~legacy.server.serve` processes the opening handshake with ``process_request``, ``extra_headers``, or ``select_subprotocol``, you must update your code. Probably you can simplify it! ``process_request`` and ``select_subprotocol`` have new signatures. ``process_response`` replaces ``extra_headers`` and provides more flexibility. See process_request_, select_subprotocol_, and process_response_ below. Customizing automatic reconnection .................................. On the client side, if you're reconnecting automatically with ``async for ... in connect(...)``, the behavior when a connection attempt fails was enhanced and made configurable. The original implementation retried on any error. The new implementation uses an heuristic to determine whether an error is retryable or fatal. By default, only network errors and server errors (HTTP 500, 502, 503, or 504) are considered retryable. You can customize this behavior with the ``process_exception`` argument of :func:`~asyncio.client.connect`. See :func:`~asyncio.client.process_exception` for more information. Here's how to revert to the behavior of the original implementation:: async for ... in connect(..., process_exception=lambda exc: exc): ... Tracking open connections ......................... The new implementation of :class:`~asyncio.server.Server` provides a :attr:`~asyncio.server.Server.connections` property, which is a set of all open connections. This didn't exist in the original implementation. If you're keeping track of open connections in order to broadcast messages to all of them, you can simplify your code by using this property. Controlling UTF-8 decoding .......................... The new implementation of the :meth:`~asyncio.connection.Connection.recv` method provides the ``decode`` argument to control UTF-8 decoding of messages. This didn't exist in the original implementation. If you're calling :meth:`~str.encode` on a :class:`str` object returned by :meth:`~asyncio.connection.Connection.recv`, using ``decode=False`` and removing :meth:`~str.encode` saves a round-trip of UTF-8 decoding and encoding for text messages. You can also force UTF-8 decoding of binary messages with ``decode=True``. This is rarely useful and has no performance benefits over decoding a :class:`bytes` object returned by :meth:`~asyncio.connection.Connection.recv`. Receiving fragmented messages ............................. The new implementation provides the :meth:`~asyncio.connection.Connection.recv_streaming` method for receiving a fragmented message frame by frame. There was no way to do this in the original implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. .. _API changes: API changes ----------- Attributes of connection objects ................................ ``path``, ``request_headers``, and ``response_headers`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, :attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and :attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are replaced by :attr:`~asyncio.connection.Connection.request` and :attr:`~asyncio.connection.Connection.response`. If your code uses them, you can update it as follows. ========================================== ========================================== Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation ========================================== ========================================== ``connection.path`` ``connection.request.path`` ``connection.request_headers`` ``connection.request.headers`` ``connection.response_headers`` ``connection.response.headers`` ========================================== ========================================== ``open`` and ``closed`` ~~~~~~~~~~~~~~~~~~~~~~~ The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. Using them was discouraged. Instead, you should call :meth:`~asyncio.connection.Connection.recv` or :meth:`~asyncio.connection.Connection.send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. If your code uses them, you can update it as follows. ========================================== ========================================== Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation ========================================== ========================================== .. ``from websockets.protocol import State`` ``connection.open`` ``connection.state is State.OPEN`` ``connection.closed`` ``connection.state is State.CLOSED`` ========================================== ========================================== Arguments of :func:`~asyncio.client.connect` ............................................ ``extra_headers`` → ``additional_headers`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you're setting the ``User-Agent`` header with the ``extra_headers`` argument, you should set it with ``user_agent_header`` instead. If you're adding other headers to the handshake request sent by :func:`~legacy.client.connect` with ``extra_headers``, you must rename it to ``additional_headers``. Arguments of :func:`~asyncio.server.serve` .......................................... ``ws_handler`` → ``handler`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The first argument of :func:`~asyncio.server.serve` is now called ``handler`` instead of ``ws_handler``. It's usually passed as a positional argument, making this change transparent. If you're passing it as a keyword argument, you must update its name. .. _process_request: ``process_request`` ~~~~~~~~~~~~~~~~~~~ The signature of ``process_request`` changed. This is easiest to illustrate with an example:: import http # Original implementation def process_request(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" # New implementation def process_request(connection, request): return connection.respond(http.HTTPStatus.OK, "OK\n") serve(..., process_request=process_request, ...) ``connection`` is always available in ``process_request``. In the original implementation, if you wanted to make the connection object available in a ``process_request`` method, you had to write a subclass of :class:`~legacy.server.WebSocketServerProtocol` and pass it in the ``create_protocol`` argument. This pattern isn't useful anymore; you can replace it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. .. _process_response: ``extra_headers`` → ``process_response`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you're setting the ``Server`` header with ``extra_headers``, you should set it with the ``server_header`` argument instead. If you're adding other headers to the handshake response sent by :func:`~legacy.server.serve` with the ``extra_headers`` argument, you must write a ``process_response`` callable instead. ``process_request`` replaces ``extra_headers`` and provides more flexibility. In the most basic case, you would adapt your code as follows:: # Original implementation serve(..., extra_headers=HEADERS, ...) # New implementation def process_response(connection, request, response): response.headers.update(HEADERS) return response serve(..., process_response=process_response, ...) ``connection`` is always available in ``process_response``, similar to ``process_request``. In the original implementation, there was no way to make the connection object available. In addition, the ``request`` and ``response`` objects are available, which enables a broader range of use cases (e.g., logging) and makes ``process_response`` more useful than ``extra_headers``. .. _select_subprotocol: ``select_subprotocol`` ~~~~~~~~~~~~~~~~~~~~~~ If you're selecting a subprotocol, you must update your code because the signature of ``select_subprotocol`` changed. Here's an example:: # Original implementation def select_subprotocol(client_subprotocols, server_subprotocols): if "chat" in client_subprotocols: return "chat" # New implementation def select_subprotocol(connection, subprotocols): if "chat" in subprotocols return "chat" serve(..., select_subprotocol=select_subprotocol, ...) ``connection`` is always available in ``select_subprotocol``. This brings the same benefits as in ``process_request``. It may remove the need to subclass :class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because ``select_subprotocols`` has to know which subprotocols it may select and under which conditions. Furthermore, the default behavior when ``select_subprotocol`` isn't provided changed in two ways: 1. In the original implementation, a server with a list of subprotocols accepted to continue without a subprotocol. In the new implementation, a server that is configured with subprotocols rejects connections that don't support any. 2. In the original implementation, when several subprotocols were available, the server averaged the client's preferences with its own preferences. In the new implementation, the server just picks the first subprotocol from its list. If you had a ``select_subprotocol`` for the sole purpose of rejecting connections without a subprotocol, you can remove it and keep only the ``subprotocols`` argument. Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` .............................................................................. ``max_queue`` ~~~~~~~~~~~~~ The ``max_queue`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` has a new meaning but achieves a similar effect. It is now the high-water mark of a buffer of incoming frames. It defaults to 16 frames. It used to be the size of a buffer of incoming messages that refilled as soon as a message was read. It used to default to 32 messages. This can make a difference when messages are fragmented in several frames. In that case, you may want to increase ``max_queue``. If you're writing a high performance server and you know that you're receiving fragmented messages, probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` and optimize the performance of reads again. In all other cases, given how uncommon fragmentation is, you shouldn't worry about this change. ``read_limit`` ~~~~~~~~~~~~~~ The ``read_limit`` argument doesn't exist in the new implementation because it doesn't buffer data received from the network in a :class:`~asyncio.StreamReader`. With a better design, this buffer could be removed. The buffer of incoming frames configured by ``max_queue`` is the only read buffer now. ``write_limit`` ~~~~~~~~~~~~~~~ The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. ``create_protocol`` → ``create_connection`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The keyword argument of :func:`~asyncio.server.serve` for customizing the creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` instead of a :class:`~legacy.server.WebSocketServerProtocol`. If you were customizing connection objects, probably you need to redo your customization. Consider switching to ``process_request`` and ``select_subprotocol`` as their new design removes most use cases for ``create_connection``. .. _basic-auth: Performing HTTP Basic Authentication .................................... .. admonition:: This section applies only to servers. :class: tip On the client side, :func:`~asyncio.client.connect` performs HTTP Basic Authentication automatically when the URI contains credentials. In the original implementation, the recommended way to add HTTP Basic Authentication to a server was to set the ``create_protocol`` argument of :func:`~legacy.server.serve` to a factory function generated by :func:`~legacy.auth.basic_auth_protocol_factory`:: from websockets.legacy.auth import basic_auth_protocol_factory from websockets.legacy.server import serve async with serve(..., create_protocol=basic_auth_protocol_factory(...)): ... In the new implementation, the :func:`~asyncio.server.basic_auth` function generates a ``process_request`` coroutine that performs HTTP Basic Authentication:: from websockets.asyncio.server import basic_auth, serve async with serve(..., process_request=basic_auth(...)): ... :func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or a ``check_credentials`` coroutine as well as an optional ``realm`` just like :func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, ``check_credentials`` may be a function instead of a coroutine. This new API has more obvious semantics. That makes it easier to understand and also easier to extend. In the original implementation, overriding ``create_protocol`` changes the type of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP Basic Authentication in its ``process_request`` method. To customize ``process_request`` further, you had only bad options: * the ill-defined option: add a ``process_request`` argument to :func:`~legacy.server.serve`; to tell which one would run first, you had to experiment or read the code; * the cumbersome option: subclass :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that subclass in the ``create_protocol`` argument of :func:`~legacy.auth.basic_auth_protocol_factory`. In the new implementation, you just write a ``process_request`` coroutine:: from websockets.asyncio.server import basic_auth, serve process_basic_auth = basic_auth(...) async def process_request(connection, request): ... # some logic here response = await process_basic_auth(connection, request) if response is not None: return response ... # more logic here async with serve(..., process_request=process_request): ... websockets-15.0.1/docs/index.rst000066400000000000000000000071461476212450300165540ustar00rootroot00000000000000websockets ========== |licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets .. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ .. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge :target: https://bestpractices.coreinfrastructure.org/projects/6475 websockets is a library for building WebSocket_ servers and clients in Python with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API It supports several network I/O and control flow paradigms. 1. The default implementation builds upon :mod:`asyncio`, Python's built-in asynchronous I/O library. It provides an elegant coroutine-based API. It's ideal for servers that handle many client connections. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used for servers that handle few client connections. 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally by websockets. .. _Sans-I/O: https://sans-io.readthedocs.io/ Refer to the :doc:`feature support matrices ` for the full list of features provided by each implementation. .. admonition:: The :mod:`asyncio` implementation was rewritten. :class: tip The new implementation in ``websockets.asyncio`` builds upon the Sans-I/O implementation. It adds features that were impossible to provide in the original design. It was introduced in version 13.0. The historical implementation in ``websockets.legacy`` traces its roots to early versions of websockets. While it's stable and robust, it was deprecated in version 14.0 and it will be removed by 2030. The new implementation provides the same features as the historical implementation, and then some. If you're using the historical implementation, you should :doc:`ugrade to the new implementation `. Here's an echo server and corresponding client. .. tab:: asyncio .. literalinclude:: ../example/asyncio/echo.py .. tab:: threading .. literalinclude:: ../example/sync/echo.py .. tab:: asyncio :new-set: .. literalinclude:: ../example/asyncio/hello.py .. tab:: threading .. literalinclude:: ../example/sync/hello.py Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care of this under the hood so you can focus on your application! Also, websockets provides an interactive client: .. code-block:: console $ websockets ws://localhost:8765/ Connected to ws://localhost:8765/. > Hello world! < Hello world! Connection closed: 1000 (OK). Do you like it? :doc:`Let's dive in! ` .. toctree:: :hidden: intro/index howto/index deploy/index faq/index reference/index topics/index project/index websockets-15.0.1/docs/intro/000077500000000000000000000000001476212450300160365ustar00rootroot00000000000000websockets-15.0.1/docs/intro/examples.rst000066400000000000000000000057301476212450300204130ustar00rootroot00000000000000Quick examples ============== .. currentmodule:: websockets Start a server -------------- This WebSocket server receives a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../../example/quick/server.py :caption: server.py :language: python :func:`~asyncio.server.serve` executes the connection handler coroutine ``hello()`` once for each WebSocket connection. It closes the WebSocket connection when the handler returns. Connect a client ---------------- This WebSocket client sends a name to the server, receives a greeting, and closes the connection. .. literalinclude:: ../../example/quick/client.py :caption: client.py :language: python Using :func:`~sync.client.connect` as a context manager ensures that the WebSocket connection is closed. Connect a browser ----------------- The WebSocket protocol was invented for the web — as the name says! Here's how to connect a browser to a WebSocket server. Run this script in a console: .. literalinclude:: ../../example/quick/show_time.py :caption: show_time.py :language: python Save this file as ``show_time.html``: .. literalinclude:: ../../example/quick/show_time.html :caption: show_time.html :language: html Save this file as ``show_time.js``: .. literalinclude:: ../../example/quick/show_time.js :caption: show_time.js :language: js Then, open ``show_time.html`` in several browsers or tabs. Clocks tick irregularly. Broadcast messages ------------------ Let's send the same timestamps to everyone instead of generating independent sequences for each connection. Stop the previous script if it's still running and run this script in a console: .. literalinclude:: ../../example/quick/sync_time.py :caption: sync_time.py :language: python Refresh ``show_time.html`` in all browsers or tabs. Clocks tick in sync. Manage application state ------------------------ A WebSocket server can receive events from clients, process them to update the application state, and broadcast the updated state to all connected clients. Here's an example where any client can increment or decrement a counter. The concurrency model of :mod:`asyncio` guarantees that updates are serialized. This example keep tracks of connected users explicitly in ``USERS`` instead of relying on :attr:`server.connections `. The result is the same. Run this script in a console: .. literalinclude:: ../../example/quick/counter.py :caption: counter.py :language: python Save this file as ``counter.html``: .. literalinclude:: ../../example/quick/counter.html :caption: counter.html :language: html Save this file as ``counter.css``: .. literalinclude:: ../../example/quick/counter.css :caption: counter.css :language: css Save this file as ``counter.js``: .. literalinclude:: ../../example/quick/counter.js :caption: counter.js :language: js Then open ``counter.html`` file in several browsers and play with [+] and [-]. websockets-15.0.1/docs/intro/index.rst000066400000000000000000000014371476212450300177040ustar00rootroot00000000000000Getting started =============== .. currentmodule:: websockets Requirements ------------ websockets requires Python ≥ 3.9. .. admonition:: Use the most recent Python release :class: tip For each minor version (3.x), only the latest bugfix or security release (3.x.y) is officially supported. It doesn't have any dependencies. .. _install: Installation ------------ Install websockets with: .. code-block:: console $ pip install websockets Wheels are available for all platforms. Tutorial -------- Learn how to build an real-time web application with websockets. .. toctree:: :maxdepth: 2 tutorial1 tutorial2 tutorial3 In a hurry? ----------- These examples will get you started quickly with websockets. .. toctree:: :maxdepth: 2 examples websockets-15.0.1/docs/intro/tutorial1.rst000066400000000000000000000443031476212450300205200ustar00rootroot00000000000000Part 1 - Send & receive ======================= .. currentmodule:: websockets In this tutorial, you're going to build a web-based `Connect Four`_ game. .. _Connect Four: https://en.wikipedia.org/wiki/Connect_Four The web removes the constraint of being in the same room for playing a game. Two players can connect over of the Internet, regardless of where they are, and play in their browsers. When a player makes a move, it should be reflected immediately on both sides. This is difficult to implement over HTTP due to the request-response style of the protocol. Indeed, there is no good way to be notified when the other player makes a move. Workarounds such as polling or long-polling introduce significant overhead. Enter `WebSocket `_. The WebSocket protocol provides two-way communication between a browser and a server over a persistent connection. That's exactly what you need to exchange moves between players, via a server. .. admonition:: This is the first part of the tutorial. * In this :doc:`first part `, you will create a server and connect one browser; you can play if you share the same browser. * In the :doc:`second part `, you will connect a second browser; you can play from different browsers on a local network. * In the :doc:`third part `, you will deploy the game to the web; you can play from any browser connected to the Internet. Prerequisites ------------- This tutorial assumes basic knowledge of Python and JavaScript. If you're comfortable with :doc:`virtual environments `, you can use one for this tutorial. Else, don't worry: websockets doesn't have any dependencies; it shouldn't create trouble in the default environment. If you haven't installed websockets yet, do it now: .. code-block:: console $ pip install websockets Confirm that websockets is installed: .. code-block:: console $ websockets --version .. admonition:: This tutorial is written for websockets |release|. :class: tip If you installed another version, you should switch to the corresponding version of the documentation. Download the starter kit ------------------------ Create a directory and download these three files: :download:`connect4.js <../../example/tutorial/start/connect4.js>`, :download:`connect4.css <../../example/tutorial/start/connect4.css>`, and :download:`connect4.py <../../example/tutorial/start/connect4.py>`. The JavaScript module, along with the CSS file, provides a web-based user interface. Here's its API. .. js:module:: connect4 .. js:data:: PLAYER1 Color of the first player. .. js:data:: PLAYER2 Color of the second player. .. js:function:: createBoard(board) Draw a board. :param board: DOM element containing the board; must be initially empty. .. js:function:: playMove(board, player, column, row) Play a move. :param board: DOM element containing the board. :param player: :js:data:`PLAYER1` or :js:data:`PLAYER2`. :param column: between ``0`` and ``6``. :param row: between ``0`` and ``5``. The Python module provides a class to record moves and tell when a player wins. Here's its API. .. module:: connect4 .. data:: PLAYER1 :value: "red" Color of the first player. .. data:: PLAYER2 :value: "yellow" Color of the second player. .. class:: Connect4 A Connect Four game. .. method:: play(player, column) Play a move. :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. :param column: between ``0`` and ``6``. :returns: Row where the checker lands, between ``0`` and ``5``. :raises ValueError: if the move is illegal. .. attribute:: moves List of moves played during this game, as ``(player, column, row)`` tuples. .. attribute:: winner :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2` if they won; :obj:`None` if the game is still ongoing. .. currentmodule:: websockets Bootstrap the web UI -------------------- Create an ``index.html`` file next to ``connect4.js`` and ``connect4.css`` with this content: .. literalinclude:: ../../example/tutorial/step1/index.html :language: html This HTML page contains an empty ``
`` element where you will draw the Connect Four board. It loads a ``main.js`` script where you will write all your JavaScript code. Create a ``main.js`` file next to ``index.html``. In this script, when the page loads, draw the board: .. code-block:: javascript import { createBoard, playMove } from "./connect4.js"; window.addEventListener("DOMContentLoaded", () => { // Initialize the UI. const board = document.querySelector(".board"); createBoard(board); }); Open a shell, navigate to the directory containing these files, and start an HTTP server: .. code-block:: console $ python -m http.server Open http://localhost:8000/ in a web browser. The page displays an empty board with seven columns and six rows. You will play moves in this board later. Bootstrap the server -------------------- Create an ``app.py`` file next to ``connect4.py`` with this content: .. code-block:: python #!/usr/bin/env python import asyncio from websockets.asyncio.server import serve async def handler(websocket): while True: message = await websocket.recv() print(message) async def main(): async with serve(handler, "", 8001) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) The entry point of this program is ``asyncio.run(main())``. It creates an asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. The ``main()`` coroutine calls :func:`~asyncio.server.serve` to start a websockets server. :func:`~asyncio.server.serve` takes three positional arguments: * ``handler`` is a coroutine that manages a connection. When a client connects, websockets calls ``handler`` with the connection in argument. When ``handler`` terminates, websockets closes the connection. * The second argument defines the network interfaces where the server can be reached. Here, the server listens on all interfaces, so that other devices on the same local network can connect. * The third argument is the port on which the server listens. Invoking :func:`~asyncio.server.serve` as an asynchronous context manager, in an ``async with`` block, ensures that the server shuts down properly when terminating the program. For each connection, the ``handler()`` coroutine runs an infinite loop that receives messages from the browser and prints them. Open a shell, navigate to the directory containing ``app.py``, and start the server: .. code-block:: console $ python app.py This doesn't display anything. Hopefully the WebSocket server is running. Let's make sure that it works. You cannot test the WebSocket server with a web browser like you tested the HTTP server. However, you can test it with websockets' interactive client. Open another shell and run this command: .. code-block:: console $ websockets ws://localhost:8001/ You get a prompt. Type a message and press "Enter". Switch to the shell where the server is running and check that the server received the message. Good! Exit the interactive client with Ctrl-C or Ctrl-D. Now, if you look at the console where you started the server, you can see the stack trace of an exception: .. code-block:: pytb connection handler failed Traceback (most recent call last): ... File "app.py", line 22, in handler message = await websocket.recv() ... websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) Indeed, the server was waiting for the next message with :meth:`~asyncio.server.ServerConnection.recv` when the client disconnected. When this happens, websockets raises a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you won't receive another message on this connection. This exception creates noise in the server logs, making it more difficult to spot real errors when you add functionality to the server. Catch it in the ``handler()`` coroutine: .. code-block:: python from websockets.exceptions import ConnectionClosedOK async def handler(websocket): while True: try: message = await websocket.recv() except ConnectionClosedOK: break print(message) Stop the server with Ctrl-C and start it again: .. code-block:: console $ python app.py .. admonition:: You must restart the WebSocket server when you make changes. :class: tip The WebSocket server loads the Python code in ``app.py`` then serves every WebSocket request with this version of the code. As a consequence, changes to ``app.py`` aren't visible until you restart the server. This is unlike the HTTP server that you started earlier with ``python -m http.server``. For every request, this HTTP server reads the target file and sends it. That's why changes are immediately visible. It is possible to :doc:`restart the WebSocket server automatically <../howto/autoreload>` but this isn't necessary for this tutorial. Try connecting and disconnecting the interactive client again. The :exc:`~exceptions.ConnectionClosedOK` exception doesn't appear anymore. This pattern is so common that websockets provides a shortcut for iterating over messages received on the connection until the client disconnects: .. code-block:: python async def handler(websocket): async for message in websocket: print(message) Restart the server and check with the interactive client that its behavior didn't change. At this point, you bootstrapped a web application and a WebSocket server. Let's connect them. Transmit from browser to server ------------------------------- In JavaScript, you open a WebSocket connection as follows: .. code-block:: javascript const websocket = new WebSocket("ws://localhost:8001/"); Before you exchange messages with the server, you need to decide their format. There is no universal convention for this. Let's use JSON objects with a ``type`` key identifying the type of the event and the rest of the object containing properties of the event. Here's an event describing a move in the middle slot of the board: .. code-block:: javascript const event = {type: "play", column: 3}; Here's how to serialize this event to JSON and send it to the server: .. code-block:: javascript websocket.send(JSON.stringify(event)); Now you have all the building blocks to send moves to the server. Add this function to ``main.js``: .. literalinclude:: ../../example/tutorial/step1/main.js :language: js :start-at: function sendMoves :end-before: window.addEventListener ``sendMoves()`` registers a listener for ``click`` events on the board. The listener figures out which column was clicked, builds a event of type ``"play"``, serializes it, and sends it to the server. Modify the initialization to open the WebSocket connection and call the ``sendMoves()`` function: .. code-block:: javascript window.addEventListener("DOMContentLoaded", () => { // Initialize the UI. const board = document.querySelector(".board"); createBoard(board); // Open the WebSocket connection and register event handlers. const websocket = new WebSocket("ws://localhost:8001/"); sendMoves(board, websocket); }); Check that the HTTP server and the WebSocket server are still running. If you stopped them, here are the commands to start them again: .. code-block:: console $ python -m http.server .. code-block:: console $ python app.py Refresh http://localhost:8000/ in your web browser. Click various columns in the board. The server receives messages with the expected column number. There isn't any feedback in the board because you haven't implemented that yet. Let's do it. Transmit from server to browser ------------------------------- In JavaScript, you receive WebSocket messages by listening to ``message`` events. Here's how to receive a message from the server and deserialize it from JSON: .. code-block:: javascript websocket.addEventListener("message", ({ data }) => { const event = JSON.parse(data); // do something with event }); You're going to need three types of messages from the server to the browser: .. code-block:: javascript {type: "play", player: "red", column: 3, row: 0} {type: "win", player: "red"} {type: "error", message: "This slot is full."} The JavaScript code receiving these messages will dispatch events depending on their type and take appropriate action. For example, it will react to an event of type ``"play"`` by displaying the move on the board with the :js:func:`~connect4.playMove` function. Add this function to ``main.js``: .. literalinclude:: ../../example/tutorial/step1/main.js :language: js :start-at: function showMessage :end-before: function sendMoves .. admonition:: Why does ``showMessage`` use ``window.setTimeout``? :class: hint When :js:func:`playMove` modifies the state of the board, the browser renders changes asynchronously. Conversely, ``window.alert()`` runs synchronously and blocks rendering while the alert is visible. If you called ``window.alert()`` immediately after :js:func:`playMove`, the browser could display the alert before rendering the move. You could get a "Player red wins!" alert without seeing red's last move. We're using ``window.alert()`` for simplicity in this tutorial. A real application would display these messages in the user interface instead. It wouldn't be vulnerable to this problem. Modify the initialization to call the ``receiveMoves()`` function: .. literalinclude:: ../../example/tutorial/step1/main.js :language: js :start-at: window.addEventListener At this point, the user interface should receive events properly. Let's test it by modifying the server to send some events. Sending an event from Python is quite similar to JavaScript: .. code-block:: python event = {"type": "play", "player": "red", "column": 3, "row": 0} await websocket.send(json.dumps(event)) .. admonition:: Don't forget to serialize the event with :func:`json.dumps`. :class: tip Else, websockets raises ``TypeError: data is a dict-like object``. Modify the ``handler()`` coroutine in ``app.py`` as follows: .. code-block:: python import json from connect4 import PLAYER1, PLAYER2 async def handler(websocket): for player, column, row in [ (PLAYER1, 3, 0), (PLAYER2, 3, 1), (PLAYER1, 4, 0), (PLAYER2, 4, 1), (PLAYER1, 2, 0), (PLAYER2, 1, 0), (PLAYER1, 5, 0), ]: event = { "type": "play", "player": player, "column": column, "row": row, } await websocket.send(json.dumps(event)) await asyncio.sleep(0.5) event = { "type": "win", "player": PLAYER1, } await websocket.send(json.dumps(event)) Restart the WebSocket server and refresh http://localhost:8000/ in your web browser. Seven moves appear at 0.5 second intervals. Then an alert announces the winner. Good! Now you know how to communicate both ways. Once you plug the game engine to process moves, you will have a fully functional game. Add the game logic ------------------ In the ``handler()`` coroutine, you're going to initialize a game: .. code-block:: python from connect4 import Connect4 async def handler(websocket): # Initialize a Connect Four game. game = Connect4() ... Then, you're going to iterate over incoming messages and take these steps: * parse an event of type ``"play"``, the only type of event that the user interface sends; * play the move in the board with the :meth:`~connect4.Connect4.play` method, alternating between the two players; * if :meth:`~connect4.Connect4.play` raises :exc:`ValueError` because the move is illegal, send an event of type ``"error"``; * else, send an event of type ``"play"`` to tell the user interface where the checker lands; * if the move won the game, send an event of type ``"win"``. Try to implement this by yourself! Keep in mind that you must restart the WebSocket server and reload the page in the browser when you make changes. When it works, you can play the game from a single browser, with players taking alternate turns. .. admonition:: Enable debug logs to see all messages sent and received. :class: tip Here's how to enable debug logs: .. code-block:: python import logging logging.basicConfig( format="%(asctime)s %(message)s", level=logging.DEBUG, ) If you're stuck, a solution is available at the bottom of this document. Summary ------- In this first part of the tutorial, you learned how to: * build and run a WebSocket server in Python with :func:`~asyncio.server.serve`; * receive a message in a connection handler with :meth:`~asyncio.server.ServerConnection.recv`; * send a message in a connection handler with :meth:`~asyncio.server.ServerConnection.send`; * iterate over incoming messages with ``async for message in websocket: ...``; * open a WebSocket connection in JavaScript with the ``WebSocket`` API; * send messages in a browser with ``WebSocket.send()``; * receive messages in a browser by listening to ``message`` events; * design a set of events to be exchanged between the browser and the server. You can now play a Connect Four game in a browser, communicating over a WebSocket connection with a server where the game logic resides! However, the two players share a browser, so the constraint of being in the same room still applies. Move on to the :doc:`second part ` of the tutorial to break this constraint and play from separate browsers. Solution -------- .. literalinclude:: ../../example/tutorial/step1/app.py :caption: app.py :language: python :linenos: .. literalinclude:: ../../example/tutorial/step1/index.html :caption: index.html :language: html :linenos: .. literalinclude:: ../../example/tutorial/step1/main.js :caption: main.js :language: js :linenos: websockets-15.0.1/docs/intro/tutorial2.rst000066400000000000000000000446401476212450300205250ustar00rootroot00000000000000Part 2 - Route & broadcast ========================== .. currentmodule:: websockets .. admonition:: This is the second part of the tutorial. * In the :doc:`first part `, you created a server and connected one browser; you could play if you shared the same browser. * In this :doc:`second part `, you will connect a second browser; you can play from different browsers on a local network. * In the :doc:`third part `, you will deploy the game to the web; you can play from any browser connected to the Internet. In the first part of the tutorial, you opened a WebSocket connection from a browser to a server and exchanged events to play moves. The state of the game was stored in an instance of the :class:`~connect4.Connect4` class, referenced as a local variable in the connection handler coroutine. Now you want to open two WebSocket connections from two separate browsers, one for each player, to the same server in order to play the same game. This requires moving the state of the game to a place where both connections can access it. Share game state ---------------- As long as you're running a single server process, you can share state by storing it in a global variable. .. admonition:: What if you need to scale to multiple server processes? :class: hint In that case, you must design a way for the process that handles a given connection to be aware of relevant events for that client. This is often achieved with a publish / subscribe mechanism. How can you make two connection handlers agree on which game they're playing? When the first player starts a game, you give it an identifier. Then, you communicate the identifier to the second player. When the second player joins the game, you look it up with the identifier. In addition to the game itself, you need to keep track of the WebSocket connections of the two players. Since both players receive the same events, you don't need to treat the two connections differently; you can store both in the same set. Let's sketch this in code. A module-level :class:`dict` enables lookups by identifier: .. code-block:: python JOIN = {} When the first player starts the game, initialize and store it: .. code-block:: python import secrets async def handler(websocket): ... # Initialize a Connect Four game, the set of WebSocket connections # receiving moves from this game, and secret access token. game = Connect4() connected = {websocket} join_key = secrets.token_urlsafe(12) JOIN[join_key] = game, connected try: ... finally: del JOIN[join_key] When the second player joins the game, look it up: .. code-block:: python async def handler(websocket): ... join_key = ... # Find the Connect Four game. game, connected = JOIN[join_key] # Register to receive moves from this game. connected.add(websocket) try: ... finally: connected.remove(websocket) Notice how we're carefully cleaning up global state with ``try: ... finally: ...`` blocks. Else, we could leave references to games or connections in global state, which would cause a memory leak. In both connection handlers, you have a ``game`` pointing to the same :class:`~connect4.Connect4` instance, so you can interact with the game, and a ``connected`` set of connections, so you can send game events to both players as follows: .. code-block:: python async def handler(websocket): ... for connection in connected: await connection.send(json.dumps(event)) ... Perhaps you spotted a major piece missing from the puzzle. How does the second player obtain ``join_key``? Let's design new events to carry this information. To start a game, the first player sends an ``"init"`` event: .. code-block:: javascript {type: "init"} The connection handler for the first player creates a game as shown above and responds with: .. code-block:: javascript {type: "init", join: ""} With this information, the user interface of the first player can create a link to ``http://localhost:8000/?join=``. For the sake of simplicity, we will assume that the first player shares this link with the second player outside of the application, for example via an instant messaging service. To join the game, the second player sends a different ``"init"`` event: .. code-block:: javascript {type: "init", join: ""} The connection handler for the second player can look up the game with the join key as shown above. There is no need to respond. Let's dive into the details of implementing this design. Start a game ------------ We'll start with the initialization sequence for the first player. In ``main.js``, define a function to send an initialization event when the WebSocket connection is established, which triggers an ``open`` event: .. code-block:: javascript function initGame(websocket) { websocket.addEventListener("open", () => { // Send an "init" event for the first player. const event = { type: "init" }; websocket.send(JSON.stringify(event)); }); } Update the initialization sequence to call ``initGame()``: .. literalinclude:: ../../example/tutorial/step2/main.js :language: js :start-at: window.addEventListener In ``app.py``, define a new ``handler`` coroutine — keep a copy of the previous one to reuse it later: .. code-block:: python import secrets JOIN = {} async def start(websocket): # Initialize a Connect Four game, the set of WebSocket connections # receiving moves from this game, and secret access token. game = Connect4() connected = {websocket} join_key = secrets.token_urlsafe(12) JOIN[join_key] = game, connected try: # Send the secret access token to the browser of the first player, # where it'll be used for building a "join" link. event = { "type": "init", "join": join_key, } await websocket.send(json.dumps(event)) # Temporary - for testing. print("first player started game", id(game)) async for message in websocket: print("first player sent", message) finally: del JOIN[join_key] async def handler(websocket): # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) assert event["type"] == "init" # First player starts a new game. await start(websocket) In ``index.html``, add an ```` element to display the link to share with the other player. .. code-block:: html In ``main.js``, modify ``receiveMoves()`` to handle the ``"init"`` message and set the target of that link: .. code-block:: javascript switch (event.type) { case "init": // Create link for inviting the second player. document.querySelector(".join").href = "?join=" + event.join; break; // ... } Restart the WebSocket server and reload http://localhost:8000/ in the browser. There's a link labeled JOIN below the board with a target that looks like http://localhost:8000/?join=95ftAaU5DJVP1zvb. The server logs say ``first player started game ...``. If you click the board, you see ``"play"`` events. There is no feedback in the UI, though, because you haven't restored the game logic yet. Before we get there, let's handle links with a ``join`` query parameter. Join a game ----------- We'll now update the initialization sequence to account for the second player. In ``main.js``, update ``initGame()`` to send the join key in the ``"init"`` message when it's in the URL: .. code-block:: javascript function initGame(websocket) { websocket.addEventListener("open", () => { // Send an "init" event according to who is connecting. const params = new URLSearchParams(window.location.search); let event = { type: "init" }; if (params.has("join")) { // Second player joins an existing game. event.join = params.get("join"); } else { // First player starts a new game. } websocket.send(JSON.stringify(event)); }); } In ``app.py``, update the ``handler`` coroutine to look for the join key in the ``"init"`` message, then load that game: .. code-block:: python async def error(websocket, message): event = { "type": "error", "message": message, } await websocket.send(json.dumps(event)) async def join(websocket, join_key): # Find the Connect Four game. try: game, connected = JOIN[join_key] except KeyError: await error(websocket, "Game not found.") return # Register to receive moves from this game. connected.add(websocket) try: # Temporary - for testing. print("second player joined game", id(game)) async for message in websocket: print("second player sent", message) finally: connected.remove(websocket) async def handler(websocket): # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) assert event["type"] == "init" if "join" in event: # Second player joins an existing game. await join(websocket, event["join"]) else: # First player starts a new game. await start(websocket) Restart the WebSocket server and reload http://localhost:8000/ in the browser. Copy the link labeled JOIN and open it in another browser. You may also open it in another tab or another window of the same browser; however, that makes it a bit tricky to remember which one is the first or second player. .. admonition:: You must start a new game when you restart the server. :class: tip Since games are stored in the memory of the Python process, they're lost when you stop the server. Whenever you make changes to ``app.py``, you must restart the server, create a new game in a browser, and join it in another browser. The server logs say ``first player started game ...`` and ``second player joined game ...``. The numbers match, proving that the ``game`` local variable in both connection handlers points to same object in the memory of the Python process. Click the board in either browser. The server receives ``"play"`` events from the corresponding player. In the initialization sequence, you're routing connections to ``start()`` or ``join()`` depending on the first message received by the server. This is a common pattern in servers that handle different clients. .. admonition:: Why not use different URIs for ``start()`` and ``join()``? :class: hint Instead of sending an initialization event, you could encode the join key in the WebSocket URI e.g. ``ws://localhost:8001/join/``. The WebSocket server would parse ``websocket.path`` and route the connection, similar to how HTTP servers route requests. When you need to send sensitive data like authentication credentials to the server, sending it an event is considered more secure than encoding it in the URI because URIs end up in logs. For the purposes of this tutorial, both approaches are equivalent because the join key comes from an HTTP URL. There isn't much at risk anyway! Now you can restore the logic for playing moves and you'll have a fully functional two-player game. Add the game logic ------------------ Once the initialization is done, the game is symmetrical, so you can write a single coroutine to process the moves of both players: .. code-block:: python async def play(websocket, game, player, connected): ... With such a coroutine, you can replace the temporary code for testing in ``start()`` by: .. code-block:: python await play(websocket, game, PLAYER1, connected) and in ``join()`` by: .. code-block:: python await play(websocket, game, PLAYER2, connected) The ``play()`` coroutine will reuse much of the code you wrote in the first part of the tutorial. Try to implement this by yourself! Keep in mind that you must restart the WebSocket server, reload the page to start a new game with the first player, copy the JOIN link, and join the game with the second player when you make changes. When ``play()`` works, you can play the game from two separate browsers, possibly running on separate computers on the same local network. A complete solution is available at the bottom of this document. Watch a game ------------ Let's add one more feature: allow spectators to watch the game. The process for inviting a spectator can be the same as for inviting the second player. You will have to duplicate all the initialization logic: - declare a ``WATCH`` global variable similar to ``JOIN``; - generate a watch key when creating a game; it must be different from the join key, or else a spectator could hijack a game by tweaking the URL; - include the watch key in the ``"init"`` event sent to the first player; - generate a WATCH link in the UI with a ``watch`` query parameter; - update the ``initGame()`` function to handle such links; - update the ``handler()`` coroutine to invoke a ``watch()`` coroutine for spectators; - prevent ``sendMoves()`` from sending ``"play"`` events for spectators. Once the initialization sequence is done, watching a game is as simple as registering the WebSocket connection in the ``connected`` set in order to receive game events and doing nothing until the spectator disconnects. You can wait for a connection to terminate with :meth:`~asyncio.server.ServerConnection.wait_closed`: .. code-block:: python async def watch(websocket, watch_key): ... connected.add(websocket) try: await websocket.wait_closed() finally: connected.remove(websocket) The connection can terminate because the ``receiveMoves()`` function closed it explicitly after receiving a ``"win"`` event, because the spectator closed their browser, or because the network failed. Again, try to implement this by yourself. When ``watch()`` works, you can invite spectators to watch the game from other browsers, as long as they're on the same local network. As a further improvement, you may support adding spectators while a game is already in progress. This requires replaying moves that were played before the spectator was added to the ``connected`` set. Past moves are available in the :attr:`~connect4.Connect4.moves` attribute of the game. This feature is included in the solution proposed below. Broadcast --------- When you need to send a message to the two players and to all spectators, you're using this pattern: .. code-block:: python async def handler(websocket): ... for connection in connected: await connection.send(json.dumps(event)) ... Since this is a very common pattern in WebSocket servers, websockets provides the :func:`~asyncio.server.broadcast` helper for this purpose: .. code-block:: python from websockets.asyncio.server import broadcast async def handler(websocket): ... broadcast(connected, json.dumps(event)) ... Calling :func:`~asyncio.server.broadcast` once is more efficient than calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no ``await`` in the second version? Indeed, :func:`~asyncio.server.broadcast` is a function, not a coroutine like :meth:`~asyncio.server.ServerConnection.send` or :meth:`~asyncio.server.ServerConnection.recv`. It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` is a coroutine. When you want to receive the next message, you have to wait until the client sends it and the network transmits it. It's less obvious why :meth:`~asyncio.server.ServerConnection.send` is a coroutine. If you send many messages or large messages, you could write data faster than the network can transmit it or the client can read it. Then, outgoing data will pile up in buffers, which will consume memory and may crash your application. To avoid this problem, :meth:`~asyncio.server.ServerConnection.send` waits until the write buffer drains. By slowing down the application as necessary, this ensures that the server doesn't send data too quickly. This is called backpressure and it's useful for building robust systems. That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. That's why :func:`~asyncio.server.broadcast` doesn't wait until write buffers drain and therefore doesn't need to be a coroutine. For our Connect Four game, there's no difference in practice. The total amount of data sent on a connection for a game of Connect Four is so small that the write buffer cannot fill up. As a consequence, backpressure never kicks in. Summary ------- In this second part of the tutorial, you learned how to: * configure a connection by exchanging initialization messages; * keep track of connections within a single server process; * wait until a client disconnects in a connection handler; * broadcast a message to many connections efficiently. You can now play a Connect Four game from separate browser, communicating over WebSocket connections with a server that synchronizes the game logic! However, the two players have to be on the same local network as the server, so the constraint of being in the same place still mostly applies. Head over to the :doc:`third part ` of the tutorial to deploy the game to the web and remove this constraint. Solution -------- .. literalinclude:: ../../example/tutorial/step2/app.py :caption: app.py :language: python :linenos: .. literalinclude:: ../../example/tutorial/step2/index.html :caption: index.html :language: html :linenos: .. literalinclude:: ../../example/tutorial/step2/main.js :caption: main.js :language: js :linenos: websockets-15.0.1/docs/intro/tutorial3.rst000066400000000000000000000217511476212450300205240ustar00rootroot00000000000000Part 3 - Deploy to the web ========================== .. currentmodule:: websockets .. admonition:: This is the third part of the tutorial. * In the :doc:`first part `, you created a server and connected one browser; you could play if you shared the same browser. * In this :doc:`second part `, you connected a second browser; you could play from different browsers on a local network. * In this :doc:`third part `, you will deploy the game to the web; you can play from any browser connected to the Internet. In the first and second parts of the tutorial, for local development, you ran an HTTP server on ``http://localhost:8000/`` with: .. code-block:: console $ python -m http.server and a WebSocket server on ``ws://localhost:8001/`` with: .. code-block:: console $ python app.py Now you want to deploy these servers on the Internet. There's a vast range of hosting providers to choose from. For the sake of simplicity, we'll rely on: * `GitHub Pages`_ for the HTTP server; * Koyeb_ for the WebSocket server. .. _GitHub Pages: https://pages.github.com/ .. _Koyeb: https://www.koyeb.com/ Koyeb is a modern Platform as a Service provider whose free tier allows you to run a web application, including a WebSocket server. Commit project to git --------------------- Perhaps you committed your work to git while you were progressing through the tutorial. If you didn't, now is a good time, because GitHub and Koyeb offer git-based deployment workflows. Initialize a git repository: .. code-block:: console $ git init -b main Initialized empty Git repository in websockets-tutorial/.git/ $ git commit --allow-empty -m "Initial commit." [main (root-commit) 8195c1d] Initial commit. Add all files and commit: .. code-block:: console $ git add . $ git commit -m "Initial implementation of Connect Four game." [main 7f0b2c4] Initial implementation of Connect Four game. 6 files changed, 500 insertions(+) create mode 100644 app.py create mode 100644 connect4.css create mode 100644 connect4.js create mode 100644 connect4.py create mode 100644 index.html create mode 100644 main.js Sign up or log in to GitHub. Create a new repository. Set the repository name to ``websockets-tutorial``, the visibility to Public, and click **Create repository**. Push your code to this repository. You must replace ``python-websockets`` by your GitHub username in the following command: .. code-block:: console $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git $ git branch -M main $ git push -u origin main ... To github.com:python-websockets/websockets-tutorial.git * [new branch] main -> main Branch 'main' set up to track remote branch 'main' from 'origin'. Adapt the WebSocket server -------------------------- Before you deploy the server, you must adapt it for Koyeb's environment. This involves three small changes: 1. Koyeb provides the port on which the server should listen in the ``$PORT`` environment variable. 2. Koyeb requires a health check to verify that the server is running. We'll add a HTTP health check. 3. Koyeb sends a ``SIGTERM`` signal when terminating the server. We'll catch it and trigger a clean exit. Adapt the ``main()`` coroutine accordingly: .. code-block:: python import http import os import signal .. literalinclude:: ../../example/tutorial/step3/app.py :pyobject: health_check .. literalinclude:: ../../example/tutorial/step3/app.py :pyobject: main The ``process_request`` parameter of :func:`~asyncio.server.serve` is a callback that runs for each request. When it returns an HTTP response, websockets sends that response instead of opening a WebSocket connection. Here, requests to ``/healthz`` return an HTTP 200 status code. ``main()`` registers a signal handler that closes the server when receiving the ``SIGTERM`` signal. Then, it waits for the server to be closed. Additionally, using :func:`~asyncio.server.serve` as a context manager ensures that the server will always be closed cleanly, even if the program crashes. Deploy the WebSocket server --------------------------- Create a ``requirements.txt`` file with this content to install ``websockets`` when building the image: .. literalinclude:: ../../example/tutorial/step3/requirements.txt :language: text .. admonition:: Koyeb treats ``requirements.txt`` as a signal to `detect a Python app`__. :class: tip That's why you don't need to declare that you need a Python runtime. __ https://www.koyeb.com/docs/build-and-deploy/build-from-git/python#detection Create a ``Procfile`` file with this content to configure the command for running the server: .. literalinclude:: ../../example/tutorial/step3/Procfile :language: text Commit and push your changes: .. code-block:: console $ git add . $ git commit -m "Deploy to Koyeb." [main 4a4b6e9] Deploy to Koyeb. 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 Procfile create mode 100644 requirements.txt $ git push ... To github.com:python-websockets/websockets-tutorial.git + 6bd6032...4a4b6e9 main -> main Sign up or log in to Koyeb. In the Koyeb control panel, create a web service with GitHub as the deployment method. `Install and authorize Koyeb's GitHub app`__ if you haven't done that yet. __ https://www.koyeb.com/docs/build-and-deploy/deploy-with-git#connect-your-github-account-to-koyeb Follow the steps to create a new service: 1. Select the ``websockets-tutorial`` repository in the list of your repositories. 2. Confirm that the **Free** instance type is selected. Click **Next**. 3. Configure health checks: change the protocol from TCP to HTTP and set the path to ``/healthz``. Review other settings; defaults should be correct. Click **Deploy**. Koyeb builds the app, deploys it, verifies that the health checks passes, and makes the deployment active. You can test the WebSocket server with the interactive client exactly like you did in the first part of the tutorial. The Koyeb control panel provides the URL of your app in the format: ``https://--.koyeb.app/``. Replace ``https`` with ``wss`` in the URL and connect the interactive client: .. code-block:: console $ websockets wss://--.koyeb.app/ Connected to wss://--.koyeb.app/. > {"type": "init"} < {"type": "init", "join": "54ICxFae_Ip7TJE2", "watch": "634w44TblL5Dbd9a"} Press Ctrl-D to terminate the connection. It works! Prepare the web application --------------------------- Before you deploy the web application, perhaps you're wondering how it will locate the WebSocket server? Indeed, at this point, its address is hard-coded in ``main.js``: .. code-block:: javascript const websocket = new WebSocket("ws://localhost:8001/"); You can take this strategy one step further by checking the address of the HTTP server and determining the address of the WebSocket server accordingly. Add this function to ``main.js``; replace ``python-websockets`` by your GitHub username and ``websockets-tutorial`` by the name of your app on Koyeb: .. literalinclude:: ../../example/tutorial/step3/main.js :language: js :start-at: function getWebSocketServer :end-before: function initGame Then, update the initialization to connect to this address instead: .. code-block:: javascript const websocket = new WebSocket(getWebSocketServer()); Commit your changes: .. code-block:: console $ git add . $ git commit -m "Configure WebSocket server address." [main 0903526] Configure WebSocket server address. 1 file changed, 11 insertions(+), 1 deletion(-) $ git push ... To github.com:python-websockets/websockets-tutorial.git + 4a4b6e9...968eaaa main -> main Deploy the web application -------------------------- Go back to GitHub, open the Settings tab of the repository and select Pages in the menu. Select the main branch as source and click Save. GitHub tells you that your site is published. Open https://.github.io/websockets-tutorial/ and start a game! Summary ------- In this third part of the tutorial, you learned how to deploy a WebSocket application with Koyeb. You can start a Connect Four game, send the JOIN link to a friend, and play over the Internet! Congratulations for completing the tutorial. Enjoy building real-time web applications with websockets! Solution -------- .. literalinclude:: ../../example/tutorial/step3/app.py :caption: app.py :language: python :linenos: .. literalinclude:: ../../example/tutorial/step3/index.html :caption: index.html :language: html :linenos: .. literalinclude:: ../../example/tutorial/step3/main.js :caption: main.js :language: js :linenos: .. literalinclude:: ../../example/tutorial/step3/Procfile :caption: Procfile :language: text :linenos: .. literalinclude:: ../../example/tutorial/step3/requirements.txt :caption: requirements.txt :language: text :linenos: websockets-15.0.1/docs/make.bat000066400000000000000000000014331476212450300163110ustar00rootroot00000000000000@ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd websockets-15.0.1/docs/project/000077500000000000000000000000001476212450300163515ustar00rootroot00000000000000websockets-15.0.1/docs/project/changelog.rst000066400000000000000000001252471476212450300210450ustar00rootroot00000000000000Changelog ========= .. currentmodule:: websockets .. _backwards-compatibility policy: Backwards-compatibility policy ------------------------------ websockets is intended for production use. Therefore, stability is a goal. websockets also aims at providing the best API for WebSocket in Python. While we value stability, we value progress more. When an improvement requires changing a public API, we make the change and document it in this changelog. When possible with reasonable effort, we preserve backwards-compatibility for five years after the release that introduced the change. When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. 15.0.1 ------ *March 5, 2025* Bug fixes ......... * Prevented an exception when exiting the interactive client. .. _15.0: 15.0 ---- *February 16, 2025* Backwards-incompatible changes .............................. .. admonition:: Client connections use SOCKS and HTTP proxies automatically. :class: important If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling :func:`~asyncio.client.connect`. See :doc:`proxies <../topics/proxies>` for details. .. _python-socks: https://github.com/romis2012/python-socks .. admonition:: Keepalive is enabled in the :mod:`threading` implementation. :class: important The :mod:`threading` implementation now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame just like the :mod:`asyncio` implementation. See :doc:`keepalive and latency <../topics/keepalive>` for details. New features ............ * Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to dispatch connections to handlers based on the request path. Read more about routing in :doc:`routing <../topics/routing>`. Improvements ............ * Refreshed several how-to guides and topic guides. * Added type overloads for the ``decode`` argument of :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. .. _14.2: 14.2 ---- *January 19, 2025* New features ............ * Added support for regular expressions in the ``origins`` argument of :func:`~asyncio.server.serve`. Bug fixes ......... * Wrapped errors when reading the opening handshake request or response in :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening handshake fails. * Fixed :meth:`~sync.connection.Connection.recv` with ``timeout=0`` in the :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. * Fixed a crash in the :mod:`asyncio` implementation when canceling a ping then receiving the corresponding pong. * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when the network becomes unavailable or when receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` implementations. .. _14.1: 14.1 ---- *November 13, 2024* Improvements ............ * Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` implementations for consistency with the legacy implementation, even though this is never a good idea. * Added ``close_code`` and ``close_reason`` attributes in the :mod:`asyncio` and :mod:`threading` implementations for consistency with the legacy implementation. Bug fixes ......... * Once the connection is closed, messages previously received and buffered can be read in the :mod:`asyncio` and :mod:`threading` implementations, just like in the legacy implementation. .. _14.0: 14.0 ---- *November 9, 2024* Backwards-incompatible changes .............................. .. admonition:: websockets 14.0 requires Python ≥ 3.9. :class: tip websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. :class: attention The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: from websockets import connect, unix_connext from websockets import broadcast, serve, unix_serve If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. Alternatively, you may stick to the legacy :mod:`asyncio` implementation for now by importing it explicitly:: from websockets.legacy.client import connect, unix_connect from websockets.legacy.server import broadcast, serve, unix_serve .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions to migrate your application. Aliases for deprecated API were removed from ``websockets.__all__``, meaning that they cannot be imported with ``from websockets import *`` anymore. .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. :class: note :func:`~asyncio.client.connect`, :func:`~asyncio.client.unix_connect`, and :func:`~asyncio.server.basic_auth` in the :mod:`asyncio` implementation as well as :func:`~sync.client.connect`, :func:`~sync.client.unix_connect`, :func:`~sync.server.serve`, :func:`~sync.server.unix_serve`, and :func:`~sync.server.basic_auth` in the :mod:`threading` implementation now raise :exc:`ValueError` when a required argument isn't provided or an argument that is incompatible with others is provided. .. admonition:: :attr:`Frame.data ` is now a bytes-like object. :class: note In addition to :class:`bytes`, it may be a :class:`bytearray` or a :class:`memoryview`. If you wrote an :class:`~extensions.Extension` that relies on methods not provided by these types, you must update your code. .. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. :class: note If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in :meth:`~extensions.Extension.decode`, for example, you must replace ``PayloadTooBig(f"over size limit ({size} > {max_size} bytes)")`` with ``PayloadTooBig(size, max_size)``. New features ............ * Added an option to receive text frames as :class:`bytes`, without decoding, in the :mod:`threading` implementation; also binary frames as :class:`str`. * Added an option to send :class:`bytes` in a text frame in the :mod:`asyncio` and :mod:`threading` implementations; also :class:`str` in a binary frame. Improvements ............ * The :mod:`threading` implementation receives messages faster. * Sending or receiving large compressed messages is now faster. * Errors when a fragmented message is too large are clearer. * Log messages at the :data:`~logging.WARNING` and :data:`~logging.INFO` levels no longer include stack traces. Bug fixes ......... * Clients no longer crash when the server rejects the opening handshake and the HTTP response doesn't Include a ``Content-Length`` header. * Returning an HTTP response in ``process_request`` or ``process_response`` doesn't generate a log message at the :data:`~logging.ERROR` level anymore. * Connections are closed with code 1007 (invalid data) when receiving invalid UTF-8 in a text frame. .. _13.1: 13.1 ---- *September 21, 2024* Backwards-incompatible changes .............................. .. admonition:: The ``code`` and ``reason`` attributes of :exc:`~exceptions.ConnectionClosed` are deprecated. :class: note They were removed from the documentation in version 10.0, due to their spec-compliant but counter-intuitive behavior, but they were kept in the code for backwards compatibility. They're now formally deprecated. New features ............ * Added support for reconnecting automatically by using :func:`~asyncio.client.connect` as an asynchronous iterator to the new :mod:`asyncio` implementation. * :func:`~asyncio.client.connect` now follows redirects in the new :mod:`asyncio` implementation. * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. * Made the set of active connections available in the :attr:`Server.connections ` property. Improvements ............ * Improved reporting of errors during the opening handshake. * Raised :exc:`~exceptions.ConcurrencyError` on unsupported concurrent calls. Previously, :exc:`RuntimeError` was raised. For backwards compatibility, :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. Bug fixes ......... * The new :mod:`asyncio` and :mod:`threading` implementations of servers don't start the connection handler anymore when ``process_request`` or ``process_response`` returns an HTTP response. * Fixed a bug in the :mod:`threading` implementation that could lead to incorrect error reporting when closing a connection while :meth:`~sync.connection.Connection.recv` is running. 13.0.1 ------ *August 28, 2024* Bug fixes ......... * Restored the C extension in the source distribution. .. _13.0: 13.0 ---- *August 20, 2024* Backwards-incompatible changes .............................. .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note If you implemented the connection handler of a server as:: async def handler(request, path): ... You should switch to the pattern recommended since version 10.1:: async def handler(request): path = request.path # only if handler() uses the path argument ... .. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` and :func:`~sync.server.serve` in the :mod:`threading` implementation is renamed to ``ssl``. :class: note This aligns the API of the :mod:`threading` implementation with the :mod:`asyncio` implementation. For backwards compatibility, ``ssl_context`` is still supported. .. admonition:: The ``WebSocketServer`` class in the :mod:`threading` implementation is renamed to :class:`~sync.server.Server`. :class: note This change should be transparent because this class shouldn't be instantiated directly; :func:`~sync.server.serve` returns an instance. Regardless, an alias provides backwards compatibility. New features ............ .. admonition:: websockets 11.0 introduces a new :mod:`asyncio` implementation. :class: important This new implementation is intended to be a drop-in replacement for the current implementation. It will become the default in a future release. Please try it and report any issue that you encounter! The :doc:`upgrade guide <../howto/upgrade>` explains everything you need to know about the upgrade process. * Validated compatibility with Python 3.12 and 3.13. * Added an option to receive text frames as :class:`bytes`, without decoding, in the :mod:`asyncio` implementation; also binary frames as :class:`str`. * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. Improvements ............ * The error message in server logs when a header is too long is more explicit. Bug fixes ......... * Fixed a bug in the :mod:`threading` implementation that could prevent the program from exiting when a connection wasn't closed properly. * Redirecting from a ``ws://`` URI to a ``wss://`` URI now works. * ``broadcast(raise_exceptions=True)`` no longer crashes when there isn't any exception. .. _12.0: 12.0 ---- *October 21, 2023* Backwards-incompatible changes .............................. .. admonition:: websockets 12.0 requires Python ≥ 3.8. :class: tip websockets 11.0 is the last version supporting Python 3.7. Improvements ............ * Made convenience imports from ``websockets`` compatible with static code analysis tools such as auto-completion in an IDE or type checking with mypy_. .. _mypy: https://github.com/python/mypy * Accepted a plain :class:`int` where an :class:`~http.HTTPStatus` is expected. * Added :class:`~frames.CloseCode`. 11.0.3 ------ *May 7, 2023* Bug fixes ......... * Fixed the :mod:`threading` implementation of servers on Windows. 11.0.2 ------ *April 18, 2023* Bug fixes ......... * Fixed a deadlock in the :mod:`threading` implementation when closing a connection without reading all messages. 11.0.1 ------ *April 6, 2023* Bug fixes ......... * Restored the C extension in the source distribution. .. _11.0: 11.0 ---- *April 2, 2023* Backwards-incompatible changes .............................. .. admonition:: The Sans-I/O implementation was moved. :class: caution Aliases provide compatibility for all previously public APIs according to the `backwards-compatibility policy`_. * The ``connection`` module was renamed to ``protocol``. * The ``connection.Connection``, ``server.ServerConnection``, and ``client.ClientConnection`` classes were renamed to ``protocol.Protocol``, ``server.ServerProtocol``, and ``client.ClientProtocol``. .. admonition:: Sans-I/O protocol constructors now use keyword-only arguments. :class: caution If you instantiate :class:`~server.ServerProtocol` or :class:`~client.ClientProtocol` directly, make sure you are using keyword arguments. .. admonition:: Closing a connection without an empty close frame is OK. :class: note Receiving an empty close frame now results in :exc:`~exceptions.ConnectionClosedOK` instead of :exc:`~exceptions.ConnectionClosedError`. As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. .. admonition:: :func:`~legacy.server.serve` times out on the opening handshake after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to :obj:`None` to disable the timeout entirely. New features ............ .. admonition:: websockets 11.0 introduces a :mod:`threading` implementation. :class: important It may be more convenient if you don't need to manage many connections and you're more comfortable with :mod:`threading` than :mod:`asyncio`. It is particularly suited to client applications that establish only one connection. It may be used for servers handling few connections. See :func:`websockets.sync.client.connect` and :func:`websockets.sync.server.serve` for details. * Added ``open_timeout`` to :func:`~legacy.server.serve`. * Made it possible to close a server without closing existing connections. * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize negotiation of subprotocols in the Sans-I/O layer. Improvements ............ * Added platform-independent wheels. * Improved error handling in :func:`~legacy.server.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. .. _10.4: 10.4 ---- *October 25, 2022* New features ............ * Validated compatibility with Python 3.11. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property to protocols. * Changed :attr:`~legacy.protocol.WebSocketCommonProtocol.ping` to return the latency of the connection. * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. * Added deployment guides for more Platform as a Service providers. Improvements ............ * Improved FAQ. .. _10.3: 10.3 ---- *April 17, 2022* Backwards-incompatible changes .............................. .. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. :class: note Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and :class:`~client.ClientProtocol` instead. See :doc:`../howto/sansio` for details. Improvements ............ * Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. .. _10.2: 10.2 ---- *February 21, 2022* Improvements ............ * Made compression negotiation more lax for compatibility with Firefox. * Improved FAQ and quick start guide. Bug fixes ......... * Fixed backwards-incompatibility in 10.1 for connection handlers created with :func:`functools.partial`. * Avoided leaking open sockets when :func:`~legacy.client.connect` is canceled. .. _10.1: 10.1 ---- *November 14, 2021* New features ............ * Added a tutorial. * Made the second parameter of connection handlers optional. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. If you implemented the connection handler of a server as:: async def handler(request, path): ... You should replace it with:: async def handler(request): path = request.path # only if handler() uses the path argument ... * Added ``python -m websockets --version``. Improvements ............ * Added wheels for Python 3.10, PyPy 3.7, and for more platforms. * Reverted optimization of default compression settings for clients, mainly to avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 * Mirrored the entire :class:`~asyncio.Server` API in :class:`~legacy.server.WebSocketServer`. * Improved performance for large messages on ARM processors. * Documented how to auto-reload on code changes in development. Bug fixes ......... * Avoided half-closing TCP connections that are already closed. .. _10.0: 10.0 ---- *September 9, 2021* Backwards-incompatible changes .............................. .. admonition:: websockets 10.0 requires Python ≥ 3.7. :class: tip websockets 9.1 is the last version supporting Python 3.6. .. admonition:: The ``loop`` parameter is deprecated from all APIs. :class: caution This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. The ``loop`` parameter is also removed from :class:`~legacy.server.WebSocketServer`. This should be transparent. .. admonition:: :func:`~legacy.client.connect` times out after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to :obj:`None` to disable the timeout entirely. .. admonition:: The ``legacy_recv`` option is deprecated. :class: note See the release notes of websockets 3.0 for details. .. admonition:: The signature of :exc:`~exceptions.ConnectionClosed` changed. :class: note If you raise :exc:`~exceptions.ConnectionClosed` or a subclass, rather than catch them when websockets raises them, you must change your code. .. admonition:: A ``msg`` parameter was added to :exc:`~exceptions.InvalidURI`. :class: note If you raise :exc:`~exceptions.InvalidURI`, rather than catch it when websockets raises it, you must change your code. New features ............ .. admonition:: websockets 10.0 introduces a `Sans-I/O API `_ for easier integration in third-party libraries. :class: important If you're integrating websockets in a library, rather than just using it, look at the :doc:`Sans-I/O integration guide <../howto/sansio>`. * Added compatibility with Python 3.10. * Added :func:`~legacy.server.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. * Added ``open_timeout`` to :func:`~legacy.client.connect`. * Documented how to integrate with `Django `_. * Documented how to deploy websockets in production, with several options. * Documented how to authenticate connections. * Documented how to broadcast messages to many connections. Improvements ............ * Improved logging. See the :doc:`logging guide <../topics/logging>`. * Optimized default compression settings to reduce memory usage. * Optimized processing of client-to-server messages when the C extension isn't available. * Supported relative redirects in :func:`~legacy.client.connect`. * Handled TCP connection drops during the opening handshake. * Made it easier to customize authentication with :meth:`~legacy.auth.BasicAuthWebSocketServerProtocol.check_credentials`. * Provided additional information in :exc:`~exceptions.ConnectionClosed` exceptions. * Clarified several exceptions or log messages. * Restructured documentation. * Improved API documentation. * Extended FAQ. Bug fixes ......... * Avoided a crash when receiving a ping while the connection is closing. .. _9.1: 9.1 --- *May 27, 2021* Security fix ............ .. admonition:: websockets 9.1 fixes a security issue introduced in 8.0. :class: danger Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords (`CVE-2021-33880`_). .. _CVE-2021-33880: https://nvd.nist.gov/vuln/detail/CVE-2021-33880 9.0.2 ----- *May 15, 2021* Bug fixes ......... * Restored compatibility of ``python -m websockets`` with Python < 3.9. * Restored compatibility with mypy. 9.0.1 ----- *May 2, 2021* Bug fixes ......... * Fixed issues with the packaging of the 9.0 release. .. _9.0: 9.0 --- *May 1, 2021* Backwards-incompatible changes .............................. .. admonition:: Several modules are moved or deprecated. :class: caution Aliases provide compatibility for all previously public APIs according to the `backwards-compatibility policy`_ * :class:`~datastructures.Headers` and :exc:`~datastructures.MultipleValuesError` are moved from ``websockets.http`` to :mod:`websockets.datastructures`. If you're using them, you should adjust the import path. * The ``client``, ``server``, ``protocol``, and ``auth`` modules were moved from the ``websockets`` package to a ``websockets.legacy`` sub-package. Despite the name, they're still fully supported. * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` modules in the ``websockets`` package are deprecated. These modules provided low-level APIs for reuse by other projects, but they didn't reach that goal. Keeping these APIs public makes it more difficult to improve websockets. These changes pave the path for a refactoring that should be a transparent upgrade for most uses and facilitate integration by other projects. .. admonition:: Convenience imports from ``websockets`` are performed lazily. :class: note While Python supports this, tools relying on static code analysis don't. This breaks auto-completion in an IDE or type checking with mypy_. .. _mypy: https://github.com/python/mypy If you depend on such tools, use the real import paths, which can be found in the API documentation, for example:: from websockets.client import connect from websockets.server import serve New features ............ * Added compatibility with Python 3.9. Improvements ............ * Added support for IRIs in addition to URIs. * Added close codes 1012, 1013, and 1014. * Raised an error when passing a :class:`dict` to :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. * Improved error reporting. Bug fixes ......... * Fixed sending fragmented, compressed messages. * Fixed ``Host`` header sent when connecting to an IPv6 address. * Fixed creating a client or a server with an existing Unix socket. * Aligned maximum cookie size with popular web browsers. * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits from :exc:`Exception`. .. _8.1: 8.1 --- *November 1, 2019* New features ............ * Added compatibility with Python 3.8. 8.0.2 ----- *July 31, 2019* Bug fixes ......... * Restored the ability to pass a socket with the ``sock`` parameter of :func:`~legacy.server.serve`. * Removed an incorrect assertion when a connection drops. 8.0.1 ----- *July 21, 2019* Bug fixes ......... * Restored the ability to import ``WebSocketProtocolError`` from ``websockets``. .. _8.0: 8.0 --- *July 7, 2019* Backwards-incompatible changes .............................. .. admonition:: websockets 8.0 requires Python ≥ 3.6. :class: tip websockets 7.0 is the last version supporting Python 3.4 and 3.5. .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note If you're passing a ``process_request`` argument to :func:`~legacy.server.serve` or :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. Previously, both were supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. .. admonition:: ``max_queue`` must be :obj:`None` to disable the limit. :class: note If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. .. admonition:: The ``host``, ``port``, and ``secure`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are deprecated. :class: note Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in servers and :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` in clients instead of ``host`` and ``port``. .. admonition:: ``WebSocketProtocolError`` is renamed to :exc:`~exceptions.ProtocolError`. :class: note An alias provides backwards compatibility. .. admonition:: ``read_response()`` now returns the reason phrase. :class: note If you're using this low-level API, you must change your code. New features ............ * Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. * :func:`~legacy.client.connect` handles redirects from the server during the handshake. * :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. * Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. * Added support for asynchronous generators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to generate fragmented messages incrementally. * Enabled readline in the interactive client. * Added type hints (:pep:`484`). * Added a FAQ to the documentation. * Added documentation for extensions. * Documented how to optimize memory usage. Improvements ............ * :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like types :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. * Added :exc:`~exceptions.ConnectionClosedOK` and :exc:`~exceptions.ConnectionClosedError` subclasses of :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. * Changed :meth:`WebSocketServer.close() ` to perform a proper closing handshake instead of failing the connection. * Improved error messages when HTTP parsing fails. * Improved API documentation. Bug fixes ......... * Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` exceptions in keepalive ping task. If you were using ``ping_timeout=None`` as a workaround, you can remove it. * Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. .. _7.0: 7.0 --- *November 1, 2018* Backwards-incompatible changes .............................. .. admonition:: Keepalive is enabled by default. :class: important websockets now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. .. admonition:: Termination of connections by :meth:`WebSocketServer.close() ` changes. :class: caution Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the connection handler, this is the same as if the remote endpoint was disconnecting. This removes the need to prepare for :exc:`~asyncio.CancelledError` in connection handlers. You can restore the previous behavior by adding the following line at the beginning of connection handlers:: def handler(websocket, path): closed = asyncio.ensure_future(websocket.wait_closed()) closed.add_done_callback(lambda task: task.cancel()) .. admonition:: Calling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` concurrently raises a :exc:`RuntimeError`. :class: note Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. .. admonition:: The ``timeout`` argument of :func:`~legacy.server.serve` and :func:`~legacy.client.connect` is renamed to ``close_timeout`` . :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. .. admonition:: The ``origins`` argument of :func:`~legacy.server.serve` changes. :class: note Include :obj:`None` in the list rather than ``''`` to allow requests that don't contain an Origin header. .. admonition:: Pending pings aren't canceled when the connection is closed. :class: note A ping — as in ``ping = await websocket.ping()`` — for which no pong was received yet used to be canceled when the connection is closed, so that ``await ping`` raised :exc:`~asyncio.CancelledError`. Now ``await ping`` raises :exc:`~exceptions.ConnectionClosed` like other public APIs. New features ............ * Added ``process_request`` and ``select_subprotocol`` arguments to :func:`~legacy.server.serve` and :class:`~legacy.server.WebSocketServerProtocol` to facilitate customization of :meth:`~legacy.server.WebSocketServerProtocol.process_request` and :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. * Added the :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` method to protocols. * Added an interactive client: ``python -m websockets ``. Improvements ............ * Improved handling of multiple HTTP headers with the same name. * Improved error messages when a required HTTP header is missing. Bug fixes ......... * Fixed a data loss bug in :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. .. _6.0: 6.0 --- *July 16, 2018* Backwards-incompatible changes .............................. .. admonition:: The :class:`~datastructures.Headers` class is introduced and several APIs are updated to use it. :class: caution * The ``request_headers`` argument of :meth:`~legacy.server.WebSocketServerProtocol.process_request` is now a :class:`~datastructures.Headers` instead of an ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are now :class:`~datastructures.Headers` instead of ``http.client.HTTPMessage``. * The ``raw_request_headers`` and ``raw_response_headers`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are removed. Use :meth:`~datastructures.Headers.raw_items` instead. * Functions defined in the ``handshake`` module now receive :class:`~datastructures.Headers` in argument instead of ``get_header`` or ``set_header`` functions. This affects libraries that rely on low-level APIs. * Functions defined in the ``http`` module now return HTTP headers as :class:`~datastructures.Headers` instead of lists of ``(name, value)`` pairs. Since :class:`~datastructures.Headers` and ``http.client.HTTPMessage`` provide similar APIs, much of the code dealing with HTTP headers won't require changes. New features ............ * Added compatibility with Python 3.7. 5.0.1 ----- *May 24, 2018* Bug fixes ......... * Fixed a regression in 5.0 that broke some invocations of :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. .. _5.0: 5.0 --- *May 22, 2018* Security fix ............ .. admonition:: websockets 5.0 fixes a security issue introduced in 4.0. :class: danger Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed messages (`CVE-2018-1000518`_). .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 Backwards-incompatible changes .............................. .. admonition:: A ``user_info`` field is added to the return value of ``parse_uri`` and ``WebSocketURI``. :class: note If you're unpacking ``WebSocketURI`` into four variables, adjust your code to account for that fifth field. New features ............ * :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains credentials. * :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property to protocols. * Added new examples in the documentation. Improvements ............ * Iterating on incoming messages no longer raises an exception when the connection terminates with close code 1001 (going away). * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. * If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's canceled when the connection is closed. * Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. * Stopped logging stack traces when the TCP connection dies prematurely. * Prevented writing to a closing TCP connection during unclean shutdowns. * Made connection termination more robust to network congestion. * Prevented processing of incoming frames after failing the connection. * Updated documentation with new features from Python 3.6. * Improved several sections of the documentation. Bug fixes ......... * Prevented :exc:`TypeError` due to missing close code on connection close. * Fixed a race condition in the closing handshake that raised :exc:`~exceptions.InvalidState`. 4.0.1 ----- *November 2, 2017* Bug fixes ......... * Fixed issues with the packaging of the 4.0 release. .. _4.0: 4.0 --- *November 2, 2017* Backwards-incompatible changes .............................. .. admonition:: websockets 4.0 requires Python ≥ 3.4. :class: tip websockets 3.4 is the last version supporting Python 3.3. .. admonition:: Compression is enabled by default. :class: important In August 2017, Firefox and Chrome support the permessage-deflate extension, but not Safari and IE. Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. .. admonition:: The ``state_name`` attribute of protocols is deprecated. :class: note Use ``protocol.state.name`` instead of ``protocol.state_name``. New features ............ * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. * Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. * Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the return value of :func:`~legacy.server.serve`. * Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. Improvements ............ * Reorganized and extended documentation. * Rewrote connection termination to increase robustness in edge cases. * Reduced verbosity of "Failing the WebSocket connection" logs. Bug fixes ......... * Aborted connections if they don't close within the configured ``timeout``. * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. .. _3.4: 3.4 --- *August 20, 2017* Backwards-incompatible changes .............................. .. admonition:: ``InvalidStatus`` is replaced by :class:`~exceptions.InvalidStatusCode`. :class: note This exception is raised when :func:`~legacy.client.connect` receives an invalid response status code from the server. New features ............ * :func:`~legacy.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with :meth:`~legacy.server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. Improvements ............ * Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. * Rewrote HTTP handling for simplicity and performance. * Added an optional C extension to speed up low-level operations. Bug fixes ......... * Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. .. _3.3: 3.3 --- *March 29, 2017* New features ............ * Ensured compatibility with Python 3.6. Improvements ............ * Reduced noise in logs caused by connection resets. Bug fixes ......... * Avoided crashing on concurrent writes on slow connections. .. _3.2: 3.2 --- *August 17, 2016* New features ............ * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. Improvements ............ * Made server shutdown more robust. .. _3.1: 3.1 --- *April 21, 2016* New features ............ * Added flow control for incoming data. Bug fixes ......... * Avoided a warning when closing a connection before the opening handshake. .. _3.0: 3.0 --- *December 25, 2015* Backwards-incompatible changes .............................. .. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` now raises an exception when the connection is closed. :class: caution :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return :obj:`None` when the connection was closed. This required checking the return value of every call:: message = await websocket.recv() if message is None: return Now it raises a :exc:`~exceptions.ConnectionClosed` exception instead. This is more Pythonic. The previous code can be simplified to:: message = await websocket.recv() When implementing a server, there's no strong reason to handle such exceptions. Let them bubble up, terminate the handler coroutine, and the server will simply ignore them. In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, :class:`~legacy.server.WebSocketServerProtocol`, or :class:`~legacy.client.WebSocketClientProtocol`. New features ............ * :func:`~legacy.client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as :class:`str` in addition to :class:`bytes`. * Made ``state_name`` attribute on protocols a public API. Improvements ............ * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. * Worked around an :mod:`asyncio` bug affecting connection termination under load. * Improved documentation. .. _2.7: 2.7 --- *November 18, 2015* New features ............ * Added compatibility with Python 3.5. Improvements ............ * Refreshed documentation. .. _2.6: 2.6 --- *August 18, 2015* New features ............ * Added ``local_address`` and ``remote_address`` attributes on protocols. * Closed open connections with code 1001 when a server shuts down. Bug fixes ......... * Avoided TCP fragmentation of small frames. .. _2.5: 2.5 --- *July 28, 2015* New features ............ * Provided access to handshake request and response HTTP headers. * Allowed customizing handshake request and response HTTP headers. * Added support for running on a non-default event loop. Improvements ............ * Improved documentation. * Sent a 403 status code instead of 400 when request Origin isn't allowed. * Clarified that the closing handshake can be initiated by the client. * Set the close code and reason more consistently. * Strengthened connection termination. Bug fixes ......... * Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. .. _2.4: 2.4 --- *January 31, 2015* New features ............ * Added support for subprotocols. * Added ``loop`` argument to :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. .. _2.3: 2.3 --- *November 3, 2014* Improvements ............ * Improved compliance of close codes. .. _2.2: 2.2 --- *July 28, 2014* New features ............ * Added support for limiting message size. .. _2.1: 2.1 --- *April 26, 2014* New features ............ * Added ``host``, ``port`` and ``secure`` attributes on protocols. * Added support for providing and checking Origin_. .. _Origin: https://datatracker.ietf.org/doc/html/rfc6455.html#section-10.2 .. _2.0: 2.0 --- *February 16, 2014* Backwards-incompatible changes .............................. .. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` are now coroutines. :class: caution They used to be functions. Instead of:: websocket.send(message) you must write:: await websocket.send(message) New features ............ * Added flow control for outgoing data. .. _1.0: 1.0 --- *November 14, 2013* New features ............ * Initial public release. websockets-15.0.1/docs/project/contributing.rst000066400000000000000000000043241476212450300216150ustar00rootroot00000000000000Contributing ============ Thanks for taking the time to contribute to websockets! Code of Conduct --------------- This project and everyone participating in it is governed by the `Code of Conduct`_. By participating, you are expected to uphold this code. Please report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. .. _Code of Conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* Contributing ------------ Bug reports, patches and suggestions are welcome! Please open an issue_ or send a `pull request`_. Feedback about the documentation is especially valuable, as the primary author feels more confident about writing code than writing docs :-) If you're wondering why things are done in a certain way, the :doc:`design document <../topics/design>` provides lots of details about the internals of websockets. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ Packaging --------- Some distributions package websockets so that it can be installed with the system package manager rather than with pip, possibly in a virtualenv. If you're packaging websockets for a distribution, you must use `releases published on PyPI`_ as input. You may check `SLSA attestations on GitHub`_. .. _releases published on PyPI: https://pypi.org/project/websockets/#files .. _SLSA attestations on GitHub: https://github.com/python-websockets/websockets/attestations You mustn't rely on the git repository as input. Specifically, you mustn't attempt to run the main test suite. It isn't treated as a deliverable of the project. It doesn't do what you think it does. It's designed for the needs of developers, not packagers. On a typical build farm for a distribution, tests that exercise timeouts will fail randomly. Indeed, the test suite is optimized for running very fast, with a tolerable level of flakiness, on a high-end laptop without noisy neighbors. This isn't your context. websockets-15.0.1/docs/project/index.rst000066400000000000000000000003521476212450300202120ustar00rootroot00000000000000About websockets ================ This is about websockets-the-project rather than websockets-the-software. .. toctree:: :titlesonly: changelog contributing sponsoring For enterprise support license websockets-15.0.1/docs/project/license.rst000066400000000000000000000000541476212450300205240ustar00rootroot00000000000000License ======= .. include:: ../../LICENSE websockets-15.0.1/docs/project/sponsoring.rst000066400000000000000000000004251476212450300213050ustar00rootroot00000000000000Sponsoring ========== You may sponsor the development of websockets through: * `GitHub Sponsors`_ * `Open Collective`_ * :doc:`Tidelift ` .. _GitHub Sponsors: https://github.com/sponsors/python-websockets .. _Open Collective: https://opencollective.com/websockets websockets-15.0.1/docs/project/support.rst000066400000000000000000000032361476212450300206230ustar00rootroot00000000000000Getting support =============== .. admonition:: There are no free support channels. :class: tip websockets is an open-source project. It's primarily maintained by one person as a hobby. For this reason, the focus is on flawless code and self-service documentation, not support. Enterprise ---------- websockets is maintained with high standards, making it suitable for enterprise use cases. Additional guarantees are available via :doc:`Tidelift `. If you're using it in a professional setting, consider subscribing. Questions --------- GitHub issues aren't a good medium for handling questions. There are better places to ask questions, for example Stack Overflow. If you want to ask a question anyway, please make sure that: - it's a question about websockets and not about :mod:`asyncio`; - it isn't answered in the documentation; - it wasn't asked already. A good question can be written as a suggestion to improve the documentation. Cryptocurrency users -------------------- websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. I'm aware of efforts to build proof-of-stake models. I'll care once the total energy consumption of all cryptocurrencies drops to a non-bullshit level. You already negated all of humanity's efforts to develop renewable energy. Please stop heating the planet where my children will have to live. Since websockets is released under an open-source license, you can use it for any purpose you like. However, I won't spend any of my time to help you. I will summarily close issues related to cryptocurrency in any way. websockets-15.0.1/docs/project/tidelift.rst000066400000000000000000000103101476212450300207020ustar00rootroot00000000000000websockets for enterprise ========================= Available as part of the Tidelift Subscription ---------------------------------------------- .. image:: ../_static/tidelift.png :height: 150px :width: 150px :align: left Tidelift is working with the maintainers of websockets and thousands of other open source projects to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. .. raw:: html Enterprise-ready open source software—managed for you ----------------------------------------------------- The Tidelift Subscription is a managed open source subscription for application dependencies covering millions of open source projects across JavaScript, Python, Java, PHP, Ruby, .NET, and more. Your subscription includes: * **Security updates** * Tidelift’s security response team coordinates patches for new breaking security vulnerabilities and alerts immediately through a private channel, so your software supply chain is always secure. * **Licensing verification and indemnification** * Tidelift verifies license information to enable easy policy enforcement and adds intellectual property indemnification to cover creators and users in case something goes wrong. You always have a 100% up-to-date bill of materials for your dependencies to share with your legal team, customers, or partners. * **Maintenance and code improvement** * Tidelift ensures the software you rely on keeps working as long as you need it to work. Your managed dependencies are actively maintained and we recruit additional maintainers where required. * **Package selection and version guidance** * We help you choose the best open source packages from the start—and then guide you through updates to stay on the best releases as new issues arise. * **Roadmap input** * Take a seat at the table with the creators behind the software you use. Tidelift’s participating maintainers earn more income as their software is used by more subscribers, so they’re interested in knowing what you need. * **Tooling and cloud integration** * Tidelift works with GitHub, GitLab, BitBucket, and more. We support every cloud platform (and other deployment targets, too). The end result? All of the capabilities you expect from commercial-grade software, for the full breadth of open source you use. That means less time grappling with esoteric open source trivia, and more time building your own applications—and your business. .. raw:: html websockets-15.0.1/docs/reference/000077500000000000000000000000001476212450300166415ustar00rootroot00000000000000websockets-15.0.1/docs/reference/asyncio/000077500000000000000000000000001476212450300203065ustar00rootroot00000000000000websockets-15.0.1/docs/reference/asyncio/client.rst000066400000000000000000000023041476212450300223150ustar00rootroot00000000000000Client (:mod:`asyncio`) ======================= .. automodule:: websockets.asyncio.client Opening a connection -------------------- .. autofunction:: connect :async: .. autofunction:: unix_connect :async: .. autofunction:: process_exception Using a connection ------------------ .. autoclass:: ClientConnection .. automethod:: __aiter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoattribute:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/asyncio/common.rst000066400000000000000000000020121476212450300223230ustar00rootroot00000000000000:orphan: Both sides (:mod:`asyncio`) =========================== .. automodule:: websockets.asyncio.connection .. autoclass:: Connection .. automethod:: __aiter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoattribute:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/asyncio/server.rst000066400000000000000000000036661476212450300223610ustar00rootroot00000000000000Server (:mod:`asyncio`) ======================= .. automodule:: websockets.asyncio.server Creating a server ----------------- .. autofunction:: serve :async: .. autofunction:: unix_serve :async: Routing connections ------------------- .. automodule:: websockets.asyncio.router .. autofunction:: route :async: .. autofunction:: unix_route :async: .. autoclass:: Router .. currentmodule:: websockets.asyncio.server Running a server ---------------- .. autoclass:: Server .. autoattribute:: connections .. automethod:: close .. automethod:: wait_closed .. automethod:: get_loop .. automethod:: is_serving .. automethod:: start_serving .. automethod:: serve_forever .. autoattribute:: sockets Using a connection ------------------ .. autoclass:: ServerConnection .. automethod:: __aiter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong .. automethod:: respond WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoattribute:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason Broadcast --------- .. autofunction:: broadcast HTTP Basic Authentication ------------------------- websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. .. autofunction:: basic_auth websockets-15.0.1/docs/reference/datastructures.rst000066400000000000000000000024471476212450300224570ustar00rootroot00000000000000Data structures =============== WebSocket events ---------------- .. automodule:: websockets.frames .. autoclass:: Frame .. autoclass:: Opcode .. autoattribute:: CONT .. autoattribute:: TEXT .. autoattribute:: BINARY .. autoattribute:: CLOSE .. autoattribute:: PING .. autoattribute:: PONG .. autoclass:: Close .. autoclass:: CloseCode .. autoattribute:: NORMAL_CLOSURE .. autoattribute:: GOING_AWAY .. autoattribute:: PROTOCOL_ERROR .. autoattribute:: UNSUPPORTED_DATA .. autoattribute:: NO_STATUS_RCVD .. autoattribute:: ABNORMAL_CLOSURE .. autoattribute:: INVALID_DATA .. autoattribute:: POLICY_VIOLATION .. autoattribute:: MESSAGE_TOO_BIG .. autoattribute:: MANDATORY_EXTENSION .. autoattribute:: INTERNAL_ERROR .. autoattribute:: SERVICE_RESTART .. autoattribute:: TRY_AGAIN_LATER .. autoattribute:: BAD_GATEWAY .. autoattribute:: TLS_HANDSHAKE HTTP events ----------- .. automodule:: websockets.http11 .. autoclass:: Request .. autoclass:: Response .. automodule:: websockets.datastructures .. autoclass:: Headers .. automethod:: get_all .. automethod:: raw_items .. autoexception:: MultipleValuesError URIs ---- .. automodule:: websockets.uri .. autofunction:: parse_uri .. autoclass:: WebSocketURI websockets-15.0.1/docs/reference/exceptions.rst000066400000000000000000000037421476212450300215620ustar00rootroot00000000000000Exceptions ========== .. automodule:: websockets.exceptions .. autoexception:: WebSocketException Connection closed ----------------- :meth:`~websockets.asyncio.connection.Connection.recv`, :meth:`~websockets.asyncio.connection.Connection.send`, and similar methods raise the exceptions below when the connection is closed. This is the expected way to detect disconnections. .. autoexception:: ConnectionClosed .. autoexception:: ConnectionClosedOK .. autoexception:: ConnectionClosedError Connection failed ----------------- These exceptions are raised by :func:`~websockets.asyncio.client.connect` when the opening handshake fails and the connection cannot be established. They are also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidURI .. autoexception:: InvalidProxy .. autoexception:: InvalidHandshake .. autoexception:: SecurityError .. autoexception:: ProxyError .. autoexception:: InvalidProxyMessage .. autoexception:: InvalidProxyStatus .. autoexception:: InvalidMessage .. autoexception:: InvalidStatus .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat .. autoexception:: InvalidHeaderValue .. autoexception:: InvalidOrigin .. autoexception:: InvalidUpgrade .. autoexception:: NegotiationError .. autoexception:: DuplicateParameter .. autoexception:: InvalidParameterName .. autoexception:: InvalidParameterValue Sans-I/O exceptions ------------------- These exceptions are only raised by the Sans-I/O implementation. They are translated to :exc:`ConnectionClosedError` in the other implementations. .. autoexception:: ProtocolError .. autoexception:: PayloadTooBig .. autoexception:: InvalidState Miscellaneous exceptions ------------------------ .. autoexception:: ConcurrencyError Legacy exceptions ----------------- These exceptions are only used by the legacy :mod:`asyncio` implementation. .. autoexception:: InvalidStatusCode .. autoexception:: AbortHandshake .. autoexception:: RedirectHandshake websockets-15.0.1/docs/reference/extensions.rst000066400000000000000000000026751476212450300216040ustar00rootroot00000000000000Extensions ========== .. currentmodule:: websockets.extensions The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate. .. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate ------------------- .. automodule:: websockets.extensions.permessage_deflate :mod:`websockets.extensions.permessage_deflate` implements WebSocket Per-Message Deflate. This extension is specified in :rfc:`7692`. Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. .. autoclass:: ServerPerMessageDeflateFactory .. autoclass:: ClientPerMessageDeflateFactory Base classes ------------ .. automodule:: websockets.extensions :mod:`websockets.extensions` defines base classes for implementing extensions. Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to learn more about writing an extension. .. autoclass:: Extension .. autoattribute:: name .. automethod:: decode .. automethod:: encode .. autoclass:: ServerExtensionFactory .. automethod:: process_request_params .. autoclass:: ClientExtensionFactory .. autoattribute:: name .. automethod:: get_request_params .. automethod:: process_response_params websockets-15.0.1/docs/reference/features.rst000066400000000000000000000301451476212450300212140ustar00rootroot00000000000000Features ======== .. currentmodule:: websockets Feature support matrices summarize which implementations support which features. .. raw:: html .. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` .. |sans| replace:: `Sans-I/O`_ .. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ Both sides ---------- .. table:: :class: support-matrix-table +------------------------------------+--------+--------+--------+--------+ | | |aio| | |sync| | |sans| | |leg| | +====================================+========+========+========+========+ | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce opening timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Broadcast a message | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Receive a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Iterate over received messages | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | | by frame | | | | | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message after | ✅ | ✅ | — | ✅ | | reassembly | | | | | +------------------------------------+--------+--------+--------+--------+ | Force sending a message as Text or | ✅ | ✅ | — | ❌ | | Binary | | | | | +------------------------------------+--------+--------+--------+--------+ | Force receiving a message as | ✅ | ✅ | — | ❌ | | :class:`bytes` or :class:`str` | | | | | +------------------------------------+--------+--------+--------+--------+ | Send a ping | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Keepalive | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | | from both sides | | | | | +------------------------------------+--------+--------+--------+--------+ | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce security limits | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Log events | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ Server ------ .. table:: :class: support-matrix-table +------------------------------------+--------+--------+--------+--------+ | | |aio| | |sync| | |sans| | |leg| | +====================================+========+========+========+========+ | Listen on a TCP socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Listen on a Unix socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Close server on context exit | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Close connection on handler exit | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Shut down server gracefully | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | +------------------------------------+--------+--------+--------+--------+ | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | +------------------------------------+--------+--------+--------+--------+ | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ Client ------ .. table:: :class: support-matrix-table +------------------------------------+--------+--------+--------+--------+ | | |aio| | |sync| | |sans| | |leg| | +====================================+========+========+========+========+ | Connect to a TCP socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Connect to a Unix socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Close connection on context exit | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Reconnect automatically | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Follow HTTP redirects | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ Known limitations ----------------- There is no way to control compression of outgoing frames on a per-frame basis (`#538`_). If compression is enabled, all frames are compressed. .. _#538: https://github.com/python-websockets/websockets/issues/538 The server doesn't check the Host header and doesn't respond with HTTP 400 Bad Request if it is missing or invalid (`#1246`_). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 The client doesn't support HTTP Digest Authentication (`#784`_). .. _#784: https://github.com/python-websockets/websockets/issues/784 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. It is possible to send or receive a text message containing invalid UTF-8 with ``send(not_utf8_bytes, text=True)`` and ``not_utf8_bytes = recv(decode=False)`` respectively. As a side effect of disabling UTF-8 encoding and decoding, these options also disable UTF-8 validation. websockets-15.0.1/docs/reference/index.rst000066400000000000000000000032631476212450300205060ustar00rootroot00000000000000API reference ============= .. currentmodule:: websockets Features -------- Check which implementations support which features and known limitations. .. toctree:: :titlesonly: features :mod:`asyncio` -------------- It's ideal for servers that handle many clients concurrently. This is the default implementation. .. toctree:: :titlesonly: asyncio/server asyncio/client :mod:`threading` ---------------- This alternative implementation can be a good choice for clients. .. toctree:: :titlesonly: sync/server sync/client `Sans-I/O`_ ----------- This layer is designed for integrating in third-party libraries, typically application servers. .. _Sans-I/O: https://sans-io.readthedocs.io/ .. toctree:: :titlesonly: sansio/server sansio/client Legacy ------ This is the historical implementation. It is deprecated. It will be removed by 2030. .. toctree:: :titlesonly: legacy/server legacy/client Extensions ---------- The Per-Message Deflate extension is built-in. You may also define custom extensions. .. toctree:: :titlesonly: extensions Shared ------ These low-level APIs are shared by all implementations. .. toctree:: :titlesonly: datastructures exceptions types variables API stability ------------- Public APIs documented in this API reference are subject to the :ref:`backwards-compatibility policy `. Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. Convenience imports ------------------- For convenience, some public APIs can be imported directly from the ``websockets`` package. websockets-15.0.1/docs/reference/legacy/000077500000000000000000000000001476212450300201055ustar00rootroot00000000000000websockets-15.0.1/docs/reference/legacy/client.rst000066400000000000000000000044641476212450300221250ustar00rootroot00000000000000Client (legacy) =============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions to migrate your application. .. automodule:: websockets.legacy.client Opening a connection -------------------- .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Using a connection ------------------ .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoproperty:: open .. autoproperty:: closed .. autoattribute:: latency The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: path .. autoattribute:: request_headers .. autoattribute:: response_headers .. autoattribute:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/legacy/common.rst000066400000000000000000000025571476212450300221400ustar00rootroot00000000000000:orphan: Both sides (legacy) =================== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions to migrate your application. .. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoproperty:: open .. autoproperty:: closed .. autoattribute:: latency The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: path .. autoattribute:: request_headers .. autoattribute:: response_headers .. autoattribute:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/legacy/server.rst000066400000000000000000000065641476212450300221600ustar00rootroot00000000000000Server (legacy) =============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions to migrate your application. .. automodule:: websockets.legacy.server Starting a server ----------------- .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: .. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server ----------------- .. autoclass:: WebSocketServer .. automethod:: close .. automethod:: wait_closed .. automethod:: get_loop .. automethod:: is_serving .. automethod:: start_serving .. automethod:: serve_forever .. autoattribute:: sockets Using a connection ------------------ .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv .. automethod:: send .. automethod:: close .. automethod:: wait_closed .. automethod:: ping .. automethod:: pong You can customize the opening handshake in a subclass by overriding these methods: .. automethod:: process_request .. automethod:: select_subprotocol WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoproperty:: open .. autoproperty:: closed .. autoattribute:: latency The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: path .. autoattribute:: request_headers .. autoattribute:: response_headers .. autoattribute:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason Broadcast --------- .. autofunction:: websockets.legacy.server.broadcast Basic authentication -------------------- .. automodule:: websockets.legacy.auth websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. .. autofunction:: basic_auth_protocol_factory .. autoclass:: BasicAuthWebSocketServerProtocol .. autoattribute:: realm .. autoattribute:: username .. automethod:: check_credentials websockets-15.0.1/docs/reference/sansio/000077500000000000000000000000001476212450300201355ustar00rootroot00000000000000websockets-15.0.1/docs/reference/sansio/client.rst000066400000000000000000000021671476212450300221530ustar00rootroot00000000000000Client (`Sans-I/O`_) ==================== .. _Sans-I/O: https://sans-io.readthedocs.io/ .. currentmodule:: websockets.client .. autoclass:: ClientProtocol .. automethod:: receive_data .. automethod:: receive_eof .. automethod:: connect .. automethod:: send_request .. automethod:: send_continuation .. automethod:: send_text .. automethod:: send_binary .. automethod:: send_close .. automethod:: send_ping .. automethod:: send_pong .. automethod:: fail .. automethod:: events_received .. automethod:: data_to_send .. automethod:: close_expected WebSocket protocol objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: handshake_exc The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason .. autoproperty:: close_exc websockets-15.0.1/docs/reference/sansio/common.rst000066400000000000000000000017661476212450300221710ustar00rootroot00000000000000:orphan: Both sides (`Sans-I/O`_) ========================= .. _Sans-I/O: https://sans-io.readthedocs.io/ .. automodule:: websockets.protocol .. autoclass:: Protocol .. automethod:: receive_data .. automethod:: receive_eof .. automethod:: send_continuation .. automethod:: send_text .. automethod:: send_binary .. automethod:: send_close .. automethod:: send_ping .. automethod:: send_pong .. automethod:: fail .. automethod:: events_received .. automethod:: data_to_send .. automethod:: close_expected .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: state .. autoproperty:: close_code .. autoproperty:: close_reason .. autoproperty:: close_exc .. autoclass:: Side .. autoattribute:: SERVER .. autoattribute:: CLIENT .. autoclass:: State .. autoattribute:: CONNECTING .. autoattribute:: OPEN .. autoattribute:: CLOSING .. autoattribute:: CLOSED .. autodata:: SEND_EOF websockets-15.0.1/docs/reference/sansio/server.rst000066400000000000000000000022731476212450300222010ustar00rootroot00000000000000Server (`Sans-I/O`_) ==================== .. _Sans-I/O: https://sans-io.readthedocs.io/ .. currentmodule:: websockets.server .. autoclass:: ServerProtocol .. automethod:: receive_data .. automethod:: receive_eof .. automethod:: accept .. automethod:: select_subprotocol .. automethod:: reject .. automethod:: send_response .. automethod:: send_continuation .. automethod:: send_text .. automethod:: send_binary .. automethod:: send_close .. automethod:: send_ping .. automethod:: send_pong .. automethod:: fail .. automethod:: events_received .. automethod:: data_to_send .. automethod:: close_expected WebSocket protocol objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: handshake_exc The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason .. autoproperty:: close_exc websockets-15.0.1/docs/reference/sync/000077500000000000000000000000001476212450300176155ustar00rootroot00000000000000websockets-15.0.1/docs/reference/sync/client.rst000066400000000000000000000021451476212450300216270ustar00rootroot00000000000000Client (:mod:`threading`) ========================= .. automodule:: websockets.sync.client Opening a connection -------------------- .. autofunction:: connect .. autofunction:: unix_connect Using a connection ------------------ .. autoclass:: ClientConnection .. automethod:: __iter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoproperty:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/sync/common.rst000066400000000000000000000017511476212450300216430ustar00rootroot00000000000000:orphan: Both sides (:mod:`threading`) ============================= .. automodule:: websockets.sync.connection .. autoclass:: Connection .. automethod:: __iter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: ping .. automethod:: pong WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoattribute:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason websockets-15.0.1/docs/reference/sync/server.rst000066400000000000000000000032051476212450300216550ustar00rootroot00000000000000Server (:mod:`threading`) ========================= .. automodule:: websockets.sync.server Creating a server ----------------- .. autofunction:: serve .. autofunction:: unix_serve Routing connections ------------------- .. automodule:: websockets.sync.router .. autofunction:: route .. autofunction:: unix_route .. autoclass:: Router .. currentmodule:: websockets.sync.server Running a server ---------------- .. autoclass:: Server .. automethod:: serve_forever .. automethod:: shutdown .. automethod:: fileno Using a connection ------------------ .. autoclass:: ServerConnection .. automethod:: __iter__ .. automethod:: recv .. automethod:: recv_streaming .. automethod:: send .. automethod:: close .. automethod:: ping .. automethod:: pong .. automethod:: respond WebSocket connection objects also provide these attributes: .. autoattribute:: id .. autoattribute:: logger .. autoproperty:: local_address .. autoproperty:: remote_address .. autoproperty:: latency .. autoproperty:: state The following attributes are available after the opening handshake, once the WebSocket connection is open: .. autoattribute:: request .. autoattribute:: response .. autoproperty:: subprotocol The following attributes are available after the closing handshake, once the WebSocket connection is closed: .. autoproperty:: close_code .. autoproperty:: close_reason HTTP Basic Authentication ------------------------- websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. .. autofunction:: basic_auth websockets-15.0.1/docs/reference/types.rst000066400000000000000000000006051476212450300205400ustar00rootroot00000000000000Types ===== .. automodule:: websockets.typing .. autodata:: Data .. autodata:: LoggerLike .. autodata:: StatusLike .. autodata:: Origin .. autodata:: Subprotocol .. autodata:: ExtensionName .. autodata:: ExtensionParameter .. autodata:: websockets.protocol.Event .. autodata:: websockets.datastructures.HeadersLike .. autodata:: websockets.datastructures.SupportsKeysAndGetItem websockets-15.0.1/docs/reference/variables.rst000066400000000000000000000041461476212450300213500ustar00rootroot00000000000000Environment variables ===================== .. currentmodule:: websockets Logging ------- .. envvar:: WEBSOCKETS_MAX_LOG_SIZE How much of each frame to show in debug logs. The default value is ``75``. See the :doc:`logging guide <../topics/logging>` for details. Security -------- .. envvar:: WEBSOCKETS_SERVER Server header sent by websockets. The default value uses the format ``"Python/x.y.z websockets/X.Y"``. .. envvar:: WEBSOCKETS_USER_AGENT User-Agent header sent by websockets. The default value uses the format ``"Python/x.y.z websockets/X.Y"``. .. envvar:: WEBSOCKETS_MAX_LINE_LENGTH Maximum length of the request or status line in the opening handshake. The default value is ``8192`` bytes. .. envvar:: WEBSOCKETS_MAX_NUM_HEADERS Maximum number of HTTP headers in the opening handshake. The default value is ``128`` bytes. .. envvar:: WEBSOCKETS_MAX_BODY_SIZE Maximum size of the body of an HTTP response in the opening handshake. The default value is ``1_048_576`` bytes (1 MiB). See the :doc:`security guide <../topics/security>` for details. Reconnection ------------ Reconnection attempts are spaced out with truncated exponential backoff. .. envvar:: WEBSOCKETS_BACKOFF_INITIAL_DELAY The first attempt is delayed by a random amount of time between ``0`` and ``WEBSOCKETS_BACKOFF_INITIAL_DELAY`` seconds. The default value is ``5.0`` seconds. .. envvar:: WEBSOCKETS_BACKOFF_MIN_DELAY The second attempt is delayed by ``WEBSOCKETS_BACKOFF_MIN_DELAY`` seconds. The default value is ``3.1`` seconds. .. envvar:: WEBSOCKETS_BACKOFF_FACTOR After the second attempt, the delay is multiplied by ``WEBSOCKETS_BACKOFF_FACTOR`` between each attempt. The default value is ``1.618``. .. envvar:: WEBSOCKETS_BACKOFF_MAX_DELAY The delay between attempts is capped at ``WEBSOCKETS_BACKOFF_MAX_DELAY`` seconds. The default value is ``90.0`` seconds. Redirects --------- .. envvar:: WEBSOCKETS_MAX_REDIRECTS Maximum number of redirects that :func:`~asyncio.client.connect` follows. The default value is ``10``. websockets-15.0.1/docs/requirements.txt000066400000000000000000000002111476212450300201610ustar00rootroot00000000000000furo sphinx sphinx-autobuild sphinx-copybutton sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio sphinxext-opengraph werkzeug websockets-15.0.1/docs/spelling_wordlist.txt000066400000000000000000000014101476212450300212040ustar00rootroot00000000000000augustin auth autoscaler aymeric backend backoff backpressure balancer balancers bottlenecked bufferbloat bugfix buildpack bytestring bytestrings changelog coroutine coroutines cryptocurrencies cryptocurrency css ctrl deserialize dev django Dockerfile dyno formatter fractalideas github gunicorn healthz html hypercorn iframe io IPv istio iterable js keepalive KiB koyeb kubernetes lifecycle linkerd liveness lookups MiB middleware mutex mypy nginx PaaS Paketo permessage pid procfile proxying py pythonic reconnection redis redistributions retransmit retryable runtime scalable stateful subclasses subclassing submodule subpackages subprotocol subprotocols supervisord tidelift tls tox txt unregister uple uvicorn uvloop virtualenv websocket WebSocket websockets ws wsgi www websockets-15.0.1/docs/topics/000077500000000000000000000000001476212450300162045ustar00rootroot00000000000000websockets-15.0.1/docs/topics/authentication.rst000066400000000000000000000261551476212450300217660ustar00rootroot00000000000000Authentication ============== The WebSocket protocol is designed for creating web applications that require bidirectional communication between browsers and servers. In most practical use cases, WebSocket servers need to authenticate clients in order to route communications appropriately and securely. :rfc:`6455` remains elusive when it comes to authentication: This protocol doesn't prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket server can use any client authentication mechanism available to a generic HTTP server, such as cookies, HTTP authentication, or TLS authentication. None of these three mechanisms works well in practice. Using cookies is cumbersome, HTTP authentication isn't supported by all mainstream browsers, and TLS authentication in a browser is an esoteric user experience. Fortunately, there are better alternatives! Let's discuss them. System design ------------- Consider a setup where the WebSocket server is separate from the HTTP server. Most servers built with websockets adopt this design because they're a component in a web application and websockets doesn't aim at supporting HTTP. The following diagram illustrates the authentication flow. .. image:: authentication.svg Assuming the current user is authenticated with the HTTP server (1), the application needs to obtain credentials from the HTTP server (2) in order to send them to the WebSocket server (3), who can check them against the database of user accounts (4). Usernames and passwords aren't a good choice of credentials here, if only because passwords aren't available in clear text in the database. Tokens linked to user accounts are a better choice. These tokens must be impossible to forge by an attacker. For additional security, they can be short-lived or even single-use. Sending credentials ------------------- Assume the web application obtained authentication credentials, likely a token, from the HTTP server. There's four options for passing them to the WebSocket server. 1. **Sending credentials as the first message in the WebSocket connection.** This is fully reliable and the most secure mechanism in this discussion. It has two minor downsides: * Authentication is performed at the application layer. Ideally, it would be managed at the protocol layer. * Authentication is performed after the WebSocket handshake, making it impossible to monitor authentication failures with HTTP response codes. 2. **Adding credentials to the WebSocket URI in a query parameter.** This is also fully reliable but less secure. Indeed, it has a major downside: * URIs end up in logs, which leaks credentials. Even if that risk could be lowered with single-use tokens, it is usually considered unacceptable. Authentication is still performed at the application layer but it can happen before the WebSocket handshake, which improves separation of concerns and enables responding to authentication failures with HTTP 401. 3. **Setting a cookie on the domain of the WebSocket URI.** Cookies are undoubtedly the most common and hardened mechanism for sending credentials from a web application to a server. In an HTTP application, credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from the web application, this idea hits the wall of the `Same-Origin Policy`_. For security reasons, setting a cookie on a different origin is impossible. The proper workaround consists in: * creating a hidden iframe_ served from the domain of the WebSocket server * sending the token to the iframe with postMessage_ * setting the cookie in the iframe before opening the WebSocket connection. Sharing a parent domain (e.g. example.com) between the HTTP server (e.g. www.example.com) and the WebSocket server (e.g. ws.example.com) and setting the cookie on that parent domain would work too. However, the cookie would be shared with all subdomains of the parent domain. For a cookie containing credentials, this is unacceptable. .. _Same-Origin Policy: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy .. _iframe: https://developer.mozilla.org/en-US/docs/Web/HTML/Element/iframe .. _postMessage: https://developer.mozilla.org/en-US/docs/Web/API/MessagePort/postMessage 4. **Adding credentials to the WebSocket URI in user information.** Letting the browser perform HTTP Basic Auth is a nice idea in theory. In practice it doesn't work due to browser support limitations: * Chrome behaves as expected. * Firefox caches credentials too aggressively. When connecting again to the same server with new credentials, it reuses the old credentials, which may be expired, resulting in an HTTP 401. Then the next connection succeeds. Perhaps errors clear the cache. When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. * Safari behaves as expected. Two other options are off the table: 1. **Setting a custom HTTP header** This would be the most elegant mechanism, solving all issues with the options discussed above. Unfortunately, it doesn't work because the `WebSocket API`_ doesn't support `setting custom headers`_. .. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API .. _setting custom headers: https://github.com/whatwg/html/issues/3062 2. **Authenticating with a TLS certificate** While this is suggested by the RFC, installing a TLS certificate is too far from the mainstream experience of browser users. This could make sense in high security contexts. I hope that developers working on projects in this category don't take security advice from the documentation of random open source projects :-) Let's experiment! ----------------- The `experiments/authentication`_ directory demonstrates these techniques. Run the experiment in an environment where websockets is installed: .. _experiments/authentication: https://github.com/python-websockets/websockets/tree/main/experiments/authentication .. code-block:: console $ python experiments/authentication/app.py Running on http://localhost:8000/ When you browse to the HTTP server at http://localhost:8000/ and you submit a username, the server creates a token and returns a testing web page. This page opens WebSocket connections to four WebSocket servers running on four different origins. It attempts to authenticate with the token in four different ways. First message ............. As soon as the connection is open, the client sends a message containing the token: .. code-block:: javascript const websocket = new WebSocket("ws://.../"); websocket.onopen = () => websocket.send(token); // ... At the beginning of the connection handler, the server receives this message and authenticates the user. If authentication fails, the server closes the connection: .. code-block:: python from websockets.frames import CloseCode async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ... Query parameter ............... The client adds the token to the WebSocket URI in a query parameter before opening the connection: .. code-block:: javascript const uri = `ws://.../?token=${token}`; const websocket = new WebSocket(uri); // ... The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns an HTTP 401: .. code-block:: python async def query_param_auth(connection, request): token = get_query_param(request.path, "token") if token is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") user = get_user(token) if user is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") connection.username = user Cookie ...... The client sets a cookie containing the token before opening the connection. The cookie must be set by an iframe loaded from the same origin as the WebSocket server. This requires passing the token to this iframe. .. code-block:: javascript // in main window iframe.contentWindow.postMessage(token, "http://..."); // in iframe document.cookie = `token=${data}; SameSite=Strict`; // in main window const websocket = new WebSocket("ws://.../"); // ... This sequence must be synchronized between the main window and the iframe. This involves several events. Look at the full implementation for details. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns an HTTP 401: .. code-block:: python async def cookie_auth(connection, request): # Serve iframe on non-WebSocket requests ... token = get_cookie(request.headers.get("Cookie", ""), "token") if token is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") user = get_user(token) if user is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") connection.username = user User information ................ The client adds the token to the WebSocket URI in user information before opening the connection: .. code-block:: javascript const uri = `ws://token:${token}@.../`; const websocket = new WebSocket(uri); // ... Since HTTP Basic Auth is designed to accept a username and a password rather than a token, we send ``token`` as username and the token as password. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns an HTTP 401: .. code-block:: python from websockets.asyncio.server import basic_auth as websockets_basic_auth def check_credentials(username, password): return username == get_user(password) basic_auth = websockets_basic_auth(check_credentials=check_credentials) Machine-to-machine authentication --------------------------------- When the WebSocket client is a standalone program rather than a script running in a browser, there are far fewer constraints. HTTP Authentication is the best solution in this scenario. To authenticate a websockets client with HTTP Basic Authentication (:rfc:`7617`), include the credentials in the URI: .. code-block:: python from websockets.asyncio.client import connect async with connect(f"wss://{username}:{password}@.../") as websocket: ... You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they contain unsafe characters. To authenticate a websockets client with HTTP Bearer Authentication (:rfc:`6750`), add a suitable ``Authorization`` header: .. code-block:: python from websockets.asyncio.client import connect headers = {"Authorization": f"Bearer {token}"} async with connect("wss://.../", additional_headers=headers) as websocket: ... websockets-15.0.1/docs/topics/authentication.svg000066400000000000000000000274171476212450300217570ustar00rootroot00000000000000HTTPserverWebSocketserverweb appin browseruser accounts(1) authenticate user(2) obtain credentials(3) send credentials(4) authenticate userwebsockets-15.0.1/docs/topics/broadcast.rst000066400000000000000000000323061476212450300207040ustar00rootroot00000000000000Broadcasting ============ .. currentmodule:: websockets .. admonition:: If you want to send a message to all connected clients, use :func:`~asyncio.server.broadcast`. :class: tip If you want to learn about its design, continue reading this document. For the legacy :mod:`asyncio` implementation, use :func:`~legacy.server.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design of :func:`~asyncio.server.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all connected clients. Integrating them is left as an exercise for the reader. You could start with:: import asyncio from websockets.asyncio.server import serve async def handler(websocket): ... async def broadcast(message): ... async def broadcast_messages(): while True: await asyncio.sleep(1) message = ... # your application logic goes here await broadcast(message) async def main(): async with serve(handler, "localhost", 8765): await broadcast_messages() # runs forever if __name__ == "__main__": asyncio.run(main()) ``broadcast_messages()`` must yield control to the event loop between each message, or else it will never let the server run. That's why it includes ``await asyncio.sleep(1)``. A complete example is available in the `experiments/broadcast`_ directory. .. _experiments/broadcast: https://github.com/python-websockets/websockets/tree/main/experiments/broadcast The naive way ------------- The most obvious way to send a message to all connected clients consists in keeping track of them and sending the message to each of them. Here's a connection handler that registers clients in a global variable:: CLIENTS = set() async def handler(websocket): CLIENTS.add(websocket) try: await websocket.wait_closed() finally: CLIENTS.remove(websocket) This implementation assumes that the client will never send any messages. If you'd rather not make this assumption, you can change:: await websocket.wait_closed() to:: async for _ in websocket: pass Here's a coroutine that broadcasts a message to all clients:: from websockets.exceptions import ConnectionClosed async def broadcast(message): for websocket in CLIENTS.copy(): try: await websocket.send(message) except ConnectionClosed: pass There are two tricks in this version of ``broadcast()``. First, it makes a copy of ``CLIENTS`` before iterating it. Else, if a client connects or disconnects while ``broadcast()`` is running, the loop would fail with:: RuntimeError: Set changed size during iteration Second, it ignores :exc:`~exceptions.ConnectionClosed` exceptions because a client could disconnect between the moment ``broadcast()`` makes a copy of ``CLIENTS`` and the moment it sends a message to this client. This is fine: a client that disconnected doesn't belongs to "all connected clients" anymore. The naive way can be very fast. Indeed, if all connections have enough free space in their write buffers, ``await websocket.send(message)`` writes the message and returns immediately, as it doesn't need to wait for the buffer to drain. In this case, ``broadcast()`` doesn't yield control to the event loop, which minimizes overhead. The naive way can also fail badly. If the write buffer of a connection reaches ``write_limit``, ``broadcast()`` waits for the buffer to drain before sending the message to other clients. This can cause a massive drop in performance. As a consequence, this pattern works only when write buffers never fill up, which is usually outside of the control of the server. If you know for sure that you will never write more than ``write_limit`` bytes within ``ping_interval + ping_timeout``, then websockets will terminate slow connections before the write buffer can fill up. Don't set extreme values of ``write_limit``, ``ping_interval``, or ``ping_timeout`` to ensure that this condition holds! Instead, set reasonable values and use the built-in :func:`~asyncio.server.broadcast` function. The concurrent way ------------------ The naive way didn't work well because it serialized writes, while the whole point of asynchronous I/O is to perform I/O concurrently. Let's modify ``broadcast()`` to send messages concurrently:: async def send(websocket, message): try: await websocket.send(message) except ConnectionClosed: pass def broadcast(message): for websocket in CLIENTS: asyncio.create_task(send(websocket, message)) We move the error handling logic in a new coroutine and we schedule a :class:`~asyncio.Task` to run it instead of executing it immediately. Since ``broadcast()`` no longer awaits coroutines, we can make it a function rather than a coroutine and do away with the copy of ``CLIENTS``. This version of ``broadcast()`` makes clients independent from one another: a slow client won't block others. As a side effect, it makes messages independent from one another. If you broadcast several messages, there is no strong guarantee that they will be sent in the expected order. Fortunately, the event loop runs tasks in the order in which they are created, so the order is correct in practice. Technically, this is an implementation detail of the event loop. However, it seems unlikely for an event loop to run tasks in an order other than FIFO. If you wanted to enforce the order without relying this implementation detail, you could be tempted to wait until all clients have received the message:: async def broadcast(message): if CLIENTS: # asyncio.wait doesn't accept an empty list await asyncio.wait([ asyncio.create_task(send(websocket, message)) for websocket in CLIENTS ]) However, this doesn't really work in practice. Quite often, it will block until the slowest client times out. Backpressure meets broadcast ---------------------------- At this point, it becomes apparent that backpressure, usually a good practice, doesn't work well when broadcasting a message to thousands of clients. When you're sending messages to a single client, you don't want to send them faster than the network can transfer them and the client accept them. This is why :meth:`~asyncio.server.ServerConnection.send` checks if the write buffer is above the high-water mark and, if it is, waits until it drains, giving the network and the client time to catch up. This provides backpressure. Without backpressure, you could pile up data in the write buffer until the server process runs out of memory and the operating system kills it. The :meth:`~asyncio.server.ServerConnection.send` API is designed to enforce backpressure by default. This helps users of websockets write robust programs even if they never heard about backpressure. For comparison, :class:`asyncio.StreamWriter` requires users to understand backpressure and to await :meth:`~asyncio.StreamWriter.drain` after each :meth:`~asyncio.StreamWriter.write` — or at least sufficiently frequently. When broadcasting messages, backpressure consists in slowing down all clients in an attempt to let the slowest client catch up. With thousands of clients, the slowest one is probably timing out and isn't going to receive the message anyway. So it doesn't make sense to synchronize with the slowest client. How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can refer to the discussion of :doc:`keepalive ` for details. How :func:`~asyncio.server.broadcast` works ------------------------------------------- The built-in :func:`~asyncio.server.broadcast` function is similar to the naive way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. Also, when sending text messages, encoding to UTF-8 happens only once rather than once per client, providing a small performance gain. Per-client queues ----------------- At this point, we deal with slow clients rather brutally: we disconnect then. Can we do better? For example, we could decide to skip or to batch messages, depending on how far behind a client is. To implement this logic, we can create a queue of messages for each client and run a task that gets messages from the queue and sends them to the client:: import asyncio CLIENTS = set() async def relay(queue, websocket): while True: # Implement custom logic based on queue.qsize() and # websocket.transport.get_write_buffer_size() here. message = await queue.get() await websocket.send(message) async def handler(websocket): queue = asyncio.Queue() relay_task = asyncio.create_task(relay(queue, websocket)) CLIENTS.add(queue) try: await websocket.wait_closed() finally: CLIENTS.remove(queue) relay_task.cancel() Then we can broadcast a message by pushing it to all queues:: def broadcast(message): for queue in CLIENTS: queue.put_nowait(message) The queues provide an additional buffer between the ``broadcast()`` function and clients. This makes it easier to support slow clients without excessive memory usage because queued messages aren't duplicated to write buffers until ``relay()`` processes them. Publish–subscribe ----------------- Can we avoid centralizing the list of connected clients in a global variable? If each client subscribes to a stream a messages, then broadcasting becomes as simple as publishing a message to the stream. Here's a message stream that supports multiple consumers:: class PubSub: def __init__(self): self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): waiter = self.waiter self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): waiter = self.waiter while True: value, waiter = await waiter yield value __aiter__ = subscribe PUBSUB = PubSub() The stream is implemented as a linked list of futures. It isn't necessary to synchronize consumers. They can read the stream at their own pace, independently from one another. Once all consumers read a message, there are no references left, therefore the garbage collector deletes it. The connection handler subscribes to the stream and sends messages:: async def handler(websocket): async for message in PUBSUB: await websocket.send(message) The broadcast function publishes to the stream:: def broadcast(message): PUBSUB.publish(message) Like per-client queues, this version supports slow clients with limited memory usage. Unlike per-client queues, it makes it difficult to tell how far behind a client is. The ``PubSub`` class could be extended or refactored to provide this information. The ``for`` loop is gone from this version of the ``broadcast()`` function. However, there's still a ``for`` loop iterating on all clients hidden deep inside :mod:`asyncio`. When ``publish()`` sets the result of the ``waiter`` future, :mod:`asyncio` loops on callbacks registered with this future and schedules them. This is how connection handlers receive the next value from the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- The built-in :func:`~asyncio.server.broadcast` function sends all messages without yielding control to the event loop. So does the naive way when the network and clients are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. It would be tempting to prepare a frame and reuse it for all connections. However, this isn't possible in general for two reasons: * Clients can negotiate different extensions. You would have to enforce the same extensions with the same parameters. For example, you would have to select some compression settings and reject clients that cannot support these settings. * Extensions can be stateful, producing different encodings of the same message depending on previous messages. For example, you would have to disable context takeover to make compression stateless, resulting in poor compression rates. All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower than the built-in :func:`~asyncio.server.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. websockets-15.0.1/docs/topics/compression.rst000066400000000000000000000227021476212450300213020ustar00rootroot00000000000000Compression =========== .. currentmodule:: websockets.extensions.permessage_deflate Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. This makes the stream of messages highly compressible. Compressing messages can reduce network traffic by more than 80%. websockets implements WebSocket Per-Message Deflate, a compression extension based on the Deflate_ algorithm specified in :rfc:`7692`. .. _Deflate: https://en.wikipedia.org/wiki/Deflate :func:`~websockets.asyncio.client.connect` and :func:`~websockets.asyncio.server.serve` enable compression by default because the reduction in network bandwidth is usually worth the additional memory and CPU cost. Configuring compression ----------------------- To disable compression, set ``compression=None``:: connect(..., compression=None, ...) serve(..., compression=None, ...) To customize compression settings, enable the Per-Message Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or :class:`ServerPerMessageDeflateFactory`:: from websockets.extensions import permessage_deflate connect( ..., extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( server_max_window_bits=11, client_max_window_bits=11, compress_settings={"memLevel": 4}, ), ], ) serve( ..., extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( server_max_window_bits=11, client_max_window_bits=11, compress_settings={"memLevel": 4}, ), ], ) The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. Compression parameters ---------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and decompression. These parameters affect the trade-off between compression rate and memory usage for both sides. * **Context Takeover** means that the compression context is retained between messages. In other words, compression is applied to the stream of messages rather than to each message individually. Context takeover should remain enabled to get good performance on applications that send a stream of messages with similar structure, that is, most applications. This requires retaining the compression context and state between messages, which increases the memory footprint of a connection. * **Window Bits** controls the size of the compression context. It must be an integer between 9 (lowest memory usage) and 15 (best compression). Setting it to 8 is possible but rejected by some versions of zlib and not very useful. On the server side, websockets defaults to 12. Specifically, the compression window size (server to client) is always 12 while the decompression window (client to server) size may be 12 or 15 depending on whether the client supports configuring it. On the client side, websockets lets the server pick a suitable value, which has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control the trade-off between compression rate, memory usage, and CPU usage for compressing. They're transparent for decompressing. * **Memory Level** controls the size of the compression state. It must be an integer between 1 (lowest memory usage) and 9 (best compression). websockets defaults to 5. This is lower than zlib's default of 8. Not only does a lower memory level reduce memory usage, but it can also increase speed thanks to memory locality. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). websockets relies on the default value chosen by :func:`~zlib.compressobj`, ``Z_DEFAULT_COMPRESSION``. * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. websockets relies on the default value chosen by :func:`~zlib.compressobj`, ``Z_DEFAULT_STRATEGY``. To customize these parameters, add keyword arguments for :func:`~zlib.compressobj` in ``compress_settings``. Default settings for servers ---------------------------- By default, websockets enables compression with conservative settings that optimize memory usage at the cost of a slightly worse compression rate: Window Bits = 12 and Memory Level = 5. This strikes a good balance for small messages that are typical of WebSocket servers. Here's an example of how compression settings affect memory usage per connection, compressed size, and compression time for a corpus of JSON documents. =========== ============ ============ ================ ================ Window Bits Memory Level Memory usage Size vs. default Time vs. default =========== ============ ============ ================ ================ 15 8 316 KiB -10% +10% 14 7 172 KiB -7% +5% 13 6 100 KiB -3% +2% **12** **5** **64 KiB** **=** **=** 11 4 46 KiB +10% +4% 10 3 37 KiB +70% +40% 9 2 33 KiB +130% +90% — — 14 KiB +350% — =========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts on what could happen at Window Bits = 11 and Memory Level = 4 on a different corpus. Defaults must be safe for all applications, hence a more conservative choice. Optimizing settings ------------------- Compressed size and compression time depend on the structure of messages exchanged by your application. As a consequence, default settings may not be optimal for your use case. To compare how various compression settings perform for your use case: 1. Create a corpus of typical messages in a directory, one message per file. 2. Run the `compression/benchmark.py`_ script, passing the directory in argument. The script measures compressed size and compression time for all combinations of Window Bits and Memory Level. It outputs two tables with absolute values and two tables with values relative to websockets' default settings. Pick your favorite settings in these tables and configure them as shown above. .. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py Default settings for clients ---------------------------- By default, websockets enables compression with Memory Level = 5 but leaves the Window Bits setting up to the server. There's two good reasons and one bad reason for not optimizing Window Bits on the client side as on the server side: 1. If the maintainers of a server configured some optimized settings, we don't want to override them with more restrictive settings. 2. Optimizing memory usage doesn't matter very much for clients because it's uncommon to open thousands of client connections in a program. 3. On a more pragmatic and annoying note, some servers misbehave badly when a client configures compression settings. `AWS API Gateway`_ is the worst offender. .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 Unfortunately, even though websockets is right and AWS is wrong, many users jump to the conclusion that websockets doesn't work. Until the ecosystem levels up, interoperability with buggy servers seems more valuable than optimizing memory usage. Decompression ------------- The discussion above focuses on compression because it's more expensive than decompression. Indeed, leaving aside small allocations, theoretical memory usage is: * ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; * ``1 << windowBits`` for decompression. CPU usage is also higher for compression than decompression. While it's always possible for a server to use a smaller window size for compressing outgoing messages, using a smaller window size for decompressing incoming messages requires collaboration from clients. When a client doesn't support configuring the size of its compression window, websockets enables compression with the largest possible decompression window. In most use cases, this is more efficient than disabling compression both ways. If you are very sensitive to memory usage, you can reverse this behavior by setting the ``require_client_max_window_bits`` parameter of :class:`ServerPerMessageDeflateFactory` to ``True``. Further reading --------------- This `blog post by Ilya Grigorik`_ provides more details about how compression settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory Level = 4 for optimizing memory usage. .. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ websockets-15.0.1/docs/topics/data-flow.svg000066400000000000000000000400001476212450300205750ustar00rootroot00000000000000Integration layerSans-I/O layerApplicationreceivemessagessendmessagesNetworksenddatareceivedatareceivebytessendbytessendeventsreceiveeventswebsockets-15.0.1/docs/topics/design.rst000066400000000000000000000542561476212450300202230ustar00rootroot00000000000000:orphan: Design (legacy) =============== .. currentmodule:: websockets.legacy This document describes the design of the legacy implementation of websockets. It assumes familiarity with the specification of the WebSocket protocol in :rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. .. warning:: Internals described in this document may change at any time. Backwards compatibility is only guaranteed for :doc:`public APIs <../reference/index>`. Lifecycle --------- State ..... WebSocket connections go through a trivial state machine: - ``CONNECTING``: initial state, - ``OPEN``: when the opening handshake is complete, - ``CLOSING``: when the closing handshake is started, - ``CLOSED``: when the TCP connection is closed. Transitions happen in the following places: - ``CONNECTING -> OPEN``: in :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when the :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection is established; - ``OPEN -> CLOSING``: in :meth:`~protocol.WebSocketCommonProtocol.write_frame` immediately before sending a close frame; since receiving a close frame triggers sending a close frame, this does the right thing regardless of which side started the :ref:`closing handshake `; also in :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates a few lines of code from ``write_close_frame()`` and ``write_frame()``; - ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost` which is always called exactly once when the TCP connection is closed. Coroutines .......... The following diagram shows which coroutines are running at each stage of the connection lifecycle on the client side. .. image:: lifecycle.svg :target: _images/lifecycle.svg The lifecycle is identical on the server side, except inversion of control makes the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts two tasks: - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` consume it. It may be canceled to terminate the connection. It never exits with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer ` below. - :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are received. It is canceled when the connection terminates. It never exits with an exception other than :exc:`~asyncio.CancelledError`. - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for the data transfer to terminate, then takes care of closing the TCP connection. It must not be canceled. It never exits with an exception. See :ref:`connection termination ` below. Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts the same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when the opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee that websockets can terminate connections: - within a fixed timeout, - without leaking pending tasks, - without leaking open TCP connections, regardless of whether the connection terminates normally or abnormally. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no more data will be received on the connection. Under normal circumstances, it exits after exchanging close frames. :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when the TCP connection is closed. .. _opening-handshake: Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket connection. On the client side, :meth:`~client.connect` executes it before returning the protocol to the caller. On the server side, it's executed before passing the protocol to the ``ws_handler`` coroutine handling the connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — websockets aims at keeping the implementation of both sides consistent with one another. On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: - builds an HTTP request based on the ``uri`` and parameters passed to :meth:`~client.connect`; - writes the HTTP request to the network; - reads an HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads an HTTP request from the network; - calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort the WebSocket handshake and return an HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds an HTTP response based on the above and parameters passed to :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. The most significant asymmetry between the two sides of the opening handshake lies in the negotiation of extensions and, to a lesser extent, of the subprotocol. The server knows everything about both sides and decides what the parameters should be for the connection. The client merely applies them. If anything goes wrong during the opening handshake, websockets :ref:`fails the connection `. .. _data-transfer: Data transfer ------------- Symmetry ........ Once the opening handshake has completed, the WebSocket protocol enters the data transfer phase. This part is almost symmetrical. There are only two differences between a server and a client: - `client-to-server masking`_: the client masks outgoing frames; the server unmasks incoming frames; - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. .. _client-to-server masking: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.3 .. _closing the TCP connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~protocol.WebSocketCommonProtocol`. .. _data framing: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5 .. _sending and receiving data: https://datatracker.ietf.org/doc/html/rfc6455.html#section-6 .. _closing the connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7 The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the :attr:`~server.WebSocketServerProtocol` and :attr:`~client.WebSocketClientProtocol` classes. Data flow ......... The following diagram shows how data flows between an application built on top of websockets and a remote endpoint. It applies regardless of which side is the server or the client. .. image:: protocol.svg :target: _images/protocol.svg Public methods are shown in green, private methods in yellow, and buffers in orange. Methods related to connection termination are omitted; connection termination is discussed in another section below. Receiving data .............. The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started when the WebSocket connection is established, processes this data. When it receives data frames, it reassembles fragments and puts the resulting messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: - if it's a close frame, it starts the closing handshake; - if it's a ping frame, it answers with a pong frame; - if it's a pong frame, it acknowledges the corresponding ping (unless it's an unsolicited pong). Running this process in a task guarantees that control frames are processed promptly. Without such a task, websockets would depend on the application to drive the connection by having exactly one coroutine awaiting :meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this happens naturally in many use cases, it cannot be relied upon. Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling backpressure and termination correctly. Sending data ............ The right side of the diagram shows how websockets sends data. :meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data frames containing the message. While sending a fragmented message, concurrent calls to :meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until all fragments are sent. This makes concurrent calls safe. :meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a :class:`~asyncio.Future` which will be completed when a matching pong frame is received. :meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. :meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and waits for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to implement flow control and provide backpressure from the TCP connection. .. _closing-handshake: Closing handshake ................. When the other side of the connection initiates the closing handshake, :meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close frame, and returns :obj:`None`, causing :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with :meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, :meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns :obj:`None`, also causing :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close timeout, websockets :ref:`fails the connection `. The closing handshake can take up to ``2 * close_timeout``: one ``close_timeout`` to write a close frame and one ``close_timeout`` to receive a close frame. Then websockets terminates the TCP connection. .. _connection-termination: Connection termination ---------------------- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is started when the WebSocket connection is established, is responsible for eventually closing the TCP connection. First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, which may happen as a result of: - a successful closing handshake: as explained above, this exits the infinite loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a timeout while waiting for the closing handshake to complete: this cancels :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the connection ` with a suitable code and exits. :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it easier to implement the timeout on the closing handshake. Canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk of canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and failing to close the TCP connection, thus leaking resources. Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no protocol compliance responsibilities. Terminating it to avoid leaking it is the only concern. Terminating the TCP connection can take up to ``2 * close_timeout`` on the server side and ``3 * close_timeout`` on the client side. Clients start by waiting for the server to close the connection, hence the extra ``close_timeout``. Then both sides go through the following steps until the TCP connection is lost: half-closing the connection (only for non-TLS connections), closing the connection, aborting the connection. At this point the connection drops regardless of what happens on the network. .. _connection-failure: Connection failure ------------------ If the opening handshake doesn't complete successfully, websockets fails the connection by closing the TCP connection. Once the opening handshake has completed, websockets fails the connection by canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and sending a close frame if appropriate. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking :attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes the TCP connection. .. _server-shutdown: Server shutdown --------------- :class:`~server.WebSocketServer` closes asynchronously like :class:`asyncio.Server`. The shutdown happen in two steps: 1. Stop listening and accepting new connections; 2. Close established connections with close code 1001 (going away) or, if the opening handshake is still in progress, with HTTP status code 503 (Service Unavailable). The first call to :class:`~server.WebSocketServer.close` starts a task that performs this sequence. Further calls are ignored. This is the easiest way to make :class:`~server.WebSocketServer.close` and :class:`~server.WebSocketServer.wait_closed` idempotent. .. _cancellation: Cancellation ------------ User code ......... websockets provides a WebSocket application server. It manages connections and passes them to user-provided connection handlers. This is an *inversion of control* scenario: library code calls user code. If a connection drops, the corresponding handler should terminate. If the server shuts down, all connection handlers must terminate. Canceling connection handlers would terminate them. However, using cancellation for this purpose would require all connection handlers to handle it properly. For example, if a connection handler starts some tasks, it should catch :exc:`~asyncio.CancelledError`, terminate or cancel these tasks, and then re-raise the exception. Cancellation is tricky in :mod:`asyncio` applications, especially when it interacts with finalization logic. In the example above, what if a handler gets interrupted with :exc:`~asyncio.CancelledError` while it's finalizing the tasks it started, after detecting that the connection dropped? websockets considers that cancellation may only be triggered by the caller of a coroutine when it doesn't care about the results of that coroutine anymore. (Source: `Guido van Rossum `_). Since connection handlers run arbitrary user code, websockets has no way of deciding whether that code is still doing something worth caring about. For these reasons, websockets never cancels connection handlers. Instead it expects them to detect when the connection is closed, execute finalization logic if needed, and exit. Conversely, cancellation isn't a concern for WebSocket clients because they don't involve inversion of control. Library ....... Most :doc:`public APIs <../reference/index>` of websockets are coroutines. They may be canceled, for example if the user starts a task that calls these coroutines and cancels the task later. websockets must handle this situation. Cancellation during the opening handshake is handled like any other exception: the TCP connection is closed and the exception is re-raised. This can only happen on the client side. On the server side, the opening handshake is managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. :meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting on two futures in parallel. As a consequence, even though it's waiting on a :class:`~asyncio.Future` signaling the next message and on :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. :meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by :meth:`~protocol.WebSocketCommonProtocol.send`, :meth:`~protocol.WebSocketCommonProtocol.ping`, and :meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. :meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is canceled, which is correct at this point. :meth:`~protocol.WebSocketCommonProtocol.close` then waits for :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. :meth:`~protocol.WebSocketCommonProtocol.close` and :meth:`~protocol.WebSocketCommonProtocol.fail_connection` are the only places where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may be canceled. :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` starts by waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating to :attr:`~protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: Backpressure ------------ .. note:: This section discusses backpressure from the perspective of a server but the concept applies to clients symmetrically. With a naive implementation, if a server receives inputs faster than it can process them, or if it generates outputs faster than it can send them, data accumulates in buffers, eventually causing the server to run out of memory and crash. The solution to this problem is backpressure. Any part of the server that receives inputs faster than it can process them and send the outputs must propagate that information back to the previous part in the chain. websockets is designed to make it easy to get backpressure right. For incoming data, websockets builds upon :class:`~asyncio.StreamReader` which propagates backpressure to its own buffer and to the TCP stream. Frames are parsed from the input stream and added to a bounded queue. If the queue fills up, parsing halts until the application reads a frame. For outgoing data, websockets builds upon :class:`~asyncio.StreamWriter` which implements flow control. If the output buffers grow too large, it waits until they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. Concurrency ----------- Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, :meth:`~protocol.WebSocketCommonProtocol.send`, :meth:`~protocol.WebSocketCommonProtocol.close` :meth:`~protocol.WebSocketCommonProtocol.ping`, or :meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method, with one exception and one limitation. * **Only one coroutine can receive messages at a time.** This constraint avoids non-deterministic behavior (and simplifies the implementation). If a coroutine is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, awaiting it again in another coroutine raises :exc:`RuntimeError`. * **Sending a fragmented message forces serialization.** Indeed, the WebSocket protocol doesn't support multiplexing messages. If a coroutine is awaiting :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, awaiting it again in another coroutine waits until the first call completes. This will be transparent in many cases. It may be a concern if the fragmented message is generated slowly by an asynchronous iterator. Receiving frames is independent from sending frames. This isolates :meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from the other methods, which send frames. While the connection is open, each frame is sent with a single write. Combined with the concurrency model of :mod:`asyncio`, this enforces serialization. The only other requirement is to prevent interleaving other data frames in the middle of a fragmented message. After the connection is closed, sending a frame raises :exc:`~websockets.exceptions.ConnectionClosed`, which is safe. websockets-15.0.1/docs/topics/index.rst000066400000000000000000000007071476212450300200510ustar00rootroot00000000000000Topic guides ============ These documents discuss how websockets is designed and how to make the best of its features when building applications. .. toctree:: :maxdepth: 2 authentication broadcast logging proxies routing These guides describe how to optimize the configuration of websockets applications for performance and reliability. .. toctree:: :maxdepth: 2 compression keepalive memory security performance websockets-15.0.1/docs/topics/keepalive.rst000066400000000000000000000146631476212450300207150ustar00rootroot00000000000000Keepalive and latency ===================== .. currentmodule:: websockets Long-lived connections ---------------------- Since the WebSocket protocol is intended for real-time communications over long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. Connections can drop as a consequence of temporary network connectivity issues, which are very common, even within data centers. Furthermore, WebSocket builds on top of HTTP/1.1 where connections are short-lived, even with ``Connection: keep-alive``. Typically, HTTP/1.1 infrastructure closes idle connections after 30 to 120 seconds. As a consequence, proxies may terminate WebSocket connections prematurely when no message was exchanged in 30 seconds. .. _keepalive: Keepalive in websockets ----------------------- To avoid these problems, websockets runs a keepalive and heartbeat mechanism based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 It sends a Ping frame every 20 seconds. It expects a Pong frame in return within 20 seconds. Else, it considers the connection broken and terminates it. This mechanism serves three purposes: 1. It creates a trickle of traffic so that the TCP connection isn't idle and network infrastructure along the path keeps it open ("keepalive"). 2. It detects if the connection drops or becomes so slow that it's unusable in practice ("heartbeat"). In that case, it terminates the connection and your application gets a :exc:`~exceptions.ConnectionClosed` exception. 3. It measures the :attr:`~asyncio.connection.Connection.latency` of the connection. The time between sending a Ping frame and receiving a matching Pong frame approximates the round-trip time. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. Shorter values will detect connection drops faster but they will increase network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and heartbeat mechanism, including measurement of latency. Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables keepalive, to keep idle connections open, and disables heartbeat, to support large latency spikes. .. admonition:: Why doesn't websockets rely on TCP keepalive? :class: hint TCP keepalive is disabled by default on most operating systems. When enabled, the default interval is two hours or more, which is far too much. Keepalive in browsers --------------------- Browsers don't enable a keepalive mechanism like websockets by default. As a consequence, they can fail to notice that a WebSocket connection is broken for an extended period of time, until the TCP connection times out. In this scenario, the ``WebSocket`` object in the browser doesn't fire a ``close`` event. If you have a reconnection mechanism, it doesn't kick in because it believes that the connection is still working. If your browser-based app mysteriously and randomly fails to receive events, this is a likely cause. You need a keepalive mechanism in the browser to avoid this scenario. Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and Pong functionality in the WebSocket protocol. You have to roll your own in the application layer. Read this `blog post `_ for a complete walk-through of this issue. Application-level keepalive --------------------------- Some servers require clients to send a keepalive message with a specific content at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, meaning that you must send them with :attr:`~asyncio.connection.Connection.send` rather than :attr:`~asyncio.connection.Connection.ping`. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.6 In websockets, such keepalive mechanisms are considered as application-level because they rely on data frames. That's unlike the protocol-level keepalive based on control frames. Therefore, it's your responsibility to implement the required behavior. You can run a task in the background to send keepalive messages: .. code-block:: python import itertools import json from websockets.exceptions import ConnectionClosed async def keepalive(websocket, ping_interval=30): for ping in itertools.count(): await asyncio.sleep(ping_interval) try: await websocket.send(json.dumps({"ping": ping})) except ConnectionClosed: break async def main(): async with connect(...) as websocket: keepalive_task = asyncio.create_task(keepalive(websocket)) try: ... # your application logic goes here finally: keepalive_task.cancel() Latency issues -------------- The :attr:`~asyncio.connection.Connection.latency` attribute stores latency measured during the last exchange of Ping and Pong frames:: latency = websocket.latency Alternatively, you can measure the latency at any time by calling :attr:`~asyncio.connection.Connection.ping` and awaiting its result:: pong_waiter = await websocket.ping() latency = await pong_waiter Latency between a client and a server may increase for two reasons: * Network connectivity is poor. When network packets are lost, TCP attempts to retransmit them, which manifests as latency. Excessive packet loss makes the connection unusable in practice. At some point, timing out is a reasonable choice. * Traffic is high. For example, if a client sends messages on the connection faster than a server can process them, this manifests as latency as well, because data is waiting in :doc:`buffers `. If the server is more than 20 seconds behind, it doesn't see the Pong before the default timeout elapses. As a consequence, it closes the connection. This is a reasonable choice to prevent overload. If traffic spikes cause unwanted timeouts and you're confident that the server will catch up eventually, you can increase ``ping_timeout`` or you can set it to :obj:`None` to disable heartbeat entirely. The same reasoning applies to situations where the server sends more traffic than the client can accept. websockets-15.0.1/docs/topics/lifecycle.graffle000066400000000000000000000060761476212450300215040ustar00rootroot00000000000000][W8~~}G%a.0KCNJiiߒs5ncUI咾}틌b/ ~Y:Zd ~Y}uZ?vj#ߋjG#_Z;` VkpU[=OћVRCEp$2[MU(f{AzNƅl;EW2\uB$r Mњ]XlϱtɎ"[]'<&HDx}7ZS5B'OL9ɑlfYOD'as'h,[-۹PE.7:9٧d" izM&_L, *mߎ2K,;L$ Y'eF9xYYj.yٗX`Gyq@sE̮c1 Km;bG70Se|Ef/YkaAP WeJ;#ہ,M"=0KV0c?mIE<:o(~cp*0*󷯆2=@U*xȷ:W¡9Cm~]WϗݫQif: 2%NwM <^ k`2}0H~+F'/2Ts\kS*WO7<&&{vs(yF[xJ%&Eq­ՑjXX:% ƅ!f`[#-D-$y]݌&)Lr;wfkM[E톗yMem}pVVݎafu1M͗3a=۹*m?NҜk{Q]<79/POZrͶlf`=LkM#[Y&%8 *ܤsX)a0Ma0 62BJKb@ii"!8u;Hpy2:ǝwݥnz=(PE֫wbQlC)&$gnn1*AЖ{sF,؊٤ O7Xb Ģ&C 04ʽۻ0 dtb8 BpX h&Q~M7K%*]%Sd/Jg8& m8g#K>+w( d}+wݟQ`{t+8)տz\~odGn/8_Iz:Bӓ0RC$$ $N$Pc(׋;G78OdзUC++1zPv!/Tx}[ :GQrH Og=uiV})3tF Q %«c/ugOOt>^ǂ.͏awLęRD#FI@JdfE>sXt( C0"V)8aA+hkL˄&LsMc\bb!LxAnDz`D|W'OH"Neځ#,np֒ӗQopvCӹw'Øb#bppO 2 YcYF:j` :AV nzD܂S8XBqk6mpqd0Oʦ^Ls*7h=0i>'ak(қkZE;rKW>àG֌$ʑ 2lQ]hnRaLMb*_|Ee(*d2%4csvz)H:_ӊR44ҁ/=bb u`]jʔ] n\]] rWWӸ׸b:Rfg E:f&5-Hݯ7 Ͳ[4z Ց-rOo&Pub~-9H+Pz1[57q=iVt5[I-{ {[^ۂC'M.sSU_o{OIO{#ǰtN̤p<rj4yAENieYwANw9/9z̽GNw97Vq",A *M~N"}!,Ŀa&:;8Zywy9>Z/o1] 0FjnhԗԣDŽ)q9T N_G=:<.TKW<7?ZSܓZ-VYWy0ewQne*{Y۩la Q /n'v2> &.{g}KŤ#!@ƕ[#܍ERaq7dDEmߺ z,ɭShJ~վZi#8.j m8N7v sk/$%u yNnXUj)mGvK5KRz?s7xAU7L*jOo5{ AnjVl) xODR)L.Ñ}l "T6]u9UT?GŻO,FZG:fo{cm';~&r':$`C&/2M[PM ͋qs $VY}?"YI {a^W;Tޭ-h#-^ئo=oG }w=z?mHe.a 4~/,'Q2*7 1mT5ON/ZLa4o7aj5S[ UDH?QAS7=EEVQ XwH= mIxG3wj&(*!ԩ[n%Yv;c5ox /If5^]Ŧ|{ET8&Lߝ>MT2; )*})k4kjzly Produced by OmniGraffle 6.6.2 2018-07-29 15:25:34 +0000Canvas 1Layer 1CONNECTINGOPENCLOSINGCLOSEDtransfer_dataclose_connectionconnectrecv / send / ping / pong / close opening handshakeconnectionterminationdata transfer& closing handshakekeepalive_ping websockets-15.0.1/docs/topics/logging.rst000066400000000000000000000207221476212450300203670ustar00rootroot00000000000000Logging ======= .. currentmodule:: websockets Logs contents ------------- When you run a WebSocket client, your code calls coroutines provided by websockets. If an error occurs, websockets tells you by raising an exception. For example, it raises a :exc:`~exceptions.ConnectionClosed` exception if the other side closes the connection. When you run a WebSocket server, websockets accepts connections, performs the opening handshake, runs the connection handler coroutine that you provided, and performs the closing handshake. Given this `inversion of control`_, if an error happens in the opening handshake or if the connection handler crashes, there is no way to raise an exception that you can handle. .. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control Logs tell you about these errors. Besides errors, you may want to record the activity of the server. In a request/response protocol such as HTTP, there's an obvious way to record activity: log one event per request/response. Unfortunately, this solution doesn't work well for a bidirectional protocol such as WebSocket. Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. .. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455.html#section-4 .. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7.1.4 By default, websockets doesn't log an event for every message. That would be excessive for many applications exchanging small messages at a fast rate. If you need this level of detail, you could add logging in your own code. Finally, you can enable debug logs to get details about everything websockets is doing. This can be useful when developing clients as well as servers. See :ref:`log levels ` below for a list of events logged by websockets logs at each log level. Configure logging ----------------- websockets relies on the :mod:`logging` module from the standard library in order to maximize compatibility and integrate nicely with other libraries:: import logging websockets logs to the ``"websockets.client"`` and ``"websockets.server"`` loggers. websockets doesn't provide a default logging configuration because requirements vary a lot depending on the environment. Here's a basic configuration for a server in production:: logging.basicConfig( format="%(asctime)s %(message)s", level=logging.INFO, ) Here's how to enable debug logs for development:: logging.basicConfig( format="%(asctime)s %(message)s", level=logging.DEBUG, ) By default, websockets elides the content of messages to improve readability. If you want to see more, you can increase the :envvar:`WEBSOCKETS_MAX_LOG_SIZE` environment variable. The default value is 75. Furthermore, websockets adds a ``websocket`` attribute to log records, so you can include additional information about the current connection in logs. You could attempt to add information with a formatter:: # this doesn't work! logging.basicConfig( format="{asctime} {websocket.id} {websocket.remote_address[0]} {message}", level=logging.INFO, style="{", ) However, this technique runs into two problems: * The formatter applies to all records. It will crash if it receives a record without a ``websocket`` attribute. For example, this happens when logging that the server starts because there is no current connection. * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. There's a better way. :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` accept a ``logger`` argument to override the default :class:`~logging.Logger`. You can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives the IP address of the proxy, which isn't useful. IP addresses of clients are provided in an HTTP header set by the proxy. Here's how to include them in logs, assuming they're in the ``X-Forwarded-For`` header:: logging.basicConfig( format="%(asctime)s %(message)s", level=logging.INFO, ) class LoggerAdapter(logging.LoggerAdapter): """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] except KeyError: # log entry not coming from a connection return msg, kwargs if websocket.request is None: # opening handshake not complete return msg, kwargs xff = headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... Logging to JSON --------------- Even though :mod:`logging` predates structured logging, it's still possible to output logs as JSON with a bit of effort. First, we need a :class:`~logging.Formatter` that renders JSON: .. literalinclude:: ../../experiments/json_log_formatter.py Then, we configure logging to apply this formatter:: handler = logging.StreamHandler() handler.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(handler) logger.setLevel(logging.INFO) Finally, we populate the ``event_data`` custom attribute in log records with a :class:`~logging.LoggerAdapter`:: class LoggerAdapter(logging.LoggerAdapter): """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] except KeyError: return msg, kwargs event_data = {"connection_id": str(websocket.id)} if websocket.request is not None: # opening handshake complete headers = websocket.request.headers event_data["remote_addr"] = headers.get("X-Forwarded-For") kwargs["extra"]["event_data"] = event_data return msg, kwargs async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... Disable logging --------------- If your application doesn't configure :mod:`logging`, Python outputs messages of severity ``WARNING`` and higher to :data:`~sys.stderr`. As a consequence, you will see a message and a stack trace if a connection handler coroutine crashes or if you hit a bug in websockets. If you want to disable this behavior for websockets, you can add a :class:`~logging.NullHandler`:: logging.getLogger("websockets").addHandler(logging.NullHandler()) Additionally, if your application configures :mod:`logging`, you must disable propagation to the root logger, or else its handlers could output logs:: logging.getLogger("websockets").propagate = False Alternatively, you could set the log level to ``CRITICAL`` for the ``"websockets"`` logger, as the highest level currently used is ``ERROR``:: logging.getLogger("websockets").setLevel(logging.CRITICAL) Or you could configure a filter to drop all messages:: logging.getLogger("websockets").addFilter(lambda record: None) .. _log-levels: Log levels ---------- Here's what websockets logs at each level. ``ERROR`` ......... * Exceptions raised by your code in servers * connection handler coroutines * ``select_subprotocol`` callbacks * ``process_request`` and ``process_response`` callbacks * Exceptions resulting from bugs in websockets ``WARNING`` ........... * Failures in :func:`~asyncio.server.broadcast` ``INFO`` ........ * Server starting and stopping * Server establishing and closing connections * Client reconnecting automatically ``DEBUG`` ......... * Changes to the state of connections * Handshake requests and responses * All frames sent and received * Steps to close a connection * Keepalive pings and pongs * Errors handled transparently Debug messages have cute prefixes that make logs easier to scan: * ``>`` - send something * ``<`` - receive something * ``=`` - set connection state * ``x`` - shut down connection * ``%`` - manage pings and pongs * ``-`` - timeout * ``!`` - error, with a traceback websockets-15.0.1/docs/topics/memory.rst000066400000000000000000000152161476212450300202530ustar00rootroot00000000000000Memory and buffers ================== .. currentmodule:: websockets In most cases, memory usage of a WebSocket server is proportional to the number of open connections. When a server handles thousands of connections, memory usage can become a bottleneck. Memory usage of a single connection is the sum of: 1. the baseline amount of memory that websockets uses for each connection; 2. the amount of memory needed by your application code; 3. the amount of data held in buffers. Connection ---------- Compression settings are the primary factor affecting how much memory each connection uses. The :mod:`asyncio` implementation with default settings uses 64 KiB of memory for each connection. You can reduce memory usage to 14 KiB per connection if you disable compression entirely. Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. Application ----------- Your application will allocate memory for its data structures. Memory usage depends on your use case and your implementation. Make sure that you don't keep references to data that you don't need anymore because this prevents garbage collection. Buffers ------- Typical WebSocket applications exchange small messages at a rate that doesn't saturate the CPU or the network. Buffers are almost always empty. This is the optimal situation. Buffers absorb bursts of incoming or outgoing messages without having to pause reading or writing. If the application receives messages faster than it can process them, receive buffers will fill up when. If the application sends messages faster than the network can transmit them, send buffers will fill up. When buffers are almost always full, not only does the additional memory usage fail to bring any benefit, but latency degrades as well. This problem is called bufferbloat_. If it cannot be resolved by adding capacity, typically because the system is bottlenecked by its output and constantly regulated by :ref:`backpressure `, then buffers should be kept small to ensure that backpressure kicks in quickly. .. _bufferbloat: https://en.wikipedia.org/wiki/Bufferbloat To sum up, buffers should be sized to absorb bursts of messages. Making them larger than necessary often causes more harm than good. There are three levels of buffering in an application built with websockets. TCP buffers ........... The operating system allocates buffers for each TCP connection. The receive buffer stores data received from the network until the application reads it. The send buffer stores data written by the application until it's sent to the network and acknowledged by the recipient. Modern operating systems adjust the size of TCP buffers automatically to match network conditions. Overall, you shouldn't worry about TCP buffers. Just be aware that they exist. In very high throughput scenarios, TCP buffers may grow to several megabytes to store the data in flight. Then, they can make up the bulk of the memory usage of a connection. I/O library buffers ................... I/O libraries like :mod:`asyncio` may provide read and write buffers to reduce the frequency of system calls or the need to pause reading or writing. You should keep these buffers small. Increasing them can help with spiky workloads but it can also backfire because it delays backpressure. * In the new :mod:`asyncio` implementation, there is no library-level read buffer. There is a write buffer. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` controls its size. When the write buffer grows above the high-water mark, :meth:`~asyncio.connection.Connection.send` waits until it drains under the low-water mark to return. This creates backpressure on coroutines that send messages. * In the legacy :mod:`asyncio` implementation, there is a library-level read buffer. The ``read_limit`` argument of :func:`~legacy.client.connect` and :func:`~legacy.server.serve` controls its size. When the read buffer grows above the high-water mark, the connection stops reading from the network until it drains under the low-water mark. This creates backpressure on the TCP connection. There is a write buffer. It as controlled by ``write_limit``. It behaves like the new :mod:`asyncio` implementation described above. * In the :mod:`threading` implementation, there are no library-level buffers. All I/O operations are performed directly on the :class:`~socket.socket`. websockets' buffers ................... Incoming messages are queued in a buffer after they have been received from the network and parsed. A larger buffer may help a slow applications handle bursts of messages while remaining responsive to control frames. The memory footprint of this buffer is bounded by the product of ``max_size``, which controls the size of items in the queue, and ``max_queue``, which controls the number of items. The ``max_size`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 1 MiB. Most applications never receive such large messages. Configuring a smaller value puts a tighter boundary on memory usage. This can make your application more resilient to denial of service attacks. The behavior of the ``max_queue`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` varies across implementations. * In the new :mod:`asyncio` implementation, ``max_queue`` is the high-water mark of a queue of incoming frames. It defaults to 16 frames. If the queue grows larger, the connection stops reading from the network until the application consumes messages and the queue goes below the low-water mark. This creates backpressure on the TCP connection. Each item in the queue is a frame. A frame can be a message or a message fragment. Either way, it must be smaller than ``max_size``, the maximum size of a message. The queue may use up to ``max_size * max_queue`` bytes of memory. By default, this is 16 MiB. * In the legacy :mod:`asyncio` implementation, ``max_queue`` is the maximum size of a queue of incoming messages. It defaults to 32 messages. If the queue fills up, the connection stops reading from the library-level read buffer described above. If that buffer fills up as well, it will create backpressure on the TCP connection. Text messages are decoded before they're added to the queue. Since Python can use up to 4 bytes of memory per character, the queue may use up to ``4 * max_size * max_queue`` bytes of memory. By default, this is 128 MiB. * In the :mod:`threading` implementation, there is no queue of incoming messages. The ``max_queue`` argument doesn't exist. The connection keeps at most one message in memory at a time. websockets-15.0.1/docs/topics/performance.rst000066400000000000000000000006721476212450300212440ustar00rootroot00000000000000Performance =========== .. currentmodule:: websockets Here are tips to optimize performance. uvloop ------ You can make a websockets application faster by running it with uvloop_. (This advice isn't specific to websockets. It applies to any :mod:`asyncio` application.) .. _uvloop: https://github.com/MagicStack/uvloop broadcast --------- :func:`~asyncio.server.broadcast` is the most efficient way to send a message to many clients. websockets-15.0.1/docs/topics/protocol.graffle000066400000000000000000000112041476212450300213730ustar00rootroot00000000000000]ks6 [Ӹqtnxc'v<ӡ%X"UfJx-JۉlA\q_FQN0|=~2Ꮟ#'ѯ/z(g_Ǒ}~wQ}qY=??C!vd5t A>x ^|;G=/~\U<_obzӧ寗Ma?ȡӞ獂4 /=)KFq8L{ + NO#-HJX(P>Dd&y7$Ӊ޽,2pY؟͞}Ff(ka|ob=2`b? f$ w5y|A9}bN zJ'\*dp0y7HwA @E~5$iHa75Vm猯J|Wrtw_.CR$U)Qʐ|I)Xq}J ǂbd\} 5#JH+S /oN+4f5; yyv|i\ڑK Ҵ2[4ZϢigJoW<AB|DY!CDGAن1g! ϯΚ@\g* *Ʃ1+*1=)/12#$UT(.a()aQ&>RJoߚ¾gX6$ uFrԙ\,,\;N(&޸Ug5>+%o劀/2T}c3  +hqAi~8 !&~K=z| +7?O(4>>isԫ|~%)4qaul~9&> /?81&Ň U c M[ 7L AdYӠZ zxqN`o~"+7tJK"|`Rt:t"v{p@4?KIDςxlF`SK/]66œ'9o) pyԽ|\~fAdǕd*K|fՍB_#5 n MPkZmNᵲa .f̸=*v,٤?a!BVc_?gYky"2PJ&4f/<DǏnU tS&5fQS{aH FX>v=fĩv6-p+5L8·*j]%l- )@Gmȅeqΰ5|ykǍ8Y;BM"&S˜֎;F\M +#.i Xu:)ڛggnqS淍``mb)}čJ6m$W0H" 1_P, W!{ }mlj{ɖ@1,i1@n;\|!_Ta!trH;ϒh=TŤNhl cWV֋O|myqUY$7 ($'sՄ鏘}V_CKKlKwӎ1V_h5[DfwwCHbE'&ZfѢ(e[և\4 wrX'VA1@vmpF3 ֐I'K:bAp?%B.xxbhXI¤s):rb'XW, {>%DA%gJ:L2A<`LʀB|ݐ?SC(e0 r OTdA-KzIM@w&u64+u,c1]XpJ|̨X,n3$c{+ZY#šԝ-4hl uOϖ.+yɱ}t'יuNui}x\殺 |A}_ N);V.Q"38ج [1YpmKҟW߅7\_n] E\tEpt-][K!7{{FT38fG:Ew`|G+0[[GZ:sK;|P 0_V=/=LFjJjME=VqYر x;͗u4ejFZ3mf:1tS{HgY0MB6,||Rh9Lu,A\fQ`N}*AA|y ZH*mB7#杫*IR0u8i]u]ǹsm:LXv=aBu,&7]CYEݺD#SX_xh`[UzOXY\_$֊{\H ORM 偙!1CmtmRpv5 :oc= L:riy$P0<HR%WBZ6='8J"d, RИ3./X0׿'smm[N ,u»ݩypC' wǻ; awut=Y9s@瀮#Н\䷈s}܄m`G\P&tn‡/1>KޠЊͤVm8v;u8p47Am=iHK:uH&gNaʠ,ԫ"*PJh xAׇjٕ%Ʒd[5y8_#f #jN+8vi s.0W1\a`>Gpjbk|mk8xzJy~cE\wJ$#4)@Y;׽!*GŴ IXn曳7߶ 7L|SHh+vHh2,ۥp Զn"ugQ,|eJտjRwr?-aXT`HܽJY-OXA8ɚ 3u胪Llcɒ'0 G}rX|FI|&9atI5\]T8KΨp=Ns*}BT\t<l 8<ʠ31dՖ],Ijm>֐|ܓ̴Z۳q0Ύ/jtڈyC+?YxX9CP;BP۴Az|(LB;5h|odJ&YmzI~ZuXYa|n86K|/GA +2WaO~ aء|ւPV]VZA|ø9A7 Oވ:69'nX_ay2;W {6Ѭ@R`g\K2OgFқ vYnfq.]t{{$ø s'z9vӹsM5}:nn7ϒ Produced by OmniGraffle 6.6.2 2019-07-07 08:38:24 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesdequerecvsendpingpongclosecontrolframesbytesframes websockets-15.0.1/docs/topics/proxies.rst000066400000000000000000000064071476212450300204360ustar00rootroot00000000000000Proxies ======= .. currentmodule:: websockets If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. Configuration ------------- First, if the server is in the proxy bypass list of the operating system or in the ``no_proxy`` environment variable, websockets connects directly. Then, it looks for a proxy in the following locations: 1. The ``wss_proxy`` or ``ws_proxy`` environment variables for ``wss://`` and ``ws://`` connections respectively. They allow configuring a specific proxy for WebSocket connections. 2. A SOCKS proxy configured in the operating system. 3. An HTTP proxy configured in the operating system or in the ``https_proxy`` environment variable, for both ``wss://`` and ``ws://`` connections. 4. An HTTP proxy configured in the operating system or in the ``http_proxy`` environment variable, only for ``ws://`` connections. Finally, if no proxy is found, websockets connects directly. While environment variables are case-insensitive, the lower-case spelling is the most common, for `historical reasons`_, and recommended. .. _historical reasons: https://unix.stackexchange.com/questions/212894/ websockets authenticates automatically when the address of the proxy includes credentials e.g. ``http://user:password@proxy:8080/``. .. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. :class: tip For example, ``https_proxy=socks5h://proxy:1080/`` configures a SOCKS proxy for all WebSocket connections. Likewise, ``wss_proxy=http://proxy:8080/`` configures an HTTP proxy only for ``wss://`` connections. .. admonition:: What if websockets doesn't select the right proxy? :class: hint websockets relies on :func:`~urllib.request.getproxies()` to read the proxy configuration. Check that it returns what you expect. If it doesn't, review your proxy configuration. You can override the default configuration and configure a proxy explicitly with the ``proxy`` argument of :func:`~asyncio.client.connect`. Set ``proxy=None`` to disable the proxy. SOCKS proxies ------------- Connecting through a SOCKS proxy requires installing the third-party library `python-socks`_: .. code-block:: console $ pip install python-socks\[asyncio\] .. _python-socks: https://github.com/romis2012/python-socks python-socks supports SOCKS4, SOCKS4a, SOCKS5, and SOCKS5h. The protocol version is configured in the address of the proxy e.g. ``socks5h://proxy:1080/``. When a SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) but does not support other authentication methods such as GSSAPI (:rfc:`1961`). HTTP proxies ------------ When the address of the proxy starts with ``https://``, websockets secures the connection to the proxy with TLS. When the address of the server starts with ``wss://``, websockets secures the connection from the proxy to the server with TLS. These two options are compatible. TLS-in-TLS is supported. The documentation of :func:`~asyncio.client.connect` describes how to configure TLS from websockets to the proxy and from the proxy to the server. websockets supports proxy authentication with Basic Auth. websockets-15.0.1/docs/topics/routing.rst000066400000000000000000000056021476212450300204300ustar00rootroot00000000000000Routing ======= .. currentmodule:: websockets Many WebSocket servers provide just one endpoint. That's why :func:`~asyncio.server.serve` accepts a single connection handler as its first argument. This may come as a surprise to you if you're used to HTTP servers. In a standard HTTP application, each request gets dispatched to a handler based on the request path. Clients know which path to use for which operation. In a WebSocket application, clients open a persistent connection then they send all messages over that unique connection. When different messages correspond to different operations, they must be dispatched based on the message content. Simple routing -------------- If you need different handlers for different clients or different use cases, you may route each connection to the right handler based on the request path. Since WebSocket servers typically provide fewer routes than HTTP servers, you can keep it simple:: async def handler(websocket): match websocket.request.path: case "/blue": await blue_handler(websocket) case "/green": await green_handler(websocket) case _: # No handler for this path. Close the connection. return You may also route connections based on the first message received from the client, as demonstrated in the :doc:`tutorial <../intro/tutorial2>`:: import json async def handler(websocket): message = await websocket.recv() settings = json.loads(message) match settings["color"]: case "blue": await blue_handler(websocket) case "green": await green_handler(websocket) case _: # No handler for this message. Close the connection. return When you need to authenticate the connection before routing it, this pattern is more convenient. Complex routing --------------- If you have outgrow these simple patterns, websockets provides full-fledged routing based on the request path with :func:`~asyncio.router.route`. This feature builds upon Flask_'s router. To use it, you must install the third-party library `werkzeug`_: .. code-block:: console $ pip install werkzeug .. _Flask: https://flask.palletsprojects.com/ .. _werkzeug: https://werkzeug.palletsprojects.com/ :func:`~asyncio.router.route` expects a :class:`werkzeug.routing.Map` as its first argument to declare which URL patterns map to which handlers. Review the documentation of :mod:`werkzeug.routing` to learn about its functionality. To give you a sense of what's possible, here's the URL map of the example in `experiments/routing.py`_: .. _experiments/routing.py: https://github.com/python-websockets/websockets/blob/main/experiments/routing.py .. literalinclude:: ../../experiments/routing.py :start-at: url_map = Map( :end-at: await server.serve_forever() websockets-15.0.1/docs/topics/security.rst000066400000000000000000000045671476212450300206210ustar00rootroot00000000000000Security ======== .. currentmodule:: websockets Encryption ---------- In production, you should always secure WebSocket connections with TLS. Secure WebSocket connections provide confidentiality and integrity, as well as better reliability because they reduce the risk of interference by bad proxies. WebSocket servers are usually deployed behind a reverse proxy that terminates TLS. Else, you can :doc:`configure TLS <../howto/encryption>` for the server. Memory usage ------------ .. warning:: An attacker who can open an arbitrary number of connections will be able to perform a denial of service by memory exhaustion. If you're concerned by denial of service attacks, you must reject suspicious connections before they reach websockets, typically in a reverse proxy. With the default settings, opening a connection uses 70 KiB of memory. Sending some highly compressed messages could use up to 128 MiB of memory with an amplification factor of 1000 between network traffic and memory usage. Configuring a server to :doc:`optimize memory usage ` will improve security in addition to improving performance. HTTP limits ----------- In the opening handshake, websockets applies limits to the amount of data that it accepts in order to minimize exposure to denial of service attacks. The request or status line is limited to 8192 bytes. Each header line, including the name and value, is limited to 8192 bytes too. No more than 128 HTTP headers are allowed. When the HTTP response includes a body, it is limited to 1 MiB. You may change these limits by setting the :envvar:`WEBSOCKETS_MAX_LINE_LENGTH`, :envvar:`WEBSOCKETS_MAX_NUM_HEADERS`, and :envvar:`WEBSOCKETS_MAX_BODY_SIZE` environment variables respectively. Identification -------------- By default, websockets identifies itself with a ``Server`` or ``User-Agent`` header in the format ``"Python/x.y.z websockets/X.Y"``. You can set the ``server_header`` argument of :func:`~asyncio.server.serve` or the ``user_agent_header`` argument of :func:`~asyncio.client.connect` to configure another value. Setting them to :obj:`None` removes the header. Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and :envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them to an empty string removes the header. If both the argument and the environment variable are set, the argument takes precedence. websockets-15.0.1/example/000077500000000000000000000000001476212450300154065ustar00rootroot00000000000000websockets-15.0.1/example/asyncio/000077500000000000000000000000001476212450300170535ustar00rootroot00000000000000websockets-15.0.1/example/asyncio/client.py000066400000000000000000000007011476212450300207010ustar00rootroot00000000000000#!/usr/bin/env python """Client example using the asyncio API.""" import asyncio from websockets.asyncio.client import connect async def hello(): async with connect("ws://localhost:8765") as websocket: name = input("What's your name? ") await websocket.send(name) print(f">>> {name}") greeting = await websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": asyncio.run(hello()) websockets-15.0.1/example/asyncio/echo.py000077500000000000000000000006121476212450300203450ustar00rootroot00000000000000#!/usr/bin/env python """Echo server using the asyncio API.""" import asyncio from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): async with serve(echo, "localhost", 8765) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/asyncio/hello.py000077500000000000000000000005551476212450300205400ustar00rootroot00000000000000#!/usr/bin/env python """Client using the asyncio API.""" import asyncio from websockets.asyncio.client import connect async def hello(): async with connect("ws://localhost:8765") as websocket: await websocket.send("Hello world!") message = await websocket.recv() print(message) if __name__ == "__main__": asyncio.run(hello()) websockets-15.0.1/example/asyncio/server.py000066400000000000000000000007421476212450300207360ustar00rootroot00000000000000#!/usr/bin/env python """Server example using the asyncio API.""" import asyncio from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) print(f">>> {greeting}") async def main(): async with serve(hello, "localhost", 8765) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/000077500000000000000000000000001476212450300175665ustar00rootroot00000000000000websockets-15.0.1/example/deployment/fly/000077500000000000000000000000001476212450300203605ustar00rootroot00000000000000websockets-15.0.1/example/deployment/fly/Procfile000066400000000000000000000000231476212450300220410ustar00rootroot00000000000000web: python app.py websockets-15.0.1/example/deployment/fly/app.py000066400000000000000000000012001476212450300215030ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import http import signal from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) def health_check(connection, request): if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): async with serve(echo, "", 8080, process_request=health_check) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/fly/fly.toml000066400000000000000000000004231476212450300220460ustar00rootroot00000000000000app = "websockets-echo" kill_signal = "SIGTERM" [build] builder = "paketobuildpacks/builder:base" [[services]] internal_port = 8080 protocol = "tcp" [[services.http_checks]] path = "/healthz" [[services.ports]] handlers = ["tls", "http"] port = 443 websockets-15.0.1/example/deployment/fly/requirements.txt000066400000000000000000000000131476212450300236360ustar00rootroot00000000000000websockets websockets-15.0.1/example/deployment/haproxy/000077500000000000000000000000001476212450300212605ustar00rootroot00000000000000websockets-15.0.1/example/deployment/haproxy/app.py000066400000000000000000000010411476212450300224060ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import os import signal from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): port = 8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]) async with serve(echo, "localhost", port) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/haproxy/haproxy.cfg000066400000000000000000000006051476212450300234340ustar00rootroot00000000000000defaults mode http timeout connect 10s timeout client 30s timeout server 30s frontend websocket bind localhost:8080 default_backend websocket backend websocket balance leastconn server websockets-test_00 localhost:8000 server websockets-test_01 localhost:8001 server websockets-test_02 localhost:8002 server websockets-test_03 localhost:8003 websockets-15.0.1/example/deployment/haproxy/supervisord.conf000066400000000000000000000002231476212450300245110ustar00rootroot00000000000000[supervisord] [program:websockets-test] command = python app.py process_name = %(program_name)s_%(process_num)02d numprocs = 4 autorestart = true websockets-15.0.1/example/deployment/heroku/000077500000000000000000000000001476212450300210635ustar00rootroot00000000000000websockets-15.0.1/example/deployment/heroku/Procfile000066400000000000000000000000231476212450300225440ustar00rootroot00000000000000web: python app.py websockets-15.0.1/example/deployment/heroku/app.py000066400000000000000000000010021476212450300222060ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import signal import os from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): port = int(os.environ["PORT"]) async with serve(echo, "localhost", port) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/heroku/requirements.txt000066400000000000000000000000131476212450300243410ustar00rootroot00000000000000websockets websockets-15.0.1/example/deployment/koyeb/000077500000000000000000000000001476212450300206775ustar00rootroot00000000000000websockets-15.0.1/example/deployment/koyeb/Procfile000066400000000000000000000000231476212450300223600ustar00rootroot00000000000000web: python app.py websockets-15.0.1/example/deployment/koyeb/app.py000066400000000000000000000012551476212450300220340ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import http import os import signal from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) def health_check(connection, request): if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): port = int(os.environ["PORT"]) async with serve(echo, "", port, process_request=health_check) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/koyeb/requirements.txt000066400000000000000000000000131476212450300241550ustar00rootroot00000000000000websockets websockets-15.0.1/example/deployment/kubernetes/000077500000000000000000000000001476212450300217355ustar00rootroot00000000000000websockets-15.0.1/example/deployment/kubernetes/Dockerfile000066400000000000000000000001341476212450300237250ustar00rootroot00000000000000FROM python:3.9-alpine RUN pip install websockets COPY app.py . CMD ["python", "app.py"] websockets-15.0.1/example/deployment/kubernetes/app.py000077500000000000000000000023211476212450300230700ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import http import signal import sys import time from websockets.asyncio.server import serve async def slow_echo(websocket): async for message in websocket: # Block the event loop! This allows saturating a single asyncio # process without opening an impractical number of connections. time.sleep(0.1) # 100ms await websocket.send(message) def health_check(connection, request): if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") if request.path == "/inemuri": loop = asyncio.get_running_loop() loop.call_later(1, time.sleep, 10) return connection.respond(http.HTTPStatus.OK, "Sleeping for 10s\n") if request.path == "/seppuku": loop = asyncio.get_running_loop() loop.call_later(1, sys.exit, 69) return connection.respond(http.HTTPStatus.OK, "Terminating\n") async def main(): async with serve(slow_echo, "", 80, process_request=health_check) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/kubernetes/benchmark.py000077500000000000000000000012111476212450300242370ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import sys from websockets.asyncio.client import connect URI = "ws://localhost:32080" async def run(client_id, messages): async with connect(URI) as websocket: for message_id in range(messages): await websocket.send(f"{client_id}:{message_id}") await websocket.recv() async def benchmark(clients, messages): await asyncio.wait([ asyncio.create_task(run(client_id, messages)) for client_id in range(clients) ]) if __name__ == "__main__": clients, messages = int(sys.argv[1]), int(sys.argv[2]) asyncio.run(benchmark(clients, messages)) websockets-15.0.1/example/deployment/kubernetes/deployment.yaml000066400000000000000000000011641476212450300250030ustar00rootroot00000000000000apiVersion: v1 kind: Service metadata: name: websockets-test spec: type: NodePort ports: - port: 80 nodePort: 32080 selector: app: websockets-test --- apiVersion: apps/v1 kind: Deployment metadata: name: websockets-test spec: selector: matchLabels: app: websockets-test template: metadata: labels: app: websockets-test spec: containers: - name: websockets-test image: websockets-test:1.0 livenessProbe: httpGet: path: /healthz port: 80 periodSeconds: 1 ports: - containerPort: 80 websockets-15.0.1/example/deployment/nginx/000077500000000000000000000000001476212450300207115ustar00rootroot00000000000000websockets-15.0.1/example/deployment/nginx/app.py000066400000000000000000000010271476212450300220430ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import os import signal from websockets.asyncio.server import unix_serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): path = f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock" async with unix_serve(echo, path) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/nginx/nginx.conf000066400000000000000000000010271476212450300227030ustar00rootroot00000000000000daemon off; events { } http { server { listen localhost:8080; location / { proxy_http_version 1.1; proxy_pass http://websocket; proxy_set_header Connection $http_connection; proxy_set_header Upgrade $http_upgrade; } } upstream websocket { least_conn; server unix:websockets-test_00.sock; server unix:websockets-test_01.sock; server unix:websockets-test_02.sock; server unix:websockets-test_03.sock; } } websockets-15.0.1/example/deployment/nginx/supervisord.conf000066400000000000000000000002231476212450300241420ustar00rootroot00000000000000[supervisord] [program:websockets-test] command = python app.py process_name = %(program_name)s_%(process_num)02d numprocs = 4 autorestart = true websockets-15.0.1/example/deployment/render/000077500000000000000000000000001476212450300210455ustar00rootroot00000000000000websockets-15.0.1/example/deployment/render/app.py000066400000000000000000000012001476212450300221700ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import http import signal from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) def health_check(connection, request): if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): async with serve(echo, "", 8080, process_request=health_check) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/render/requirements.txt000066400000000000000000000000131476212450300243230ustar00rootroot00000000000000websockets websockets-15.0.1/example/deployment/supervisor/000077500000000000000000000000001476212450300220075ustar00rootroot00000000000000websockets-15.0.1/example/deployment/supervisor/app.py000066400000000000000000000007351476212450300231460ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import signal from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): async with serve(echo, "", 8080, reuse_port=True) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/deployment/supervisor/supervisord.conf000066400000000000000000000002231476212450300252400ustar00rootroot00000000000000[supervisord] [program:websockets-test] command = python app.py process_name = %(program_name)s_%(process_num)02d numprocs = 4 autorestart = true websockets-15.0.1/example/django/000077500000000000000000000000001476212450300166505ustar00rootroot00000000000000websockets-15.0.1/example/django/authentication.py000066400000000000000000000012001476212450300222320ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import django django.setup() from sesame.utils import get_user from websockets.asyncio.server import serve from websockets.frames import CloseCode async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return await websocket.send(f"Hello {user}!") async def main(): async with serve(handler, "localhost", 8888) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/django/notifications.py000066400000000000000000000041001476212450300220660ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import json import aioredis import django django.setup() from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user from websockets.asyncio.server import broadcast, serve from websockets.frames import CloseCode CONNECTIONS = {} def get_content_types(user): """Return the set of IDs of content types visible by user.""" # This does only three database queries because Django caches # all permissions on the first call to user.has_perm(...). return { ct.id for ct in ContentType.objects.all() if user.has_perm(f"{ct.app_label}.view_{ct.model}") or user.has_perm(f"{ct.app_label}.change_{ct.model}") } async def handler(websocket): """Authenticate user and register connection in CONNECTIONS.""" sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ct_ids = await asyncio.to_thread(get_content_types, user) CONNECTIONS[websocket] = {"content_type_ids": ct_ids} try: await websocket.wait_closed() finally: del CONNECTIONS[websocket] async def process_events(): """Listen to events in Redis and process them.""" redis = aioredis.from_url("redis://127.0.0.1:6379/1") pubsub = redis.pubsub() await pubsub.subscribe("events") async for message in pubsub.listen(): if message["type"] != "message": continue payload = message["data"].decode() # Broadcast event to all users who have permissions to see it. event = json.loads(payload) recipients = ( websocket for websocket, connection in CONNECTIONS.items() if event["content_type_id"] in connection["content_type_ids"] ) broadcast(recipients, payload) async def main(): async with serve(handler, "localhost", 8888): await process_events() # runs forever if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/django/signals.py000066400000000000000000000013371476212450300206660ustar00rootroot00000000000000import json from django.contrib.admin.models import LogEntry from django.db.models.signals import post_save from django.dispatch import receiver from django_redis import get_redis_connection @receiver(post_save, sender=LogEntry) def publish_event(instance, **kwargs): event = { "model": instance.content_type.name, "object": instance.object_repr, "message": instance.get_change_message(), "timestamp": instance.action_time.isoformat(), "user": str(instance.user), "content_type_id": instance.content_type_id, "object_id": instance.object_id, } connection = get_redis_connection("default") payload = json.dumps(event) connection.publish("events", payload) websockets-15.0.1/example/faq/000077500000000000000000000000001476212450300161555ustar00rootroot00000000000000websockets-15.0.1/example/faq/health_check_server.py000077500000000000000000000007741476212450300225320ustar00rootroot00000000000000#!/usr/bin/env python import asyncio from http import HTTPStatus from websockets.asyncio.server import serve def health_check(connection, request): if request.path == "/healthz": return connection.respond(HTTPStatus.OK, "OK\n") async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): async with serve(echo, "localhost", 8765, process_request=health_check) as server: await server.serve_forever() asyncio.run(main()) websockets-15.0.1/example/faq/shutdown_client.py000077500000000000000000000007611476212450300217470ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import signal from websockets.asyncio.client import connect async def client(): async with connect("ws://localhost:8765") as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: ... asyncio.run(client()) websockets-15.0.1/example/faq/shutdown_server.py000077500000000000000000000007261476212450300220000ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import signal from websockets.asyncio.server import serve async def handler(websocket): async for message in websocket: ... async def server(): async with serve(handler, "localhost", 8765) as server: # Close the server when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() asyncio.run(server()) websockets-15.0.1/example/legacy/000077500000000000000000000000001476212450300166525ustar00rootroot00000000000000websockets-15.0.1/example/legacy/basic_auth_client.py000077500000000000000000000005141476212450300226670ustar00rootroot00000000000000#!/usr/bin/env python # WS client example with HTTP Basic Authentication import asyncio from websockets.legacy.client import connect async def hello(): uri = "ws://mary:p@ssw0rd@localhost:8765" async with connect(uri) as websocket: greeting = await websocket.recv() print(greeting) asyncio.run(hello()) websockets-15.0.1/example/legacy/basic_auth_server.py000077500000000000000000000011461476212450300227210ustar00rootroot00000000000000#!/usr/bin/env python # Server example with HTTP Basic Authentication over TLS import asyncio from websockets.legacy.auth import basic_auth_protocol_factory from websockets.legacy.server import serve async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) async def main(): async with serve( hello, "localhost", 8765, create_protocol=basic_auth_protocol_factory( realm="example", credentials=("mary", "p@ssw0rd") ), ): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) websockets-15.0.1/example/legacy/unix_client.py000077500000000000000000000007651476212450300215600ustar00rootroot00000000000000#!/usr/bin/env python # WS client example connecting to a Unix socket import asyncio import os.path from websockets.legacy.client import unix_connect async def hello(): socket_path = os.path.join(os.path.dirname(__file__), "socket") async with unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) print(f">>> {name}") greeting = await websocket.recv() print(f"<<< {greeting}") asyncio.run(hello()) websockets-15.0.1/example/legacy/unix_server.py000077500000000000000000000010631476212450300216000ustar00rootroot00000000000000#!/usr/bin/env python # WS server example listening on a Unix socket import asyncio import os.path from websockets.legacy.server import unix_serve async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) print(f">>> {greeting}") async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") async with unix_serve(hello, socket_path): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) websockets-15.0.1/example/quick/000077500000000000000000000000001476212450300165225ustar00rootroot00000000000000websockets-15.0.1/example/quick/client.py000077500000000000000000000005501476212450300203550ustar00rootroot00000000000000#!/usr/bin/env python from websockets.sync.client import connect def hello(): uri = "ws://localhost:8765" with connect(uri) as websocket: name = input("What's your name? ") websocket.send(name) print(f">>> {name}") greeting = websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": hello() websockets-15.0.1/example/quick/counter.css000066400000000000000000000007331476212450300207160ustar00rootroot00000000000000body { font-family: "Courier New", sans-serif; text-align: center; } .buttons { font-size: 4em; display: flex; justify-content: center; } .button, .value { line-height: 1; padding: 2rem; margin: 2rem; border: medium solid; min-height: 1em; min-width: 1em; } .button { cursor: pointer; user-select: none; } .minus { color: red; } .plus { color: green; } .value { min-width: 2em; } .state { font-size: 2em; } websockets-15.0.1/example/quick/counter.html000066400000000000000000000007451476212450300210750ustar00rootroot00000000000000 WebSocket demo
-
?
+
? online
websockets-15.0.1/example/quick/counter.js000066400000000000000000000015021476212450300205350ustar00rootroot00000000000000window.addEventListener("DOMContentLoaded", () => { const websocket = new WebSocket("ws://localhost:6789/"); document.querySelector(".minus").addEventListener("click", () => { websocket.send(JSON.stringify({ action: "minus" })); }); document.querySelector(".plus").addEventListener("click", () => { websocket.send(JSON.stringify({ action: "plus" })); }); websocket.onmessage = ({ data }) => { const event = JSON.parse(data); switch (event.type) { case "value": document.querySelector(".value").textContent = event.value; break; case "users": const users = `${event.count} user${event.count == 1 ? "" : "s"}`; document.querySelector(".users").textContent = users; break; default: console.error("unsupported event", event); } }; }); websockets-15.0.1/example/quick/counter.py000077500000000000000000000023721476212450300205620ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import json import logging from websockets.asyncio.server import broadcast, serve logging.basicConfig() USERS = set() VALUE = 0 def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) def value_event(): return json.dumps({"type": "value", "value": VALUE}) async def counter(websocket): global USERS, VALUE try: # Register user USERS.add(websocket) broadcast(USERS, users_event()) # Send current state to user await websocket.send(value_event()) # Manage state changes async for message in websocket: event = json.loads(message) if event["action"] == "minus": VALUE -= 1 broadcast(USERS, value_event()) elif event["action"] == "plus": VALUE += 1 broadcast(USERS, value_event()) else: logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) broadcast(USERS, users_event()) async def main(): async with serve(counter, "localhost", 6789) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/quick/server.py000077500000000000000000000006631476212450300204120ustar00rootroot00000000000000#!/usr/bin/env python import asyncio from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) print(f">>> {greeting}") async def main(): async with serve(hello, "localhost", 8765) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/quick/show_time.html000066400000000000000000000002521476212450300214050ustar00rootroot00000000000000 WebSocket demo websockets-15.0.1/example/quick/show_time.js000066400000000000000000000006431476212450300210610ustar00rootroot00000000000000window.addEventListener("DOMContentLoaded", () => { const messages = document.createElement("ul"); document.body.appendChild(messages); const websocket = new WebSocket("ws://localhost:5678/"); websocket.onmessage = ({ data }) => { const message = document.createElement("li"); const content = document.createTextNode(data); message.appendChild(content); messages.appendChild(message); }; }); websockets-15.0.1/example/quick/show_time.py000077500000000000000000000007651476212450300211050ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import datetime import random from websockets.asyncio.server import serve async def show_time(websocket): while True: message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() await websocket.send(message) await asyncio.sleep(random.random() * 2 + 1) async def main(): async with serve(show_time, "localhost", 5678) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/quick/sync_time.py000077500000000000000000000010721476212450300210710ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import datetime import random from websockets.asyncio.server import broadcast, serve async def noop(websocket): await websocket.wait_closed() async def show_time(server): while True: message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() broadcast(server.connections, message) await asyncio.sleep(random.random() * 2 + 1) async def main(): async with serve(noop, "localhost", 5678) as server: await show_time(server) if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/ruff.toml000066400000000000000000000000401476212450300172370ustar00rootroot00000000000000[lint.isort] no-sections = true websockets-15.0.1/example/sync/000077500000000000000000000000001476212450300163625ustar00rootroot00000000000000websockets-15.0.1/example/sync/client.py000066400000000000000000000006131476212450300202120ustar00rootroot00000000000000#!/usr/bin/env python """Client example using the threading API.""" from websockets.sync.client import connect def hello(): with connect("ws://localhost:8765") as websocket: name = input("What's your name? ") websocket.send(name) print(f">>> {name}") greeting = websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": hello() websockets-15.0.1/example/sync/echo.py000077500000000000000000000005111476212450300176520ustar00rootroot00000000000000#!/usr/bin/env python """Echo server using the threading API.""" from websockets.sync.server import serve def echo(websocket): for message in websocket: websocket.send(message) def main(): with serve(echo, "localhost", 8765) as server: server.serve_forever() if __name__ == "__main__": main() websockets-15.0.1/example/sync/hello.py000077500000000000000000000004701476212450300200430ustar00rootroot00000000000000#!/usr/bin/env python """Client using the threading API.""" from websockets.sync.client import connect def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() print(message) if __name__ == "__main__": hello() websockets-15.0.1/example/sync/server.py000066400000000000000000000006411476212450300202430ustar00rootroot00000000000000#!/usr/bin/env python """Server example using the threading API.""" from websockets.sync.server import serve def hello(websocket): name = websocket.recv() print(f"<<< {name}") greeting = f"Hello {name}!" websocket.send(greeting) print(f">>> {greeting}") def main(): with serve(hello, "localhost", 8765) as server: server.serve_forever() if __name__ == "__main__": main() websockets-15.0.1/example/tls/000077500000000000000000000000001476212450300162105ustar00rootroot00000000000000websockets-15.0.1/example/tls/client.py000077500000000000000000000010771476212450300200500ustar00rootroot00000000000000#!/usr/bin/env python import pathlib import ssl from websockets.sync.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_verify_locations(localhost_pem) def hello(): uri = "wss://localhost:8765" with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") websocket.send(name) print(f">>> {name}") greeting = websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": hello() websockets-15.0.1/example/tls/localhost.pem000066400000000000000000000055341476212450300207120ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDG8iDak4UBpurI TWjSfqJ0YVG/S56nhswehupCaIzu0xQ8wqPSs36h5t1jMexJPZfvwyvFjcV+hYpj LMM0wMJPx9oBQEe0bsmlC66e8aF0UpSQw1aVfYoxA9BejgEyrFNE7cRbQNYFEb/5 3HfqZKdEQA2fgQSlZ0RTRmLrD+l72iO5o2xl5bttXpqYZB2XOkyO79j/xWdu9zFE sgZJ5ysWbqoRAGgnxjdYYr9DARd8bIE/hN3SW7mDt5v4LqCIhGn1VmrwtT3d5AuG QPz4YEbm0t6GOlmFjIMYH5Y7pALRVfoJKRj6DGNIR1JicL+wqLV66kcVnj8WKbla 20i7fR7NAgMBAAECggEAG5yvgqbG5xvLqlFUIyMAWTbIqcxNEONcoUAIc38fUGZr gKNjKXNQOBha0dG0AdZSqCxmftzWdGEEfA9SaJf4YCpUz6ekTB60Tfv5GIZg6kwr 4ou6ELWD4Jmu6fC7qdTRGdgGUMQG8F0uT/eRjS67KHXbbi/x/SMAEK7MO+PRfCbj +JGzS9Ym9mUweINPotgjHdDGwwd039VWYS+9A+QuNK27p3zq4hrWRb4wshSC8fKy oLoe4OQt81aowpX9k6mAU6N8vOmP8/EcQHYC+yFIIDZB2EmDP07R1LUEH3KJnzo7 plCK1/kYPhX0a05cEdTpXdKa74AlvSRkS11sGqfUAQKBgQDj1SRv0AUGsHSA0LWx a0NT1ZLEXCG0uqgdgh0sTqIeirQsPROw3ky4lH5MbjkfReArFkhHu3M6KoywEPxE wanSRh/t1qcNjNNZUvFoUzAKVpb33RLkJppOTVEWPt+wtyDlfz1ZAXzMV66tACrx H2a3v0ZWUz6J+x/dESH5TTNL4QKBgQDfirmknp408pwBE+bulngKy0QvU09En8H0 uvqr8q4jCXqJ1tXon4wsHg2yF4Fa37SCpSmvONIDwJvVWkkYLyBHKOns/fWCkW3n hIcYx0q2jgcoOLU0uoaM9ArRXhIxoWqV/KGkQzN+3xXC1/MxZ5OhyxBxfPCPIYIN YN3M1t/QbQKBgDImhsC+D30rdlmsl3IYZFed2ZKznQ/FTqBANd+8517FtWdPgnga VtUCitKUKKrDnNafLwXrMzAIkbNn6b/QyWrp2Lln2JnY9+TfpxgJx7de3BhvZ2sl PC4kQsccy+yAQxOBcKWY+Dmay251bP5qpRepWPhDlq6UwqzMyqev4KzBAoGAWDMi IEO9ZGK9DufNXCHeZ1PgKVQTmJ34JxmHQkTUVFqvEKfFaq1Y3ydUfAouLa7KSCnm ko42vuhGFB41bOdbMvh/o9RoBAZheNGfhDVN002ioUoOpSlbYU4A3q7hOtfXeCpf lLI3JT3cFi6ic8HMTDAU4tJLEA5GhATOPr4hPNkCgYB8jTYGcLvoeFaLEveg0kS2 cz6ZXGLJx5m1AOQy5g9FwGaW+10lr8TF2k3AldwoiwX0R6sHAf/945aGU83ms5v9 PB9/x66AYtSRUos9MwB4y1ur4g6FiXZUBgTJUqzz2nehPCyGjYhh49WucjszqcjX chS1bKZOY+1knWq8xj5Qyg== -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIIDTTCCAjWgAwIBAgIJAOjte6l+03jvMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTkyOVoYDzIwNjAwNTA0 MTY1OTI5WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI hvcNAQEBBQADggEPADCCAQoCggEBAMbyINqThQGm6shNaNJ+onRhUb9LnqeGzB6G 6kJojO7TFDzCo9KzfqHm3WMx7Ek9l+/DK8WNxX6FimMswzTAwk/H2gFAR7RuyaUL rp7xoXRSlJDDVpV9ijED0F6OATKsU0TtxFtA1gURv/ncd+pkp0RADZ+BBKVnRFNG YusP6XvaI7mjbGXlu21emphkHZc6TI7v2P/FZ273MUSyBknnKxZuqhEAaCfGN1hi v0MBF3xsgT+E3dJbuYO3m/guoIiEafVWavC1Pd3kC4ZA/PhgRubS3oY6WYWMgxgf ljukAtFV+gkpGPoMY0hHUmJwv7CotXrqRxWePxYpuVrbSLt9Hs0CAwEAAaMwMC4w LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G CSqGSIb3DQEBCwUAA4IBAQC9TsTxTEvqHPUS6sfvF77eG0D6HLOONVN91J+L7LiX v3bFeS1xbUS6/wIxZi5EnAt/te5vaHk/5Q1UvznQP4j2gNoM6lH/DRkSARvRitVc H0qN4Xp2Yk1R9VEx4ZgArcyMpI+GhE4vJRx1LE/hsuAzw7BAdsTt9zicscNg2fxO 3ao/eBcdaC6n9aFYdE6CADMpB1lCX2oWNVdj6IavQLu7VMc+WJ3RKncwC9th+5OP ISPvkVZWf25rR2STmvvb0qEm3CZjk4Xd7N+gxbKKUvzEgPjrLSWzKKJAWHjCLugI /kQqhpjWVlTbtKzWz5bViqCjSbrIPpU2MgG9AUV9y3iV -----END CERTIFICATE----- websockets-15.0.1/example/tls/server.py000077500000000000000000000012021476212450300200660ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import pathlib import ssl from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) print(f">>> {greeting}") ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_cert_chain(localhost_pem) async def main(): async with serve(hello, "localhost", 8765, ssl=ssl_context) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/tutorial/000077500000000000000000000000001476212450300172515ustar00rootroot00000000000000websockets-15.0.1/example/tutorial/start/000077500000000000000000000000001476212450300204065ustar00rootroot00000000000000websockets-15.0.1/example/tutorial/start/connect4.css000066400000000000000000000027001476212450300226340ustar00rootroot00000000000000/* General layout */ body { background-color: white; display: flex; flex-direction: column-reverse; justify-content: center; align-items: center; margin: 0; min-height: 100vh; } /* Action buttons */ .actions { display: flex; flex-direction: row; justify-content: space-evenly; align-items: flex-end; width: 720px; height: 100px; } .action { color: darkgray; font-family: "Helvetica Neue", sans-serif; font-size: 20px; line-height: 20px; font-weight: 300; text-align: center; text-decoration: none; text-transform: uppercase; padding: 20px; width: 120px; } .action:hover { background-color: darkgray; color: white; font-weight: 700; } .action[href=""] { display: none; } /* Connect Four board */ .board { background-color: blue; display: flex; flex-direction: row; padding: 0 10px; position: relative; } .board::before, .board::after { background-color: blue; content: ""; height: 720px; width: 20px; position: absolute; } .board::before { left: -20px; } .board::after { right: -20px; } .column { display: flex; flex-direction: column-reverse; padding: 10px; } .cell { border-radius: 50%; width: 80px; height: 80px; margin: 10px 0; } .empty { background-color: white; } .column:hover .empty { background-color: lightgray; } .column:hover .empty ~ .empty { background-color: white; } .red { background-color: red; } .yellow { background-color: yellow; } websockets-15.0.1/example/tutorial/start/connect4.js000066400000000000000000000027231476212450300224650ustar00rootroot00000000000000const PLAYER1 = "red"; const PLAYER2 = "yellow"; function createBoard(board) { // Inject stylesheet. const linkElement = document.createElement("link"); linkElement.href = import.meta.url.replace(".js", ".css"); linkElement.rel = "stylesheet"; document.head.append(linkElement); // Generate board. for (let column = 0; column < 7; column++) { const columnElement = document.createElement("div"); columnElement.className = "column"; columnElement.dataset.column = column; for (let row = 0; row < 6; row++) { const cellElement = document.createElement("div"); cellElement.className = "cell empty"; cellElement.dataset.column = column; columnElement.append(cellElement); } board.append(columnElement); } } function playMove(board, player, column, row) { // Check values of arguments. if (player !== PLAYER1 && player !== PLAYER2) { throw new Error(`player must be ${PLAYER1} or ${PLAYER2}.`); } const columnElement = board.querySelectorAll(".column")[column]; if (columnElement === undefined) { throw new RangeError("column must be between 0 and 6."); } const cellElement = columnElement.querySelectorAll(".cell")[row]; if (cellElement === undefined) { throw new RangeError("row must be between 0 and 5."); } // Place checker in cell. if (!cellElement.classList.replace("empty", player)) { throw new Error("cell must be empty."); } } export { PLAYER1, PLAYER2, createBoard, playMove }; websockets-15.0.1/example/tutorial/start/connect4.py000066400000000000000000000026111476212450300224750ustar00rootroot00000000000000__all__ = ["PLAYER1", "PLAYER2", "Connect4"] PLAYER1, PLAYER2 = "red", "yellow" class Connect4: """ A Connect Four game. Play moves with :meth:`play`. Get past moves with :attr:`moves`. Check for a victory with :attr:`winner`. """ def __init__(self): self.moves = [] self.top = [0 for _ in range(7)] self.winner = None @property def last_player(self): """ Player who played the last move. """ return PLAYER1 if len(self.moves) % 2 else PLAYER2 @property def last_player_won(self): """ Whether the last move is winning. """ b = sum(1 << (8 * column + row) for _, column, row in self.moves[::-2]) return any(b & b >> v & b >> 2 * v & b >> 3 * v for v in [1, 7, 8, 9]) def play(self, player, column): """ Play a move in a column. Returns the row where the checker lands. Raises :exc:`ValueError` if the move is illegal. """ if player == self.last_player: raise ValueError("It isn't your turn.") row = self.top[column] if row == 6: raise ValueError("This slot is full.") self.moves.append((player, column, row)) self.top[column] += 1 if self.winner is None and self.last_player_won: self.winner = self.last_player return row websockets-15.0.1/example/tutorial/start/favicon.ico000066400000000000000000000124661476212450300225400ustar00rootroot00000000000000 h&  (  >-=v=`@BU=;;;B#JVEA=<;<RWNIEE%>;;;=*\VUQMN$KcEA=D`q9 ^]ZU$RdMIFgjj1i0k1Cfb_#[eUQNRt8wk2i0k19dJ]YUR|>sv:r7k3k2uj50fda]^Gk~@z>w<Ҧs@l2qzEh;pēNICBѳ@xɪwDʖSŒPMJCŻD͗UBƔPyÏMY̙f( @ <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<UUUUUUUUUB<<<;<>>>>>>F3A?>;;;;;;<<<<<<<<<<<<<<<======J4FDB?><;;;;;>->>>>>>>>>>>>><<<<<M5JGECB@=<;;;;==============UUUUR5NKIGECBBA'<;;;;;<;;;;xt9s8q6m3l3k1k19k1k1k1s5s5s5s5s5s5s5s5s5s5s5s5ÐOMJHGG^GG}?z=yo5o5o5o5u;u;u;u;u;u;u;u;u;u;u;u;ƑOÐOLJIG2GCB|?z=v:u9t7s5>s5s5s5s5s5}?}?}?}?}?}?}?}?}?}?}?}?ʕTtőOÐNMII귆GʵDC~@|?z=x Connect Four
websockets-15.0.1/example/tutorial/step1/main.js000066400000000000000000000030541476212450300215710ustar00rootroot00000000000000import { createBoard, playMove } from "./connect4.js"; function showMessage(message) { window.setTimeout(() => window.alert(message), 50); } function receiveMoves(board, websocket) { websocket.addEventListener("message", ({ data }) => { const event = JSON.parse(data); switch (event.type) { case "play": // Update the UI with the move. playMove(board, event.player, event.column, event.row); break; case "win": showMessage(`Player ${event.player} wins!`); // No further messages are expected; close the WebSocket connection. websocket.close(1000); break; case "error": showMessage(event.message); break; default: throw new Error(`Unsupported event type: ${event.type}.`); } }); } function sendMoves(board, websocket) { // When clicking a column, send a "play" event for a move in that column. board.addEventListener("click", ({ target }) => { const column = target.dataset.column; // Ignore clicks outside a column. if (column === undefined) { return; } const event = { type: "play", column: parseInt(column, 10), }; websocket.send(JSON.stringify(event)); }); } window.addEventListener("DOMContentLoaded", () => { // Initialize the UI. const board = document.querySelector(".board"); createBoard(board); // Open the WebSocket connection and register event handlers. const websocket = new WebSocket("ws://localhost:8001/"); receiveMoves(board, websocket); sendMoves(board, websocket); }); websockets-15.0.1/example/tutorial/step2/000077500000000000000000000000001476212450300203065ustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step2/app.py000066400000000000000000000116551476212450300214500ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import json import secrets from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 JOIN = {} WATCH = {} async def error(websocket, message): """ Send an error message. """ event = { "type": "error", "message": message, } await websocket.send(json.dumps(event)) async def replay(websocket, game): """ Send previous moves. """ # Make a copy to avoid an exception if game.moves changes while iteration # is in progress. If a move is played while replay is running, moves will # be sent out of order but each move will be sent once and eventually the # UI will be consistent. for player, column, row in game.moves.copy(): event = { "type": "play", "player": player, "column": column, "row": row, } await websocket.send(json.dumps(event)) async def play(websocket, game, player, connected): """ Receive and process moves from a player. """ async for message in websocket: # Parse a "play" event from the UI. event = json.loads(message) assert event["type"] == "play" column = event["column"] try: # Play the move. row = game.play(player, column) except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue # Send a "play" event to update the UI. event = { "type": "play", "player": player, "column": column, "row": row, } broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: event = { "type": "win", "player": game.winner, } broadcast(connected, json.dumps(event)) async def start(websocket): """ Handle a connection from the first player: start a new game. """ # Initialize a Connect Four game, the set of WebSocket connections # receiving moves from this game, and secret access tokens. game = Connect4() connected = {websocket} join_key = secrets.token_urlsafe(12) JOIN[join_key] = game, connected watch_key = secrets.token_urlsafe(12) WATCH[watch_key] = game, connected try: # Send the secret access tokens to the browser of the first player, # where they'll be used for building "join" and "watch" links. event = { "type": "init", "join": join_key, "watch": watch_key, } await websocket.send(json.dumps(event)) # Receive and process moves from the first player. await play(websocket, game, PLAYER1, connected) finally: del JOIN[join_key] del WATCH[watch_key] async def join(websocket, join_key): """ Handle a connection from the second player: join an existing game. """ # Find the Connect Four game. try: game, connected = JOIN[join_key] except KeyError: await error(websocket, "Game not found.") return # Register to receive moves from this game. connected.add(websocket) try: # Send the first move, in case the first player already played it. await replay(websocket, game) # Receive and process moves from the second player. await play(websocket, game, PLAYER2, connected) finally: connected.remove(websocket) async def watch(websocket, watch_key): """ Handle a connection from a spectator: watch an existing game. """ # Find the Connect Four game. try: game, connected = WATCH[watch_key] except KeyError: await error(websocket, "Game not found.") return # Register to receive moves from this game. connected.add(websocket) try: # Send previous moves, in case the game already started. await replay(websocket, game) # Keep the connection open, but don't receive any messages. await websocket.wait_closed() finally: connected.remove(websocket) async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. """ # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) assert event["type"] == "init" if "join" in event: # Second player joins an existing game. await join(websocket, event["join"]) elif "watch" in event: # Spectator watches an existing game. await watch(websocket, event["watch"]) else: # First player starts a new game. await start(websocket) async def main(): async with serve(handler, "", 8001) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/tutorial/step2/connect4.css000077700000000000000000000000001476212450300263272../start/connect4.cssustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step2/connect4.js000077700000000000000000000000001476212450300257772../start/connect4.jsustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step2/connect4.py000077700000000000000000000000001476212450300260272../start/connect4.pyustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step2/favicon.ico000077700000000000000000000000001476212450300263442../../../logo/favicon.icoustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step2/index.html000066400000000000000000000005571476212450300223120ustar00rootroot00000000000000 Connect Four
websockets-15.0.1/example/tutorial/step2/main.js000066400000000000000000000050501476212450300215700ustar00rootroot00000000000000import { createBoard, playMove } from "./connect4.js"; function initGame(websocket) { websocket.addEventListener("open", () => { // Send an "init" event according to who is connecting. const params = new URLSearchParams(window.location.search); let event = { type: "init" }; if (params.has("join")) { // Second player joins an existing game. event.join = params.get("join"); } else if (params.has("watch")) { // Spectator watches an existing game. event.watch = params.get("watch"); } else { // First player starts a new game. } websocket.send(JSON.stringify(event)); }); } function showMessage(message) { window.setTimeout(() => window.alert(message), 50); } function receiveMoves(board, websocket) { websocket.addEventListener("message", ({ data }) => { const event = JSON.parse(data); switch (event.type) { case "init": // Create links for inviting the second player and spectators. document.querySelector(".join").href = "?join=" + event.join; document.querySelector(".watch").href = "?watch=" + event.watch; break; case "play": // Update the UI with the move. playMove(board, event.player, event.column, event.row); break; case "win": showMessage(`Player ${event.player} wins!`); // No further messages are expected; close the WebSocket connection. websocket.close(1000); break; case "error": showMessage(event.message); break; default: throw new Error(`Unsupported event type: ${event.type}.`); } }); } function sendMoves(board, websocket) { // Don't send moves for a spectator watching a game. const params = new URLSearchParams(window.location.search); if (params.has("watch")) { return; } // When clicking a column, send a "play" event for a move in that column. board.addEventListener("click", ({ target }) => { const column = target.dataset.column; // Ignore clicks outside a column. if (column === undefined) { return; } const event = { type: "play", column: parseInt(column, 10), }; websocket.send(JSON.stringify(event)); }); } window.addEventListener("DOMContentLoaded", () => { // Initialize the UI. const board = document.querySelector(".board"); createBoard(board); // Open the WebSocket connection and register event handlers. const websocket = new WebSocket("ws://localhost:8001/"); initGame(websocket); receiveMoves(board, websocket); sendMoves(board, websocket); }); websockets-15.0.1/example/tutorial/step3/000077500000000000000000000000001476212450300203075ustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step3/Procfile000066400000000000000000000000231476212450300217700ustar00rootroot00000000000000web: python app.py websockets-15.0.1/example/tutorial/step3/app.py000066400000000000000000000124161476212450300214450ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import http import json import os import secrets import signal from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 JOIN = {} WATCH = {} async def error(websocket, message): """ Send an error message. """ event = { "type": "error", "message": message, } await websocket.send(json.dumps(event)) async def replay(websocket, game): """ Send previous moves. """ # Make a copy to avoid an exception if game.moves changes while iteration # is in progress. If a move is played while replay is running, moves will # be sent out of order but each move will be sent once and eventually the # UI will be consistent. for player, column, row in game.moves.copy(): event = { "type": "play", "player": player, "column": column, "row": row, } await websocket.send(json.dumps(event)) async def play(websocket, game, player, connected): """ Receive and process moves from a player. """ async for message in websocket: # Parse a "play" event from the UI. event = json.loads(message) assert event["type"] == "play" column = event["column"] try: # Play the move. row = game.play(player, column) except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue # Send a "play" event to update the UI. event = { "type": "play", "player": player, "column": column, "row": row, } broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: event = { "type": "win", "player": game.winner, } broadcast(connected, json.dumps(event)) async def start(websocket): """ Handle a connection from the first player: start a new game. """ # Initialize a Connect Four game, the set of WebSocket connections # receiving moves from this game, and secret access tokens. game = Connect4() connected = {websocket} join_key = secrets.token_urlsafe(12) JOIN[join_key] = game, connected watch_key = secrets.token_urlsafe(12) WATCH[watch_key] = game, connected try: # Send the secret access tokens to the browser of the first player, # where they'll be used for building "join" and "watch" links. event = { "type": "init", "join": join_key, "watch": watch_key, } await websocket.send(json.dumps(event)) # Receive and process moves from the first player. await play(websocket, game, PLAYER1, connected) finally: del JOIN[join_key] del WATCH[watch_key] async def join(websocket, join_key): """ Handle a connection from the second player: join an existing game. """ # Find the Connect Four game. try: game, connected = JOIN[join_key] except KeyError: await error(websocket, "Game not found.") return # Register to receive moves from this game. connected.add(websocket) try: # Send the first move, in case the first player already played it. await replay(websocket, game) # Receive and process moves from the second player. await play(websocket, game, PLAYER2, connected) finally: connected.remove(websocket) async def watch(websocket, watch_key): """ Handle a connection from a spectator: watch an existing game. """ # Find the Connect Four game. try: game, connected = WATCH[watch_key] except KeyError: await error(websocket, "Game not found.") return # Register to receive moves from this game. connected.add(websocket) try: # Send previous moves, in case the game already started. await replay(websocket, game) # Keep the connection open, but don't receive any messages. await websocket.wait_closed() finally: connected.remove(websocket) async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. """ # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) assert event["type"] == "init" if "join" in event: # Second player joins an existing game. await join(websocket, event["join"]) elif "watch" in event: # Spectator watches an existing game. await watch(websocket, event["watch"]) else: # First player starts a new game. await start(websocket) def health_check(connection, request): if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): port = int(os.environ.get("PORT", "8001")) async with serve(handler, "", port, process_request=health_check) as server: loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/example/tutorial/step3/connect4.css000077700000000000000000000000001476212450300263302../start/connect4.cssustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step3/connect4.js000077700000000000000000000000001476212450300260002../start/connect4.jsustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step3/connect4.py000077700000000000000000000000001476212450300260302../start/connect4.pyustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step3/favicon.ico000077700000000000000000000000001476212450300263452../../../logo/favicon.icoustar00rootroot00000000000000websockets-15.0.1/example/tutorial/step3/index.html000066400000000000000000000005571476212450300223130ustar00rootroot00000000000000 Connect Four
websockets-15.0.1/example/tutorial/step3/main.js000066400000000000000000000055521476212450300216000ustar00rootroot00000000000000import { createBoard, playMove } from "./connect4.js"; function getWebSocketServer() { if (window.location.host === "python-websockets.github.io") { return "wss://websockets-tutorial.koyeb.app/"; } else if (window.location.host === "localhost:8000") { return "ws://localhost:8001/"; } else { throw new Error(`Unsupported host: ${window.location.host}`); } } function initGame(websocket) { websocket.addEventListener("open", () => { // Send an "init" event according to who is connecting. const params = new URLSearchParams(window.location.search); let event = { type: "init" }; if (params.has("join")) { // Second player joins an existing game. event.join = params.get("join"); } else if (params.has("watch")) { // Spectator watches an existing game. event.watch = params.get("watch"); } else { // First player starts a new game. } websocket.send(JSON.stringify(event)); }); } function showMessage(message) { window.setTimeout(() => window.alert(message), 50); } function receiveMoves(board, websocket) { websocket.addEventListener("message", ({ data }) => { const event = JSON.parse(data); switch (event.type) { case "init": // Create links for inviting the second player and spectators. document.querySelector(".join").href = "?join=" + event.join; document.querySelector(".watch").href = "?watch=" + event.watch; break; case "play": // Update the UI with the move. playMove(board, event.player, event.column, event.row); break; case "win": showMessage(`Player ${event.player} wins!`); // No further messages are expected; close the WebSocket connection. websocket.close(1000); break; case "error": showMessage(event.message); break; default: throw new Error(`Unsupported event type: ${event.type}.`); } }); } function sendMoves(board, websocket) { // Don't send moves for a spectator watching a game. const params = new URLSearchParams(window.location.search); if (params.has("watch")) { return; } // When clicking a column, send a "play" event for a move in that column. board.addEventListener("click", ({ target }) => { const column = target.dataset.column; // Ignore clicks outside a column. if (column === undefined) { return; } const event = { type: "play", column: parseInt(column, 10), }; websocket.send(JSON.stringify(event)); }); } window.addEventListener("DOMContentLoaded", () => { // Initialize the UI. const board = document.querySelector(".board"); createBoard(board); // Open the WebSocket connection and register event handlers. const websocket = new WebSocket(getWebSocketServer()); initGame(websocket); receiveMoves(board, websocket); sendMoves(board, websocket); }); websockets-15.0.1/example/tutorial/step3/requirements.txt000066400000000000000000000000131476212450300235650ustar00rootroot00000000000000websockets websockets-15.0.1/experiments/000077500000000000000000000000001476212450300163165ustar00rootroot00000000000000websockets-15.0.1/experiments/authentication/000077500000000000000000000000001476212450300213355ustar00rootroot00000000000000websockets-15.0.1/experiments/authentication/app.py000066400000000000000000000125141476212450300224720ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import email.utils import http import http.cookies import pathlib import signal import urllib.parse import uuid from websockets.asyncio.server import basic_auth as websockets_basic_auth, serve from websockets.datastructures import Headers from websockets.frames import CloseCode from websockets.http11 import Response # User accounts database USERS = {} def create_token(user, lifetime=1): """Create token for user and delete it once its lifetime is over.""" token = uuid.uuid4().hex USERS[token] = user asyncio.get_running_loop().call_later(lifetime, USERS.pop, token) return token def get_user(token): """Find user authenticated by token or return None.""" return USERS.get(token) # Utilities def get_cookie(raw, key): cookie = http.cookies.SimpleCookie(raw) morsel = cookie.get(key) if morsel is not None: return morsel.value def get_query_param(path, key): query = urllib.parse.urlparse(path).query params = urllib.parse.parse_qs(query) values = params.get(key, []) if len(values) == 1: return values[0] # WebSocket handler async def handler(websocket): try: user = websocket.username except AttributeError: return await websocket.send(f"Hello {user}!") message = await websocket.recv() assert message == f"Goodbye {user}." CONTENT_TYPES = { ".css": "text/css", ".html": "text/html; charset=utf-8", ".ico": "image/x-icon", ".js": "text/javascript", } async def serve_html(connection, request): """Basic HTTP server implemented as a process_request hook.""" user = get_query_param(request.path, "user") path = urllib.parse.urlparse(request.path).path if path == "/": if user is None: page = "index.html" else: page = "test.html" else: page = path[1:] try: template = pathlib.Path(__file__).with_name(page) except ValueError: pass else: if template.is_file(): body = template.read_bytes() if user is not None: token = create_token(user) body = body.replace(b"TOKEN", token.encode()) headers = Headers( { "Date": email.utils.formatdate(usegmt=True), "Connection": "close", "Content-Length": str(len(body)), "Content-Type": CONTENT_TYPES[template.suffix], } ) return Response(200, "OK", headers, body) return connection.respond(http.HTTPStatus.NOT_FOUND, "Not found\n") async def first_message_handler(websocket): """Handler that sends credentials in the first WebSocket message.""" token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return websocket.username = user await handler(websocket) async def query_param_auth(connection, request): """Authenticate user from token in query parameter.""" token = get_query_param(request.path, "token") if token is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") user = get_user(token) if user is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") connection.username = user async def cookie_auth(connection, request): """Authenticate user from token in cookie.""" if "Upgrade" not in request.headers: template = pathlib.Path(__file__).with_name(request.path[1:]) body = template.read_bytes() headers = Headers( { "Date": email.utils.formatdate(usegmt=True), "Connection": "close", "Content-Length": str(len(body)), "Content-Type": CONTENT_TYPES[template.suffix], } ) return Response(200, "OK", headers, body) token = get_cookie(request.headers.get("Cookie", ""), "token") if token is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") user = get_user(token) if user is None: return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") connection.username = user def check_credentials(username, password): """Authenticate user with HTTP Basic Auth.""" return username == get_user(password) basic_auth = websockets_basic_auth(check_credentials=check_credentials) async def main(): """Start one HTTP server and four WebSocket servers.""" # Set the stop condition when receiving SIGINT or SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with ( serve(handler, host="", port=8000, process_request=serve_html), serve(first_message_handler, host="", port=8001), serve(handler, host="", port=8002, process_request=query_param_auth), serve(handler, host="", port=8003, process_request=cookie_auth), serve(handler, host="", port=8004, process_request=basic_auth), ): print("Running on http://localhost:8000/") await stop print("\rExiting") if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/experiments/authentication/cookie.html000066400000000000000000000007661476212450300235050ustar00rootroot00000000000000 Cookie | WebSocket Authentication

[??] Cookie

[OK] Cookie

[KO] Cookie

websockets-15.0.1/experiments/authentication/cookie.js000066400000000000000000000012641476212450300231470ustar00rootroot00000000000000// send token to iframe window.addEventListener("DOMContentLoaded", () => { const iframe = document.querySelector("iframe"); iframe.addEventListener("load", () => { iframe.contentWindow.postMessage(token, "http://localhost:8003"); }); }); // once iframe has set cookie, open WebSocket connection window.addEventListener("message", ({ origin }) => { if (origin !== "http://localhost:8003") { return; } const websocket = new WebSocket("ws://localhost:8003/"); websocket.onmessage = ({ data }) => { // event.data is expected to be "Hello !" websocket.send(`Goodbye ${data.slice(6, -1)}.`); }; runTest(websocket); }); websockets-15.0.1/experiments/authentication/cookie_iframe.html000066400000000000000000000003101476212450300250110ustar00rootroot00000000000000 Cookie iframe | WebSocket Authentication websockets-15.0.1/experiments/authentication/cookie_iframe.js000066400000000000000000000004761476212450300244760ustar00rootroot00000000000000// receive token from the parent window, set cookie and notify parent window.addEventListener("message", ({ origin, data }) => { if (origin !== "http://localhost:8000") { return; } document.cookie = `token=${data}; SameSite=Strict`; window.parent.postMessage("", "http://localhost:8000"); }); websockets-15.0.1/experiments/authentication/favicon.ico000077700000000000000000000000001476212450300271602../../logo/favicon.icoustar00rootroot00000000000000websockets-15.0.1/experiments/authentication/first_message.html000066400000000000000000000006711476212450300250620ustar00rootroot00000000000000 First message | WebSocket Authentication

[??] First message

[OK] First message

[KO] First message

websockets-15.0.1/experiments/authentication/first_message.js000066400000000000000000000005451476212450300245320ustar00rootroot00000000000000window.addEventListener("DOMContentLoaded", () => { const websocket = new WebSocket("ws://localhost:8001/"); websocket.onopen = () => websocket.send(token); websocket.onmessage = ({ data }) => { // event.data is expected to be "Hello !" websocket.send(`Goodbye ${data.slice(6, -1)}.`); }; runTest(websocket); }); websockets-15.0.1/experiments/authentication/index.html000066400000000000000000000004331476212450300233320ustar00rootroot00000000000000 WebSocket Authentication
websockets-15.0.1/experiments/authentication/query_param.html000066400000000000000000000006771476212450300245620ustar00rootroot00000000000000 Query parameter | WebSocket Authentication

[??] Query parameter

[OK] Query parameter

[KO] Query parameter

websockets-15.0.1/experiments/authentication/query_param.js000066400000000000000000000005251476212450300242220ustar00rootroot00000000000000window.addEventListener("DOMContentLoaded", () => { const uri = `ws://localhost:8002/?token=${token}`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { // event.data is expected to be "Hello !" websocket.send(`Goodbye ${data.slice(6, -1)}.`); }; runTest(websocket); }); websockets-15.0.1/experiments/authentication/script.js000066400000000000000000000025161476212450300232030ustar00rootroot00000000000000var token = window.parent.token, user = window.parent.user; function getExpectedEvents() { return [ { type: "open", }, { type: "message", data: `Hello ${user}!`, }, { type: "close", code: 1000, reason: "", wasClean: true, }, ]; } function isEqual(expected, actual) { // good enough for our purposes here! return JSON.stringify(expected) === JSON.stringify(actual); } function testStep(expected, actual) { if (isEqual(expected, actual)) { document.body.className = "ok"; } else if (isEqual(expected.slice(0, actual.length), actual)) { document.body.className = "test"; } else { document.body.className = "ko"; } } function runTest(websocket) { const expected = getExpectedEvents(); var actual = []; websocket.addEventListener("open", ({ type }) => { actual.push({ type }); testStep(expected, actual); }); websocket.addEventListener("message", ({ type, data }) => { actual.push({ type, data }); testStep(expected, actual); }); websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { actual.push({ type, code, reason, wasClean }); testStep(expected, actual); }); } websockets-15.0.1/experiments/authentication/style.css000066400000000000000000000017311476212450300232110ustar00rootroot00000000000000/* page layout */ body { display: flex; flex-direction: column; justify-content: center; align-items: center; margin: 0; height: 100vh; } div.title, iframe { width: 100vw; height: 20vh; border: none; } div.title { display: flex; flex-direction: column; justify-content: center; align-items: center; } h1, p { margin: 0; width: 24em; } /* text style */ h1, input, p { font-family: monospace; font-size: 3em; } input { color: #333; border: 3px solid #999; padding: 1em; } input:focus { border-color: #333; outline: none; } input::placeholder { color: #999; opacity: 1; } /* test results */ body.test { background-color: #666; color: #fff; } body.ok { background-color: #090; color: #fff; } body.ko { background-color: #900; color: #fff; } body > p { display: none; } body > p.title, body.test > p.test, body.ok > p.ok, body.ko > p.ko { display: block; } websockets-15.0.1/experiments/authentication/test.html000066400000000000000000000007651476212450300232120ustar00rootroot00000000000000 WebSocket Authentication

WebSocket Authentication

websockets-15.0.1/experiments/authentication/test.js000066400000000000000000000002051476212450300226470ustar00rootroot00000000000000var token = document.body.dataset.token; const params = new URLSearchParams(window.location.search); var user = params.get("user"); websockets-15.0.1/experiments/authentication/user_info.html000066400000000000000000000007011476212450300242120ustar00rootroot00000000000000 User information | WebSocket Authentication

[??] User information

[OK] User information

[KO] User information

websockets-15.0.1/experiments/authentication/user_info.js000066400000000000000000000005271476212450300236700ustar00rootroot00000000000000window.addEventListener("DOMContentLoaded", () => { const uri = `ws://${user}:${token}@localhost:8004/`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { // event.data is expected to be "Hello !" websocket.send(`Goodbye ${data.slice(6, -1)}.`); }; runTest(websocket); }); websockets-15.0.1/experiments/broadcast/000077500000000000000000000000001476212450300202605ustar00rootroot00000000000000websockets-15.0.1/experiments/broadcast/clients.py000066400000000000000000000030621476212450300222740ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import statistics import sys import time from websockets.asyncio.client import connect LATENCIES = {} async def log_latency(interval): while True: await asyncio.sleep(interval) p = statistics.quantiles(LATENCIES.values(), n=100) print(f"clients = {len(LATENCIES)}") print( f"p50 = {p[49] / 1e6:.1f}ms, " f"p95 = {p[94] / 1e6:.1f}ms, " f"p99 = {p[98] / 1e6:.1f}ms" ) print() async def client(): try: async with connect( "ws://localhost:8765", ping_timeout=None, ) as websocket: async for msg in websocket: client_time = time.time_ns() server_time = int(msg[:19].decode()) LATENCIES[websocket] = client_time - server_time except Exception as exc: print(exc) async def main(count, interval): asyncio.create_task(log_latency(interval)) clients = [] for _ in range(count): clients.append(asyncio.create_task(client())) await asyncio.sleep(0.001) # 1ms between each connection await asyncio.wait(clients) if __name__ == "__main__": try: count = int(sys.argv[1]) interval = float(sys.argv[2]) except Exception as exc: print(f"Usage: {sys.argv[0]} count interval") print(" Connect clients e.g. 1000") print(" Report latency every seconds e.g. 1") print() print(exc) else: asyncio.run(main(count, interval)) websockets-15.0.1/experiments/broadcast/server.py000066400000000000000000000112011476212450300221330ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import functools import os import sys import time from websockets.asyncio.server import broadcast, serve from websockets.exceptions import ConnectionClosed CLIENTS = set() async def send(websocket, message): try: await websocket.send(message) except ConnectionClosed: pass async def relay(queue, websocket): while True: message = await queue.get() await websocket.send(message) class PubSub: def __init__(self): self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): waiter = self.waiter self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): waiter = self.waiter while True: value, waiter = await waiter yield value __aiter__ = subscribe async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) try: await websocket.wait_closed() finally: CLIENTS.remove(websocket) elif method == "queue": queue = asyncio.Queue() relay_task = asyncio.create_task(relay(queue, websocket)) CLIENTS.add(queue) try: await websocket.wait_closed() finally: CLIENTS.remove(queue) relay_task.cancel() elif method == "pubsub": global PUBSUB async for message in PUBSUB: await websocket.send(message) else: raise NotImplementedError(f"unsupported method: {method}") async def broadcast_messages(method, size, delay): """Broadcast messages at regular intervals.""" if method == "pubsub": global PUBSUB PUBSUB = PubSub() load_average = 0 time_average = 0 pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() await asyncio.sleep(delay) while True: print(f"clients = {len(CLIENTS)}") pc0, pt0 = time.perf_counter_ns(), time.process_time_ns() load_average = 0.9 * load_average + 0.1 * (pt0 - pt1) / (pc0 - pc1) print( f"load = {(pt0 - pt1) / (pc0 - pc1) * 100:.1f}% / " f"average = {load_average * 100:.1f}%, " f"late = {(pc0 - pc1 - delay * 1e9) / 1e6:.1f} ms" ) pc1, pt1 = pc0, pt0 assert size > 20 message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) if method == "default": broadcast(CLIENTS, message) elif method == "naive": # Since the loop can yield control, make a copy of CLIENTS # to avoid: RuntimeError: Set changed size during iteration for websocket in CLIENTS.copy(): await send(websocket, message) elif method == "task": for websocket in CLIENTS: asyncio.create_task(send(websocket, message)) elif method == "wait": if CLIENTS: # asyncio.wait doesn't accept an empty list await asyncio.wait( [ asyncio.create_task(send(websocket, message)) for websocket in CLIENTS ] ) elif method == "queue": for queue in CLIENTS: queue.put_nowait(message) elif method == "pubsub": PUBSUB.publish(message) else: raise NotImplementedError(f"unsupported method: {method}") pc2 = time.perf_counter_ns() wait = delay + (pc1 - pc2) / 1e9 time_average = 0.9 * time_average + 0.1 * (pc2 - pc1) print( f"broadcast = {(pc2 - pc1) / 1e6:.1f}ms / " f"average = {time_average / 1e6:.1f}ms, " f"wait = {wait * 1e3:.1f}ms" ) await asyncio.sleep(wait) print() async def main(method, size, delay): async with serve( functools.partial(handler, method=method), "localhost", 8765, compression=None, ping_timeout=None, ): await broadcast_messages(method, size, delay) if __name__ == "__main__": try: method = sys.argv[1] assert method in ["default", "naive", "task", "wait", "queue", "pubsub"] size = int(sys.argv[2]) delay = float(sys.argv[3]) except Exception as exc: print(f"Usage: {sys.argv[0]} method size delay") print(" Start a server broadcasting messages with e.g. naive") print(" Send a payload of bytes every seconds") print() print(exc) else: asyncio.run(main(method, size, delay)) websockets-15.0.1/experiments/compression/000077500000000000000000000000001476212450300206575ustar00rootroot00000000000000websockets-15.0.1/experiments/compression/benchmark.py000066400000000000000000000063151476212450300231700ustar00rootroot00000000000000#!/usr/bin/env python import collections import pathlib import sys import time import zlib REPEAT = 10 WB, ML = 12, 5 # defaults used as a reference def benchmark(data): size = collections.defaultdict(dict) duration = collections.defaultdict(dict) for wbits in range(9, 16): for memLevel in range(1, 10): encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) encoded = [] print(f"Compressing {REPEAT} times with {wbits=} and {memLevel=}") t0 = time.perf_counter() for _ in range(REPEAT): for item in data: # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): item = item[:-4] encoded.append(item) t1 = time.perf_counter() size[wbits][memLevel] = sum(len(item) for item in encoded) / REPEAT duration[wbits][memLevel] = (t1 - t0) / REPEAT raw_size = sum(len(item) for item in data) print("=" * 79) print("Compression ratio") print("=" * 79) print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) for wbits in range(9, 16): print( "\t".join( [str(wbits)] + [ f"{100 * (1 - size[wbits][memLevel] / raw_size):.1f}%" for memLevel in range(1, 10) ] ) ) print("=" * 79) print() print("=" * 79) print("CPU time") print("=" * 79) print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) for wbits in range(9, 16): print( "\t".join( [str(wbits)] + [ f"{1000 * duration[wbits][memLevel]:.1f}ms" for memLevel in range(1, 10) ] ) ) print("=" * 79) print() print("=" * 79) print(f"Size vs. {WB} \\ {ML}") print("=" * 79) print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) for wbits in range(9, 16): print( "\t".join( [str(wbits)] + [ f"{100 * (size[wbits][memLevel] / size[WB][ML] - 1):.1f}%" for memLevel in range(1, 10) ] ) ) print("=" * 79) print() print("=" * 79) print(f"Time vs. {WB} \\ {ML}") print("=" * 79) print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) for wbits in range(9, 16): print( "\t".join( [str(wbits)] + [ f"{100 * (duration[wbits][memLevel] / duration[WB][ML] - 1):.1f}%" for memLevel in range(1, 10) ] ) ) print("=" * 79) print() def main(corpus): data = [file.read_bytes() for file in corpus.iterdir()] benchmark(data) if __name__ == "__main__": if len(sys.argv) < 2: print(f"Usage: {sys.argv[0]} [directory]") sys.exit(2) main(pathlib.Path(sys.argv[1])) websockets-15.0.1/experiments/compression/client.py000066400000000000000000000027111476212450300225100ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import statistics import tracemalloc from websockets.asyncio.client import connect from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory CLIENTS = 20 INTERVAL = 1 / 10 # seconds WB, ML = 12, 5 MEM_SIZE = [] async def client(num): # Space out connections to make them sequential. await asyncio.sleep(num * INTERVAL) tracemalloc.start() async with connect( "ws://localhost:8765", extensions=[ ClientPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, ) ], ) as ws: await ws.send("hello") await ws.recv() await ws.send(b"hello") await ws.recv() MEM_SIZE.append(tracemalloc.get_traced_memory()[0]) tracemalloc.stop() # Hold connection open until the end of the test. await asyncio.sleep((CLIENTS + 1 - num) * INTERVAL) async def clients(): # Start one more client than necessary because we will ignore # non-representative results from the first connection. await asyncio.gather(*[client(num) for num in range(CLIENTS + 1)]) asyncio.run(clients()) # First connection incurs non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") print(f"σ = {statistics.stdev(MEM_SIZE) / 1024:.1f} KiB") websockets-15.0.1/experiments/compression/corpus.py000066400000000000000000000023551476212450300225510ustar00rootroot00000000000000#!/usr/bin/env python import getpass import json import pathlib import subprocess import sys import time def github_commits(): OAUTH_TOKEN = getpass.getpass("OAuth Token? ") COMMIT_API = ( f'curl -H "Authorization: token {OAUTH_TOKEN}" ' f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" ) commits = [] head = subprocess.check_output( "git rev-parse origin/main", shell=True, text=True, ).strip() todo = [head] seen = set() while todo: sha = todo.pop(0) commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) commits.append(commit) seen.add(sha) for parent in json.loads(commit)["parents"]: sha = parent["sha"] if sha not in seen and sha not in todo: todo.append(sha) time.sleep(1) # rate throttling return commits def main(corpus): data = github_commits() for num, content in enumerate(reversed(data)): (corpus / f"{num:04d}.json").write_bytes(content) if __name__ == "__main__": if len(sys.argv) < 2: print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) websockets-15.0.1/experiments/compression/server.py000066400000000000000000000026621476212450300225450ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import os import signal import statistics import tracemalloc from websockets.asyncio.server import serve from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory CLIENTS = 20 INTERVAL = 1 / 10 # seconds WB, ML = 12, 5 MEM_SIZE = [] async def handler(ws): msg = await ws.recv() await ws.send(msg) msg = await ws.recv() await ws.send(msg) MEM_SIZE.append(tracemalloc.get_traced_memory()[0]) tracemalloc.stop() tracemalloc.start() # Hold connection open until the end of the test. await asyncio.sleep(CLIENTS * INTERVAL) async def server(): async with serve( handler, "localhost", 8765, extensions=[ ServerPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, ) ], ) as server: print("Stop the server with:") print(f"kill -TERM {os.getpid()}") print() loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, server.close) tracemalloc.start() await server.wait_closed() asyncio.run(server()) # First connection incurs non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") print(f"σ = {statistics.stdev(MEM_SIZE) / 1024:.1f} KiB") websockets-15.0.1/experiments/json_log_formatter.py000066400000000000000000000020011476212450300225560ustar00rootroot00000000000000import datetime import json import logging class JSONFormatter(logging.Formatter): """ Render logs as JSON. To add details to a log record, store them in a ``event_data`` custom attribute. This dict is merged into the event. """ def __init__(self): pass # override logging.Formatter constructor def format(self, record): event = { "timestamp": self.getTimestamp(record.created), "message": record.getMessage(), "level": record.levelname, "logger": record.name, } event_data = getattr(record, "event_data", None) if event_data: event.update(event_data) if record.exc_info: event["exc_info"] = self.formatException(record.exc_info) if record.stack_info: event["stack_info"] = self.formatStack(record.stack_info) return json.dumps(event) def getTimestamp(self, created): return datetime.datetime.utcfromtimestamp(created).isoformat() websockets-15.0.1/experiments/optimization/000077500000000000000000000000001476212450300210445ustar00rootroot00000000000000websockets-15.0.1/experiments/optimization/parse_frames.py000066400000000000000000000062441476212450300240730ustar00rootroot00000000000000"""Benchark parsing WebSocket frames.""" import subprocess import sys import timeit from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.frames import Frame, Opcode from websockets.streams import StreamReader # 256kB of text, compressible by about 70%. text = subprocess.check_output(["git", "log", "8dd8e410"], text=True) def get_frame(size): repeat, remainder = divmod(size, 256 * 1024) payload = repeat * text + text[:remainder] return Frame(Opcode.TEXT, payload.encode(), True) def parse_frame(data, count, mask, extensions): reader = StreamReader() for _ in range(count): reader.feed_data(data) parser = Frame.parse( reader.read_exact, mask=mask, extensions=extensions, ) try: next(parser) except StopIteration: pass else: raise AssertionError("parser should return frame") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" def run_benchmark(size, count, compression=False, number=100): if compression: extensions = [PerMessageDeflate(True, True, 12, 12, {"memLevel": 5})] else: extensions = [] globals = { "get_frame": get_frame, "parse_frame": parse_frame, "extensions": extensions, } sppf = ( min( timeit.repeat( f"parse_frame(data, {count}, mask=True, extensions=extensions)", f"data = get_frame({size})" f".serialize(mask=True, extensions=extensions)", number=number, globals=globals, ) ) / number / count * 1_000_000 ) cppf = ( min( timeit.repeat( f"parse_frame(data, {count}, mask=False, extensions=extensions)", f"data = get_frame({size})" f".serialize(mask=False, extensions=extensions)", number=number, globals=globals, ) ) / number / count * 1_000_000 ) print(f"{size}\t{compression}\t{sppf:.2f}\t{cppf:.2f}") if __name__ == "__main__": print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") print(file=sys.stderr) print("size\tcompression\tserver\tclient") run_benchmark(size=8, count=1000, compression=False) run_benchmark(size=60, count=1000, compression=False) run_benchmark(size=500, count=1000, compression=False) run_benchmark(size=4_000, count=1000, compression=False) run_benchmark(size=30_000, count=200, compression=False) run_benchmark(size=250_000, count=100, compression=False) run_benchmark(size=2_000_000, count=20, compression=False) run_benchmark(size=8, count=1000, compression=True) run_benchmark(size=60, count=1000, compression=True) run_benchmark(size=500, count=200, compression=True) run_benchmark(size=4_000, count=100, compression=True) run_benchmark(size=30_000, count=20, compression=True) run_benchmark(size=250_000, count=10, compression=True) websockets-15.0.1/experiments/optimization/parse_handshake.py000066400000000000000000000061651476212450300245460ustar00rootroot00000000000000"""Benchark parsing WebSocket handshake requests.""" # The parser for responses is designed similarly and should perform similarly. import sys import timeit from websockets.http11 import Request from websockets.streams import StreamReader CHROME_HANDSHAKE = ( b"GET / HTTP/1.1\r\n" b"Host: localhost:5678\r\n" b"Connection: Upgrade\r\n" b"Pragma: no-cache\r\n" b"Cache-Control: no-cache\r\n" b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36\r\n" b"Upgrade: websocket\r\n" b"Origin: null\r\n" b"Sec-WebSocket-Version: 13\r\n" b"Accept-Encoding: gzip, deflate, br\r\n" b"Accept-Language: en-GB,en;q=0.9,en-US;q=0.8,fr;q=0.7\r\n" b"Sec-WebSocket-Key: ebkySAl+8+e6l5pRKTMkyQ==\r\n" b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" b"\r\n" ) FIREFOX_HANDSHAKE = ( b"GET / HTTP/1.1\r\n" b"Host: localhost:5678\r\n" b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) " b"Gecko/20100101 Firefox/111.0\r\n" b"Accept: */*\r\n" b"Accept-Language: en-US,en;q=0.7,fr-FR;q=0.3\r\n" b"Accept-Encoding: gzip, deflate, br\r\n" b"Sec-WebSocket-Version: 13\r\n" b"Origin: null\r\n" b"Sec-WebSocket-Extensions: permessage-deflate\r\n" b"Sec-WebSocket-Key: 1PuS+hnb+0AXsL7z2hNAhw==\r\n" b"Connection: keep-alive, Upgrade\r\n" b"Sec-Fetch-Dest: websocket\r\n" b"Sec-Fetch-Mode: websocket\r\n" b"Sec-Fetch-Site: cross-site\r\n" b"Pragma: no-cache\r\n" b"Cache-Control: no-cache\r\n" b"Upgrade: websocket\r\n" b"\r\n" ) WEBSOCKETS_HANDSHAKE = ( b"GET / HTTP/1.1\r\n" b"Host: localhost:8765\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: 9c55e0/siQ6tJPCs/QR8ZA==\r\n" b"Sec-WebSocket-Version: 13\r\n" b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" b"User-Agent: Python/3.11 websockets/11.0\r\n" b"\r\n" ) def parse_handshake(handshake): reader = StreamReader() reader.feed_data(handshake) parser = Request.parse(reader.read_line) try: next(parser) except StopIteration: pass else: raise AssertionError("parser should return request") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" def run_benchmark(name, handshake, number=10000): ph = ( min( timeit.repeat( "parse_handshake(handshake)", number=number, globals={"parse_handshake": parse_handshake, "handshake": handshake}, ) ) / number * 1_000_000 ) print(f"{name}\t{len(handshake)}\t{ph:.1f}") if __name__ == "__main__": print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") print(file=sys.stderr) print("client\tsize\ttime") run_benchmark("Chrome", CHROME_HANDSHAKE) run_benchmark("Firefox", FIREFOX_HANDSHAKE) run_benchmark("websockets", WEBSOCKETS_HANDSHAKE) websockets-15.0.1/experiments/optimization/streams.py000066400000000000000000000213341476212450300230770ustar00rootroot00000000000000""" Benchmark two possible implementations of a stream reader. The difference lies in the data structure that buffers incoming data: * ``ByteArrayStreamReader`` uses a ``bytearray``; * ``BytesDequeStreamReader`` uses a ``deque[bytes]``. ``ByteArrayStreamReader`` is faster for streaming small frames, which is the standard use case of websockets, likely due to its simple implementation and to ``bytearray`` being fast at appending data and removing data at the front (https://hg.python.org/cpython/rev/499a96611baa). ``BytesDequeStreamReader`` is faster for large frames and for bursts, likely because it copies payloads only once, while ``ByteArrayStreamReader`` copies them twice. """ import collections import os import timeit # Implementations class ByteArrayStreamReader: def __init__(self): self.buffer = bytearray() self.eof = False def readline(self): n = 0 # number of bytes to read p = 0 # number of bytes without a newline while True: n = self.buffer.find(b"\n", p) + 1 if n > 0: break p = len(self.buffer) yield r = self.buffer[:n] del self.buffer[:n] return r def readexactly(self, n): assert n >= 0 while len(self.buffer) < n: yield r = self.buffer[:n] del self.buffer[:n] return r def feed_data(self, data): self.buffer += data def feed_eof(self): self.eof = True def at_eof(self): return self.eof and not self.buffer class BytesDequeStreamReader: def __init__(self): self.buffer = collections.deque() self.eof = False def readline(self): b = [] while True: # Read next chunk while True: try: c = self.buffer.popleft() except IndexError: yield else: break # Handle chunk n = c.find(b"\n") + 1 if n == len(c): # Read exactly enough data b.append(c) break elif n > 0: # Read too much data b.append(c[:n]) self.buffer.appendleft(c[n:]) break else: # n == 0 # Need to read more data b.append(c) return b"".join(b) def readexactly(self, n): if n == 0: return b"" b = [] while True: # Read next chunk while True: try: c = self.buffer.popleft() except IndexError: yield else: break # Handle chunk n -= len(c) if n == 0: # Read exactly enough data b.append(c) break elif n < 0: # Read too much data b.append(c[:n]) self.buffer.appendleft(c[n:]) break else: # n >= 0 # Need to read more data b.append(c) return b"".join(b) def feed_data(self, data): self.buffer.append(data) def feed_eof(self): self.eof = True def at_eof(self): return self.eof and not self.buffer # Tests class Protocol: def __init__(self, StreamReader): self.reader = StreamReader() self.events = [] # Start parser coroutine self.parser = self.run_parser() next(self.parser) def run_parser(self): while True: frame = yield from self.reader.readexactly(2) self.events.append(frame) frame = yield from self.reader.readline() self.events.append(frame) def data_received(self, data): self.reader.feed_data(data) next(self.parser) # run parser until more data is needed events, self.events = self.events, [] return events def run_test(StreamReader): proto = Protocol(StreamReader) actual = proto.data_received(b"a") expected = [] assert actual == expected, f"{actual} != {expected}" actual = proto.data_received(b"b") expected = [b"ab"] assert actual == expected, f"{actual} != {expected}" actual = proto.data_received(b"c") expected = [] assert actual == expected, f"{actual} != {expected}" actual = proto.data_received(b"\n") expected = [b"c\n"] assert actual == expected, f"{actual} != {expected}" actual = proto.data_received(b"efghi\njklmn") expected = [b"ef", b"ghi\n", b"jk"] assert actual == expected, f"{actual} != {expected}" # Benchmarks def get_frame_packets(size, packet_size=None): if size < 126: frame = bytes([138, size]) elif size < 65536: frame = bytes([138, 126]) + bytes(divmod(size, 256)) else: size1, size2 = divmod(size, 65536) frame = ( bytes([138, 127]) + bytes(divmod(size1, 256)) + bytes(divmod(size2, 256)) ) frame += os.urandom(size) if packet_size is None: return [frame] else: packets = [] while frame: packets.append(frame[:packet_size]) frame = frame[packet_size:] return packets def benchmark_stream(StreamReader, packets, size, count): reader = StreamReader() for _ in range(count): for packet in packets: reader.feed_data(packet) yield from reader.readexactly(2) if size >= 65536: yield from reader.readexactly(4) elif size >= 126: yield from reader.readexactly(2) yield from reader.readexactly(size) reader.feed_eof() assert reader.at_eof() def benchmark_burst(StreamReader, packets, size, count): reader = StreamReader() for _ in range(count): for packet in packets: reader.feed_data(packet) reader.feed_eof() for _ in range(count): yield from reader.readexactly(2) if size >= 65536: yield from reader.readexactly(4) elif size >= 126: yield from reader.readexactly(2) yield from reader.readexactly(size) assert reader.at_eof() def run_benchmark(size, count, packet_size=None, number=1000): stmt = f"list(benchmark(StreamReader, packets, {size}, {count}))" setup = f"packets = get_frame_packets({size}, {packet_size})" context = globals() context["StreamReader"] = context["ByteArrayStreamReader"] context["benchmark"] = context["benchmark_stream"] bas = min(timeit.repeat(stmt, setup, number=number, globals=context)) context["benchmark"] = context["benchmark_burst"] bab = min(timeit.repeat(stmt, setup, number=number, globals=context)) context["StreamReader"] = context["BytesDequeStreamReader"] context["benchmark"] = context["benchmark_stream"] bds = min(timeit.repeat(stmt, setup, number=number, globals=context)) context["benchmark"] = context["benchmark_burst"] bdb = min(timeit.repeat(stmt, setup, number=number, globals=context)) print( f"Frame size = {size} bytes, " f"frame count = {count}, " f"packet size = {packet_size}" ) print(f"* ByteArrayStreamReader (stream): {bas / number * 1_000_000:.1f}µs") print( f"* BytesDequeStreamReader (stream): " f"{bds / number * 1_000_000:.1f}µs ({(bds / bas - 1) * 100:+.1f}%)" ) print(f"* ByteArrayStreamReader (burst): {bab / number * 1_000_000:.1f}µs") print( f"* BytesDequeStreamReader (burst): " f"{bdb / number * 1_000_000:.1f}µs ({(bdb / bab - 1) * 100:+.1f}%)" ) print() if __name__ == "__main__": run_test(ByteArrayStreamReader) run_test(BytesDequeStreamReader) run_benchmark(size=8, count=1000) run_benchmark(size=60, count=1000) run_benchmark(size=500, count=500) run_benchmark(size=4_000, count=200) run_benchmark(size=30_000, count=100) run_benchmark(size=250_000, count=50) run_benchmark(size=2_000_000, count=20) run_benchmark(size=4_000, count=200, packet_size=1024) run_benchmark(size=30_000, count=100, packet_size=1024) run_benchmark(size=250_000, count=50, packet_size=1024) run_benchmark(size=2_000_000, count=20, packet_size=1024) run_benchmark(size=30_000, count=100, packet_size=4096) run_benchmark(size=250_000, count=50, packet_size=4096) run_benchmark(size=2_000_000, count=20, packet_size=4096) run_benchmark(size=30_000, count=100, packet_size=16384) run_benchmark(size=250_000, count=50, packet_size=16384) run_benchmark(size=2_000_000, count=20, packet_size=16384) run_benchmark(size=250_000, count=50, packet_size=65536) run_benchmark(size=2_000_000, count=20, packet_size=65536) websockets-15.0.1/experiments/profiling/000077500000000000000000000000001476212450300203075ustar00rootroot00000000000000websockets-15.0.1/experiments/profiling/compression.py000066400000000000000000000031061476212450300232220ustar00rootroot00000000000000#!/usr/bin/env python """ Profile the permessage-deflate extension. Usage:: $ pip install line_profiler $ python experiments/compression/corpus.py experiments/compression/corpus $ PYTHONPATH=src python -m kernprof \ --line-by-line \ --prof-mod src/websockets/extensions/permessage_deflate.py \ --view \ experiments/profiling/compression.py experiments/compression/corpus 12 5 6 """ import pathlib import sys from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.frames import OP_TEXT, Frame def compress_and_decompress(corpus, max_window_bits, memory_level, level): extension = PerMessageDeflate( remote_no_context_takeover=False, local_no_context_takeover=False, remote_max_window_bits=max_window_bits, local_max_window_bits=max_window_bits, compress_settings={"memLevel": memory_level, "level": level}, ) for data in corpus: frame = Frame(OP_TEXT, data) frame = extension.encode(frame) frame = extension.decode(frame) if __name__ == "__main__": if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): print(f"Usage: {sys.argv[0]} [] []") corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 compress_and_decompress(corpus, max_window_bits, memory_level, level) websockets-15.0.1/experiments/routing.py000066400000000000000000000106561476212450300203670ustar00rootroot00000000000000#!/usr/bin/env python import asyncio import datetime import time import zoneinfo from websockets.asyncio.router import route from websockets.exceptions import ConnectionClosed from werkzeug.routing import BaseConverter, Map, Rule, ValidationError async def clock(websocket, tzinfo): """Send the current time in the given timezone every second.""" loop = asyncio.get_running_loop() loop_offset = (loop.time() - time.time()) % 1 try: while True: # Sleep until the next second according to the wall clock. await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) now = datetime.datetime.now(tzinfo).replace(microsecond=0) await websocket.send(now.isoformat()) except ConnectionClosed: return async def alarm(websocket, alarm_at, tzinfo): """Send the alarm time in the given timezone when it is reached.""" alarm_at = alarm_at.replace(tzinfo=tzinfo) now = datetime.datetime.now(tz=datetime.timezone.utc) try: async with asyncio.timeout((alarm_at - now).total_seconds()): await websocket.wait_closed() except asyncio.TimeoutError: try: await websocket.send(alarm_at.isoformat()) except ConnectionClosed: return async def timer(websocket, alarm_after): """Send the remaining time until the alarm time every second.""" alarm_at = datetime.datetime.now(tz=datetime.timezone.utc) + alarm_after loop = asyncio.get_running_loop() loop_offset = (loop.time() - time.time() + alarm_at.timestamp()) % 1 try: while alarm_after.total_seconds() > 0: # Sleep until the next second as a delta to the alarm time. await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) alarm_after = alarm_at - datetime.datetime.now(tz=datetime.timezone.utc) # Round up to the next second. alarm_after += datetime.timedelta( seconds=1, microseconds=-alarm_after.microseconds, ) await websocket.send(format_timedelta(alarm_after)) except ConnectionClosed: return class ZoneInfoConverter(BaseConverter): regex = r"[A-Za-z0-9_/+-]+" def to_python(self, value): try: return zoneinfo.ZoneInfo(value) except zoneinfo.ZoneInfoNotFoundError: raise ValidationError def to_url(self, value): return value.key class DateTimeConverter(BaseConverter): regex = r"[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3})?" def to_python(self, value): try: return datetime.datetime.fromisoformat(value) except ValueError: raise ValidationError def to_url(self, value): return value.isoformat() class TimeDeltaConverter(BaseConverter): regex = r"[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3}(?:[0-9]{3})?)?" def to_python(self, value): return datetime.timedelta( hours=int(value[0:2]), minutes=int(value[3:5]), seconds=int(value[6:8]), milliseconds=int(value[9:12]) if len(value) == 12 else 0, microseconds=int(value[9:15]) if len(value) == 15 else 0, ) def to_url(self, value): return format_timedelta(value) def format_timedelta(delta): assert 0 <= delta.seconds < 86400 hours = delta.seconds // 3600 minutes = (delta.seconds % 3600) // 60 seconds = delta.seconds % 60 if delta.microseconds: return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{delta.microseconds:06d}" else: return f"{hours:02d}:{minutes:02d}:{seconds:02d}" url_map = Map( [ Rule( "/", redirect_to="/clock", ), Rule( "/clock", defaults={"tzinfo": datetime.timezone.utc}, endpoint=clock, ), Rule( "/clock/", endpoint=clock, ), Rule( "/alarm//", endpoint=alarm, ), Rule( "/timer/", endpoint=timer, ), ], converters={ "tzinfo": ZoneInfoConverter, "datetime": DateTimeConverter, "timedelta": TimeDeltaConverter, }, ) async def main(): async with route(url_map, "localhost", 8888) as server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) websockets-15.0.1/fuzzing/000077500000000000000000000000001476212450300154475ustar00rootroot00000000000000websockets-15.0.1/fuzzing/fuzz_http11_request_parser.py000066400000000000000000000016611476212450300233500ustar00rootroot00000000000000import sys import atheris with atheris.instrument_imports(): from websockets.exceptions import SecurityError from websockets.http11 import Request from websockets.streams import StreamReader def test_one_input(data): reader = StreamReader() reader.feed_data(data) reader.feed_eof() parser = Request.parse( reader.read_line, ) try: next(parser) except StopIteration as exc: assert isinstance(exc.value, Request) return # input accepted except ( EOFError, # connection is closed without a full HTTP request SecurityError, # request exceeds a security limit ValueError, # request isn't well formatted ): return # input rejected with a documented exception raise RuntimeError("parsing didn't complete") def main(): atheris.Setup(sys.argv, test_one_input) atheris.Fuzz() if __name__ == "__main__": main() websockets-15.0.1/fuzzing/fuzz_http11_response_parser.py000066400000000000000000000020431476212450300235110ustar00rootroot00000000000000import sys import atheris with atheris.instrument_imports(): from websockets.exceptions import SecurityError from websockets.http11 import Response from websockets.streams import StreamReader def test_one_input(data): reader = StreamReader() reader.feed_data(data) reader.feed_eof() parser = Response.parse( reader.read_line, reader.read_exact, reader.read_to_eof, ) try: next(parser) except StopIteration as exc: assert isinstance(exc.value, Response) return # input accepted except ( EOFError, # connection is closed without a full HTTP response SecurityError, # response exceeds a security limit LookupError, # response isn't well formatted ValueError, # response isn't well formatted ): return # input rejected with a documented exception raise RuntimeError("parsing didn't complete") def main(): atheris.Setup(sys.argv, test_one_input) atheris.Fuzz() if __name__ == "__main__": main() websockets-15.0.1/fuzzing/fuzz_websocket_parser.py000066400000000000000000000024411476212450300224420ustar00rootroot00000000000000import sys import atheris with atheris.instrument_imports(): from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import Frame from websockets.streams import StreamReader def test_one_input(data): fdp = atheris.FuzzedDataProvider(data) mask = fdp.ConsumeBool() max_size_enabled = fdp.ConsumeBool() max_size = fdp.ConsumeInt(4) payload = fdp.ConsumeBytes(atheris.ALL_REMAINING) reader = StreamReader() reader.feed_data(payload) reader.feed_eof() parser = Frame.parse( reader.read_exact, mask=mask, max_size=max_size if max_size_enabled else None, ) try: next(parser) except StopIteration as exc: assert isinstance(exc.value, Frame) return # input accepted except ( EOFError, # connection is closed without a full WebSocket frame UnicodeDecodeError, # frame contains invalid UTF-8 PayloadTooBig, # frame's payload size exceeds ``max_size`` ProtocolError, # frame contains incorrect values ): return # input rejected with a documented exception raise RuntimeError("parsing didn't complete") def main(): atheris.Setup(sys.argv, test_one_input) atheris.Fuzz() if __name__ == "__main__": main() websockets-15.0.1/logo/000077500000000000000000000000001476212450300147135ustar00rootroot00000000000000websockets-15.0.1/logo/favicon.ico000066400000000000000000000124661476212450300170450ustar00rootroot00000000000000 h&  (  >-=v=`@BU=;;;B#JVEA=<;<RWNIEE%>;;;=*\VUQMN$KcEA=D`q9 ^]ZU$RdMIFgjj1i0k1Cfb_#[eUQNRt8wk2i0k19dJ]YUR|>sv:r7k3k2uj50fda]^Gk~@z>w<Ҧs@l2qzEh;pēNICBѳ@xɪwDʖSŒPMJCŻD͗UBƔPyÏMY̙f( @ <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<UUUUUUUUUB<<<;<>>>>>>F3A?>;;;;;;<<<<<<<<<<<<<<<======J4FDB?><;;;;;>->>>>>>>>>>>>><<<<<M5JGECB@=<;;;;==============UUUUR5NKIGECBBA'<;;;;;<;;;;xt9s8q6m3l3k1k19k1k1k1s5s5s5s5s5s5s5s5s5s5s5s5ÐOMJHGG^GG}?z=yo5o5o5o5u;u;u;u;u;u;u;u;u;u;u;u;ƑOÐOLJIG2GCB|?z=v:u9t7s5>s5s5s5s5s5}?}?}?}?}?}?}?}?}?}?}?}?ʕTtőOÐNMII귆GʵDC~@|?z=x GitHub social preview

Take a screenshot of this DOM node to make a PNG.

For 2x DPI screens.

preview @ 2x

For regular screens.

preview

websockets-15.0.1/logo/github-social-preview.png000066400000000000000000000672721476212450300216500ustar00rootroot00000000000000PNG  IHDRLBnIDATx /}W@@@@@@@@@@`````````@@@@@@@@@````` T?9VoֽH RG -n7@p$@3+Н3#0oϽs-U6v\Ӵr͎uG˖17+,<兯9ƳS]tϥ KVxYnִw(!ͻOU,=p\xnVsʟgWS~Oexڟ})0yjik ?qa^Սy Jk[;Uh@==תZVJ/崩:W~,rIK?y!!湲:(+z^?WGW `bg|݅:%(;rҚ/>k;g_D=?l?xNMsM 0h>pvvfk}T_~ϞY)gԷv*( ꣷG?šs/q=goC5WT_S}rė,8# 2ַ)(0J[wE_Ց}9~9 O|U@~'o6~rM^owegϦ։wܡM5@4X^S}<8s?rpN6m`[<y؇A{Cwu(7.ٳj!]9<ۯ|_2ve˕ 0nCל: gO_p멛 0iϙ+~X}uۯɪ+ 0}ݻ*p; ͝DUc @8S2CwccެUP_]=닩we$shiOo]} {W~M>_>Z֩FnAޕ:b~M״9Hie @X(!]+d37j`;t]H߻r8;~۟>PT V{Wv_{WG$]~߯Sؗ)t`r5,{wp5$yt`5wH]c8{WYV{W^@e3KtuciBБ{W_S}П)뎶uv+{ 0H=_ewCGm&|3yݑ.e@͝;;gϮ?tdbn]9T_L~u5V[ 0zkwޕԇ,>{P}ul(:~{W{od/V (^]9~%]='sqЁ`s{WV_S}Mt 0(hὫ0ػ#pkg:1+C`˲GU[ݾЁ 08u!r]W~ut P]9~]e ߻r:˗L.x0@4wUZtw`ɧ;]K 0pMa}\ z+bxT_޹(ZUG{WGٻޕsik_>βӁ d\ Ͻ~߻r8n:(#ޕWUND] h; o{YNK{Dy ٻ pw|5%t` b6w8Jۻr>߻Jݻ8>C=Ѷ7gс z 0ZU{W]Տbzi:~řt` 5B{W_{W3]8W_>с 0+jwL5:0D['GU+ï:B*`qng{W·_{W:[ןЁ j 0x}Ù@={ejotٟWf4Ӂ 0hlgï+ï+ïiO« ^ٻzv*_6:0DnKsKhB]Y~m{.H@`pnkwwh98~u8:0Db͝Б{W?tdb]>g?ƥҁ 08X\mqݻ {W_뽫?X]ࡣ˃}"Gow%侱ԍ>6k:rg{&S}݁[; 2 0R,f1'wtu(h\x-,{WZ]9~uS_Mj@`X EawYEՍmj\*uhۻ ~}:0DHw5yC~MS ނ_{W_~MFK `pU :xˉn5,ꞈtcj{W_(xim0{WG]}봲sy\xjɨylb@`NޕN ]=ZMSˋrt17Lձ8΋sr,`iNɀïܽ]vn[('ާ{@`0;83W6(W4^+ï{Wkb&p+Eػr>Z];+]kn8~cdS5/N(wʴ߿ʼ]ۢFFm֟=w|P}MҮaQ3{W<ػrh$tjޕWgM_5Yⷻ@x`S^]~ח768;l⇷wugϾ/ǯ~av8t`ѫ q_}M*[h0> ǽ+ï}Yˋwс@z-sgj~ߙ:=*t~eav]h:V=?k'o# qj`Mz>B;l]ze]9~-it~1n//Uܦ`hq8޻2ן?LuARBgebV>?s(]ïn}y>.G:.˻M]9ߕ?$Ɨ?Y w~ufI~9.T_{W~u.D`<~~e]Ou7uKB/ +ïc]ϳ3Ӂ@VW+i(ͫޕ7{WۯYW[ 7֜Б{Wև_~%5.NCKǨQwշЁ@T1a{ן)Oշ;fI{W6k·  ,\~ޕS'e)}գ{W:#y k`3]b*ï):;6]#w<.4x^ċ!|ȵϞM|dc=$;+qBzS}-,@\ٽo]޻xQJek-ҥ>tW[<@ɧoޕsU_?%JZokAՃ]9W_ WB ٻ7_&sVꥲ:8Nw~9(fMKLqkگ?ZTV+4?u>{^/pJwv߻ |q}FwFw$߁ &o믾?tQcwQw% `pjX{W]$~o.ޒс~e] 0ob*ͭ+S}گ/gM22{Wk]@WޕWW__'Q-Kٻr+M{W·_~}ˮ?GH~e~JϞ' ܻ  `TX꽫o߻r<ꫳ7}{^с%Gޕ `pT_{W~WS_vr:+sҒv{WB䏟:prlWޕ `,{Wɶ{WڗOvwTJ'J{Wy-Mu@yjwO;N!w&ۻgڡ\`0A]9~%] T_??| ֎Wb] =u@+r#jT_۟Oqjo%Y]>ZdC]@E2,;G|":ps[yI]@cvi*U W}u\p\yc̼$ï+ؔ5U-E@ GJd]usJ_zVRn]- 2K(}T_@_ }d>&;+9trP@\]<ȇ[t>ۏZ.״_yyڪ\`p]ydLuk&w$tW8{](-'>tdwt5t[B3/]\ޕ=wP輥ݪ.AծQKUUX5RM9@_Qժ_뽫߸wS5Ix2#i#[=]z]G~:BCR Sug3))fj>zZl= ޕCG{W߄dt~MnVAܩi5G'q_P^7t \}M8UNv]{W_S}iɟ$t1{]Tت뛄N(P|UtuwQ埩,%Ք|+ï+kگδM)xϞ?Zӫ.+kwH 09WU|2~cs~ISZ{W·חF)xork[gWeY{ @OOkrgj9$ih]%޻ |5חx_tQsxQqYM'Rn 0TW_N;~dG^G{W;Cwe&SMxsԜݞi+2酖?ferPh>A5^bWޕW'WW__l0<6փ;n zuj qMe (v7&o 6#5J׹hC6([jK-R ~x\}MLS ;J<޻8t{WwurIB9_2PVV/W_tnOVeJ6Eޕ״_uMYszTuUi%Km~I.8zӾV_iL~M&Ձw)R.9}qT(a:b״`ɭ7U$!r,:~oڬSw1{n?`pZ<_}uT;p %swix{W~LOfaI4t`E%٧Jl~*h@RݝYq*:+uvL^wb+o*dͯ&}K%| #Wg$T.ِ}lwV (Tņi_kR6^s(ػ ~}WgB2N^Ё_.q ?ݤ\d_zX(4WRun&e$t֧߻>+S} &}o wή(kMmYn3Δ uwj(:,Fux߲ޭ7{W;pdO/7au܁?]-I{^mo_VrJ~ wݨW7xӢϿdx/܊Q%k_{Wn~Mq+#8kgGWqKR-@cӾU_[kSFޕ_}M "[3԰H=i9Z`PO7r:|My-voဇBw{W <'ƟJ7ܿ&G##fHӾ6WGTK:_Ϟqo]Vp%ROil螹+蹻U`@zOWgj"l?tdb]}?ܗ_Eejwj'ʹn 9gU`i"&ޖp^z&G-d~Mٻ]-tO683uhfpoQ 0`:n<+MXymiYF_+s{M7c߁gИ)>3y]]V^[z&+ï}{Wo-H \Z^j׊{׾|]xyƹ{Ϟ+ܮEJE)~  )i_GH^zFޕ~JGW_+^w&oVP'}׾oVdUB] A5yBF.p>R"(t`wZ}uT}M$tം 8zTM@,Nh|+y{a y׃~GHp8K+`8W:/cVDkՙ됡V_:>S@vЁ?k/ՐM*(( QOj=UUVU/UUUY궩}Hu+*CyQ؁sN((˺77p7oU. xWF5ןd:p, ۣZN_[o:T_!U)To.OT_4 BGU~`ʶ.B*ʽk2VTzZ Q񝜧}ET.;pWT@tEġ8V_}\-wo#i_U}M:Y\l(g.xʶ>/ 3333; e0333örNq쬛ݣgy_v|֟M%}]^_\.Hsڷo8SGp=@dʘwh|R,ؽ8Ӿ7~5}uZWpPb `@[vϩ#^髻@@^8Ӿ;Cg>[ u@H׭cz-4p }/ \mzg:$3;}*J_ݵM{ik.i7zR}п_2oBkIf6wK.I_7p6aP;}|V:}u2m4NFu~t;p 0TO_ 洯n?8 07H]-q&caP}K-$ xPti+n|X{Iot۩eߢϪ{şnFށ<[O_ށ&֢g>+j Ww*IiW矾n_ ԇ}C|`@7O_8s?},$mzZ|`yVWW*}uWcxA{z%黭8 {'*k=}#Wg<둾_u.¨eoRxE i_5-WPfoH Bѧ}N_= Ox^*s$wU5o PK=7Ӿ+#},V:둾~:2}c7Wwo`С78k$}u6> i_ukG~ZϷ> kM:}uInNo3AH߈Fp3J z@R #}uWwJN_ՕLr@ot⃉-|^99 M<}=O[7~zӾ1Je0z=둾;Wo2VOJ߈}k燙U|`@DyՕJ_[o`i_59듾n7hy`ȇ7Q']'Pɦy՝y4( 9aM`oY~Q#Vu]@@ui_ukIr6 ié_|ܩ4WӾ+k;liVOޛ|ypfa֛:둾8=;Ӿvk{9 >+Pe$p@Fׯ,]m:/W>L׿`kC;듾ni`5pZa.t @2:둾n;δ*3iͫ}N{b1pjdh}=;,ҷ^] Hҳ"i__y}WL\Eϙ*ϒ 4Wg2}46}ڗu~wo)Ͼ-\zupҍ7,tь OXwUGXu$@ vGm3}\A4p@}}Ww!w`m N-'IV 6pX}#mץV `{V 1h@Nz) lzyu& UfvƲ^g@P۪ͶYZטM;@$o[\@nJۿmX_@ho黹S4^@hҷ zWט e վiouWpUh69vJl;N2k`0@ ܼ͖vMs~@6Hl;k mvf+4pQV 5ٚ4pa6`T0@#}Hu=r諔ꎬr(\mk~/hj >Wїg?  \@vmNo^;ת֦mm m,}m1}7 \~T6g_ ` <| }I\G^"_Z׹1 5p mݬߚ뾸 2&bkcy @P~Įm$}Y聻v6 Njz7 x[yn}`* Qڛվob'vROӷ|W`vվ =uN,~k6p>MhҷGr-x\sՙ轶{֘|g.*8ޓվM;W}/Gu>`,t5x3:&YN?8`roub[xZ{o퍽6KY-W}F"@h=z$}-OxwOwAu_<?/@ҳ57}7FiodH]T}> @e1g֯? $wWl H\ e{U 0x@Z#ߋ ׎OE +B( 0~Hwo9 vz5 l@bg߉GbD {ap˩Y: F@me$ lϋ@[ @mO7&0lOE@m%:h`;3OF@m8h`:/0ԧy?; ^@[׹2 /Es@:t[07?h& Օk`u0ئ@ XrX;@ sby  lgH<v$h`;`4=`4M5 lj)0f>@ 3H9u~@zdgFrP]WnX 7 ՘|oj `z.?c?γN~VC>`sYWH'`u6Jlٱ:@j 땘Ht!}8R`XkE߱\0ثFuG/8`GW'ω0Du)Dt_Z2zqwCm1X;]t|7}:ʏ;szp18}K=1追{_04HF6nlśfw~!]`X;sD#zڑ̞m>(?եh=@@u9z/b3Y:b3ޔEsItǘr=Jb1xy\m 0`a  0` @  0` @@ 0+~3#ߋ'Z6L_œ3x;ɪ 1ݿf_<Qե'ܓ8>8T{V>3 @C_VW`#kkPṣS[u ئl_Ʈ.{E/|6#/ ǿ5<<}l~i% // `}{77~]m `ޔem=s30삍G|GAWOvޫ൧a;E`A@ A# `0 0 F@ ` a# 00``@O  `0`0 00`@`@# `0 0 Fq+1otE롎lZOo9n%2effe:Au1bfR ,xnZ竳x]FvӁFsY*w;Ve85MݫV}XL}k)~NMoP3TL OR[Զ\hi O%>JCGg3/25O_gh_@/ɂ-ienݼb6 r}K% O)w'XgbOQ F撤snziG:Ew^Ǘ۝# ~fǀ/̀swmw¢}8MҏkM1/~#5^U멥GůL,1>T$9ֶ=Ny/ԴʺSZgjy^ֆ axah N {=&.G;zDgԷ&A}4t:*@t^<=?V؅&қr:p] +C@S]' b]OgRLz`2IC9Ȥ:c:^&q0BN>S4Y@o Hð3hQL97z/ڭF" 'p"Z ph ʌ!NŤm}?%9yE?HNY?=YeT Lnc`lZ#>J+rK?6L3JWk}%F|3Gņ%v,﹘0 0ùH@n3mv|ɪ Jcg&\+ VW;d͜Wn%~՟ӃN4~9=# XYX0ـ?>&-Q^\/6+R|yK 0 0ð;,c?/)4|;bh6`H$Gavt㝴 99A}%9xgZ5`> S/Y"2: 0nkEtNz~zdG7@i_̜z3C,Ig_FP=uDw [/(2ٔZyQܻ2;e`afQ5E?QvLӄ1b4z6SglӺSZA}{h8/d gP~Bmh- -Bikk(9 QKNBo?6@?x\-طIJ58r_v^7_wl~%6U¦gZӊp,^[(gv7`o_+'Eԗumэ}a*_T'f,5&d_/q3 0 X4u| 3є;~ڿv@^'{5ٹ#a?MR&j|N[>& #amLK-͇"%xCʲD ;"mz)*zv` vµJnv vq(^_'J3JɵOj˯t~-0{aXYg'1Ƈh <+x_lSLcDdKȟY)2'۳>>(!=07I]xt2u kbR.N* YaH\a)7 ^7 e ߜZ*N+;0.`af^8z?"-O 7@D!"Lcg)8J٧Fm MN)HCSydԢ:c$Vߣ^7FV\;Q\1HbWb6N?v4&a\& mn}=$(~"o\La?I{zPdN٣`VT‰D̃ ]Uqa(c&W Y_|r# !u".~Ctbu0,,5!EMtAхM8Gv0BH#cfCs6zDI< -dj>8kp ]_ϗntn%8_y,/lq9w 0*- q3 0 Slۡ dh$:&o4G1QL7̻M;dMk_QlWdqU&D%lck+^6/}ңcdeu]/JM0aXYo5 8@۩A FTZ=VHv3uMhD } /r#߁œiݿI]Q~)*U]J6PP-^ͫt~Aue è`afYx# I.f쑋{a 0ðCn2J@N >Ӗ:B=/SGij}ҷA8 Dpo>BZW֮F`kYn[S}wvs1]$n2_~7f0ðFn M@"?aq K8yRKc4/2=YrdԷF6X}N߿NHpn#3 333:{0s2[I%Y>Wzݐԧv_iW HW(=?C֤R 'zq-Ai:OQL ǜ{G>#Ϛ<1aiƓչ]PjiC57R lh^SdK>چSri pFn>c` pBjb{'02{c)uY,%befls*߻!e6H@/o`Fg.wannZ{ߌX/BS*7x>pAEXsU9Co9_^eP̫\&gi*v#XQ?8r4*:/RacQ"frU4\:d1!ZdP:ocr`F @Ƌ$D׻aAX``NduN~&Z!g>v1v)#/:MIjXߕۆ:&hc֟ts`0;50:6Ь4KRM_~ؒ}hAC WϹ$vghbK{ r<Atڧ ]I 5EOQU#]A] v wv=ՍpF}]Pog?]?PPJ螊*A2 ŵpk ) Lo"{(nl O]&aRC5lJ"wq̈n^7 CoSv aJebcàWH ,""8)5?9uɟGJ)xOf҄ 0jDG DASF?pf`bBە33/C73w+M+Ou[A`I^?5/Fʏ o/xt 9q2cN;>|S9"w9|1V]afj__.>j~b z~+% $"`1""Zoًz4T;*'*y#0uTcx} 7 ϹlA`:DƏGsEn|i};=t0}?s_#Ħp|g`A``襦[6uĺvC?a"E;C oջtWKSoV9jHe bGV`3kZmhΎN2rKwF^;IDJDES/Q5~^JɆ7ܶtvb9p =!;um]WMEo5^uFW DI06xZ'gpu4 ?VjǕΏ<5Ʒ=t'}hzNҨ* IZkw2c4^pܳа kxcb:2_;z\{)i[H_ ^Yhc)y{: nQ0HE;.2> `n^4X΅;{~޴h"LfX\ 0a>'m=qAu`x\?NU}>%k2I$~*K?F[^1 0>ۥR.6H@'%Lw-~֮Q KgmG?aTn:^Z}2O+0nt`x<|3& 4Q Z:"n[7dM8{hNI}+EGq6ܞM頌Շ}LJL_Wk- *:N01 鸡H +SOT5w2FxA~qG]`:^;)\\;LJi! ٯ~#8+ȼ,ei38gERjF5c2J,2tc?$g:ZlCLjY$" Ð8Ќ(fo4ǜm~Nνfkv8 >&[f<*QFkd'/}~ޟ䜿c\q`~PTbJXBj~2/cV5^^z߅ai F/Ti–6@fOX7&=ȭT3Ȃ&>a6*ޣD1bO=S vֆ۵?Wp#Wif< H>?"ؕ1?Tvzyw7+($R#_-m$ k`?>HxdsOn(El*6z tCDe+3) _ ?q_RkGZ$ゆ*q?u<߳JS;Բȝ)F"Ak  O:3 BB> ̮ s9>Pl[gK6أcʾpH@,tN] 8]e[ebdeM5P`~m?UlϫXx7d; n D[V -$]Da-|4#g`Wwy`WqTLmQW~D K ʩ(17:ZBJDܳ +Z7rNH&l9{~~Bu{& ͯt?P!-d;!,iʅXͣ2N &=]Tx/qFR`Fm9ř.߇oez,zcҒb$zA[1*}ޚ X7je/B 2VSL'sB W~^ȇyק2b-ؔ>|%64ֈ,lՀY/LJri雩# ;'@]&%2%vJ[5xgf;$F(+"l;Y۷-v"")`0r(COi'ĽLʋ/dϚ P/HVZϴGÑ(Y ] p)9Y##7*^AXd&d-٩µd9 )yBl7 k< 0S?{F-ȇ'ݏښ톣Ā ʥD c=D)-`A7aU65w*KLPb>₶+Bpa"+;E?^53455qJ79zIz8n]?  gBZ|I)@*9pi' sPHA%m 90B֐v&[ߛ* olh,'X߆tXN_AXH12h<6')p6R+\6,,r`r.xU?:k=%xdԥt?8L9;ݦ ,qv@5S~9xi-{#ۈv)Q́Ϛ<̾@ RցazrP :-qB'D韇n(E|F%E +ak #ҩJh"z;/ir߾6uGES{߬tAXd6c٤[xf9o&@=k[L2v){ў㪏X4"ҤkύT#2h`d5cfg-T6u>5fC@5`A;e9}mIE>a^ҍp3ݡKMH4ب- #e~J ں^r@>Y3۽N9CPL6ݖwNAX$O%C)sj|[7ڳOyxHgO^qET0!ˑI*̽_9|/RωfjJQ*/R|_hom]gfΙrp$:ai+ӌoCl>×ַv`A"n8* v2(TflIKbs}e5E[)5&Wr+Ժ$!LƳ ("ȗ"ɫhs~S1xUn_$f$yimR ZDO:UGLNJ08(6{Ow_, HMC}y8B :Uגg-MC q Qe;?Cw#d7*yt0%=[, ZU/Wiikm 9~cq|[o~Êaa;9;.j Gbdӕ(7ʛ>| =﹊lv C;J.X%x,T\ߑg)=_T)-^Lw^jFΏ+@$B(i B"b{8FF}[7LxȚ"hޞe!h< ]W.:a7~wϥR{r6};5ZmC}%_mؾCxù{M8ߟu1o ,$ ,4įVşa l_WcuΖ?Dz" SJ Dps> 4&%oqq?>t}KQ٫FV͕CF=5y GV((>m'>jҞ0=BxӃBw_ᏡTҔbÜ,jP1oIWF`xTI}98w88'  Z#}2unJُzO#n䂜9WSʠpGflLBi\YnմA,3sweJ{U!lױ Y*~ 000000000          000000000        s^@$IIENDB`websockets-15.0.1/logo/horizontal.svg000066400000000000000000000350551476212450300176350ustar00rootroot00000000000000 websockets-15.0.1/logo/icon.html000066400000000000000000000020771476212450300165370ustar00rootroot00000000000000 Icon

Take a screenshot of these DOM nodes to2x make a PNG.

8x8 / 16x16 @ 2x

16x16 / 32x32 @ 2x

32x32 / 32x32 @ 2x

32x32 / 64x64 @ 2x

64x64 / 128x128 @ 2x

128x128 / 256x256 @ 2x

256x256 / 512x512 @ 2x

512x512 / 1024x1024 @ 2x

websockets-15.0.1/logo/icon.svg000066400000000000000000000062701476212450300163710ustar00rootroot00000000000000 websockets-15.0.1/logo/old.svg000066400000000000000000000013041476212450300162100ustar00rootroot00000000000000 websockets-15.0.1/logo/tidelift.png000066400000000000000000000077451476212450300172420ustar00rootroot00000000000000PNG  IHDRXXf pHYs.#.#x?vIDATx1n7nRI;0 Ĭ@t/в6`T+0lXC PKb@kM EGC~9`sc @`,X, X @`,X, X @` X, X @` @`, X @` @`,X @` @`,X @` @`,X  @`,X   @`,X   @`,X   @`,X   @`,X,  @`,X, X @`,X, X @` X, X @` X, X @` @`, X @` @`,X @` @`3#`N.:I:vJkѢ!^V[_INFs#A`AuYԱ۵w^$wI> a$M؁$N.n hxj`q5W@'b\EqIhIEQ ߘ`>4F]!Xg@!F~#Jw+'g=c@`f ,lXX4ځy( 6̹QX_E`,X   >ca,ؠd$X X4?XA1 =<P (B8ڀo566H`!(t2'1 \տ 25IG\!(9fI:IǣݹM6t2Z OđL2H289$_IMr`@1g;g?I/z hul-z>eN.LaoP.X, X @` X, X @` @`, X @` @`,X WI~NF]Ou,8wIzhܤNF$$7^"@`%Iҭcqr:|R (dԛNF˦_t2L*n h$ѠNFIzIn6-/KX@0o>u䭗X.NF%кJ"e ز$?1:sq. X&4 t2ꦺ 'q:utKue,`+o.񛯟FKtz|4p"k\;u\͍+vޛ o~:yA(m":,+ofFZ:Q>+ovYQV ۶fu8,{wVތb#5O҉sYu꼕G0l6,/Q-' X`/'y+ovYT2tR7FȚבiʛFr:ƊX@k15H"e^yseY6IʛF"V^xf֫[hvd-]̒|2h'`AS1BkjΝi.էFQddS˺1 X@3XyӎZN'^{+oZZI^ţ@`[^y30VFul]ʛGzŎ[PL` @`,V`5t=׫a+,İ:\ $vi/V~׫,zqC/񨎬s <,_AOaʫ,q5N򩎗RY s4-W[Va:ɇ~S}ʰ˖GIaԫ,)괎=^}@`OWWI~K[=ֻհ?s. Xj͞8-Cr]q7e1 X@YqՉ7Mv\V(@`ei)z(@`͎AUIY hfXY); U7-c(zιQvW?7V8H(+v@`;qi7a\,`q5G0Qe/wD`@`, X @` @`,X @` @`,X @` @`,X  @`,X   @`,X   @`,X   @`,X6涰7ZJ܏1"I׻X ؂yAzkv -X.ŵ V^$ ͺ*o~叹ٵjvVtZ7;@`A{ϖ:M3?iv伩-?<<<j5_zy`dа oNnuXߎo_uy v,  @`,X,  @`,X, X @`,X, X @` X, X @` @`, X @` @`, X @` @`,X @` @`,X  @`,X   @`,X   @`,X   @`,X   @`,X,  @`,X, X @`,X, XO }qIENDB`websockets-15.0.1/logo/vertical.svg000066400000000000000000000347021476212450300172530ustar00rootroot00000000000000 websockets-15.0.1/pyproject.toml000066400000000000000000000053061476212450300166730ustar00rootroot00000000000000[build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" requires-python = ">=3.9" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, ] keywords = ["WebSocket"] classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] dynamic = ["version", "readme"] [project.urls] Homepage = "https://github.com/python-websockets/websockets" Changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" Documentation = "https://websockets.readthedocs.io/" Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" Tracker = "https://github.com/python-websockets/websockets/issues" [project.scripts] websockets = "websockets.cli:main" [tool.cibuildwheel] enable = ["pypy"] # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] archs = ["x86_64", "universal2", "arm64"] # On an Linux Intel runner with QEMU installed, build Intel and ARM wheels. [tool.cibuildwheel.linux] archs = ["auto", "aarch64"] [tool.coverage.run] branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", "tests/maxi_cov.py", ] [tool.coverage.paths] source = [ "src/websockets", ".tox/*/lib/python*/site-packages/websockets", ] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "except ImportError:", "if self.debug:", "if sys.platform == \"win32\":", "if sys.platform != \"win32\":", "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", "@overload", "@unittest.skip", ] partial_branches = [ "pragma: no branch", "with self.assertRaises\\(.*\\)", ] [tool.ruff] target-version = "py312" [tool.ruff.lint] select = [ "E", # pycodestyle "F", # Pyflakes "W", # pycodestyle "I", # isort ] ignore = [ "F403", "F405", ] [tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 websockets-15.0.1/setup.py000066400000000000000000000020501476212450300154620ustar00rootroot00000000000000import os import pathlib import re import setuptools root_dir = pathlib.Path(__file__).parent exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) # PyPI disables the "raw" directive. Remove this section of the README. long_description = re.sub( r"^\.\. raw:: html.*?^(?=\w)", "", (root_dir / "README.rst").read_text(encoding="utf-8"), flags=re.DOTALL | re.MULTILINE, ) # Set BUILD_EXTENSION to yes or no to force building or not building the # speedups extension. If unset, the extension is built only if possible. if os.environ.get("BUILD_EXTENSION") == "no": ext_modules = [] else: ext_modules = [ setuptools.Extension( "websockets.speedups", sources=["src/websockets/speedups.c"], optional=os.environ.get("BUILD_EXTENSION") != "yes", ) ] # Static values are declared in pyproject.toml. setuptools.setup( version=version, long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, ) websockets-15.0.1/src/000077500000000000000000000000001476212450300145425ustar00rootroot00000000000000websockets-15.0.1/src/websockets/000077500000000000000000000000001476212450300167135ustar00rootroot00000000000000websockets-15.0.1/src/websockets/__init__.py000066400000000000000000000156221476212450300210320ustar00rootroot00000000000000from __future__ import annotations # Importing the typing module would conflict with websockets.typing. from typing import TYPE_CHECKING from .imports import lazy_import from .version import version as __version__ # noqa: F401 __all__ = [ # .asyncio.client "connect", "unix_connect", "ClientConnection", # .asyncio.router "route", "unix_route", "Router", # .asyncio.server "basic_auth", "broadcast", "serve", "unix_serve", "ServerConnection", "Server", # .client "ClientProtocol", # .datastructures "Headers", "HeadersLike", "MultipleValuesError", # .exceptions "ConcurrencyError", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", "DuplicateParameter", "InvalidHandshake", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", "InvalidProxy", "InvalidProxyMessage", "InvalidProxyStatus", "InvalidState", "InvalidStatus", "InvalidUpgrade", "InvalidURI", "NegotiationError", "PayloadTooBig", "ProtocolError", "ProxyError", "SecurityError", "WebSocketException", # .frames "Close", "CloseCode", "Frame", "Opcode", # .http11 "Request", "Response", # .protocol "Protocol", "Side", "State", # .server "ServerProtocol", # .typing "Data", "ExtensionName", "ExtensionParameter", "LoggerLike", "StatusLike", "Origin", "Subprotocol", ] # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect from .asyncio.router import Router, route, unix_route from .asyncio.server import ( Server, ServerConnection, basic_auth, broadcast, serve, unix_serve, ) from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( ConcurrencyError, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, DuplicateParameter, InvalidHandshake, InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, InvalidProxy, InvalidProxyMessage, InvalidProxyStatus, InvalidState, InvalidStatus, InvalidUpgrade, InvalidURI, NegotiationError, PayloadTooBig, ProtocolError, ProxyError, SecurityError, WebSocketException, ) from .frames import Close, CloseCode, Frame, Opcode from .http11 import Request, Response from .protocol import Protocol, Side, State from .server import ServerProtocol from .typing import ( Data, ExtensionName, ExtensionParameter, LoggerLike, Origin, StatusLike, Subprotocol, ) else: lazy_import( globals(), aliases={ # .asyncio.client "connect": ".asyncio.client", "unix_connect": ".asyncio.client", "ClientConnection": ".asyncio.client", # .asyncio.router "route": ".asyncio.router", "unix_route": ".asyncio.router", "Router": ".asyncio.router", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", "serve": ".asyncio.server", "unix_serve": ".asyncio.server", "ServerConnection": ".asyncio.server", "Server": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures "Headers": ".datastructures", "HeadersLike": ".datastructures", "MultipleValuesError": ".datastructures", # .exceptions "ConcurrencyError": ".exceptions", "ConnectionClosed": ".exceptions", "ConnectionClosedError": ".exceptions", "ConnectionClosedOK": ".exceptions", "DuplicateParameter": ".exceptions", "InvalidHandshake": ".exceptions", "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", "InvalidProxy": ".exceptions", "InvalidProxyMessage": ".exceptions", "InvalidProxyStatus": ".exceptions", "InvalidState": ".exceptions", "InvalidStatus": ".exceptions", "InvalidUpgrade": ".exceptions", "InvalidURI": ".exceptions", "NegotiationError": ".exceptions", "PayloadTooBig": ".exceptions", "ProtocolError": ".exceptions", "ProxyError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", # .frames "Close": ".frames", "CloseCode": ".frames", "Frame": ".frames", "Opcode": ".frames", # .http11 "Request": ".http11", "Response": ".http11", # .protocol "Protocol": ".protocol", "Side": ".protocol", "State": ".protocol", # .server "ServerProtocol": ".server", # .typing "Data": ".typing", "ExtensionName": ".typing", "ExtensionParameter": ".typing", "LoggerLike": ".typing", "Origin": ".typing", "StatusLike": ".typing", "Subprotocol": ".typing", }, deprecated_aliases={ # deprecated in 9.0 - 2021-09-01 "framing": ".legacy", "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", # deprecated in 14.0 - 2024-11-09 # .legacy.auth "BasicAuthWebSocketServerProtocol": ".legacy.auth", "basic_auth_protocol_factory": ".legacy.auth", # .legacy.client "WebSocketClientProtocol": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", }, ) websockets-15.0.1/src/websockets/__main__.py000066400000000000000000000000761476212450300210100ustar00rootroot00000000000000from .cli import main if __name__ == "__main__": main() websockets-15.0.1/src/websockets/asyncio/000077500000000000000000000000001476212450300203605ustar00rootroot00000000000000websockets-15.0.1/src/websockets/asyncio/__init__.py000066400000000000000000000000001476212450300224570ustar00rootroot00000000000000websockets-15.0.1/src/websockets/asyncio/async_timeout.py000066400000000000000000000214131476212450300236160ustar00rootroot00000000000000# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py # Licensed under the Apache License (Apache-2.0) import asyncio import enum import sys import warnings from types import TracebackType from typing import Optional, Type if sys.version_info >= (3, 11): from typing import final else: # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py # Licensed under the Python Software Foundation License (PSF-2.0) # @final exists in 3.8+, but we backport it for all versions # before 3.11 to keep support for the __final__ attribute. # See https://bugs.python.org/issue46342 def final(f): """This decorator can be used to indicate to type checkers that the decorated method cannot be overridden, and decorated class cannot be subclassed. For example: class Base: @final def done(self) -> None: ... class Sub(Base): def done(self) -> None: # Error reported by type checker ... @final class Leaf: ... class Other(Leaf): # Error reported by type checker ... There is no runtime checking of these properties. The decorator sets the ``__final__`` attribute to ``True`` on the decorated object to allow runtime introspection. """ try: f.__final__ = True except (AttributeError, TypeError): # Skip the attribute silently if it is not writable. # AttributeError happens if the object has __slots__ or a # read-only property, TypeError if it's a builtin class. pass return f # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py if sys.version_info >= (3, 11): def _uncancel_task(task: "asyncio.Task[object]") -> None: task.uncancel() else: def _uncancel_task(task: "asyncio.Task[object]") -> None: pass __version__ = "4.0.3" __all__ = ("timeout", "timeout_at", "Timeout") def timeout(delay: Optional[float]) -> "Timeout": """timeout context manager. Useful in cases when you want to apply timeout logic around block of code or in cases when asyncio.wait_for is not suitable. For example: >>> async with timeout(0.001): ... async with aiohttp.get('https://github.com') as r: ... await r.text() delay - value in seconds or None to disable timeout logic """ loop = asyncio.get_running_loop() if delay is not None: deadline = loop.time() + delay # type: Optional[float] else: deadline = None return Timeout(deadline, loop) def timeout_at(deadline: Optional[float]) -> "Timeout": """Schedule the timeout at absolute time. deadline argument points on the time in the same clock system as loop.time(). Please note: it is not POSIX time but a time with undefined starting base, e.g. the time of the system power on. >>> async with timeout_at(loop.time() + 10): ... async with aiohttp.get('https://github.com') as r: ... await r.text() """ loop = asyncio.get_running_loop() return Timeout(deadline, loop) class _State(enum.Enum): INIT = "INIT" ENTER = "ENTER" TIMEOUT = "TIMEOUT" EXIT = "EXIT" @final class Timeout: # Internal class, please don't instantiate it directly # Use timeout() and timeout_at() public factories instead. # # Implementation note: `async with timeout()` is preferred # over `with timeout()`. # While technically the Timeout class implementation # doesn't need to be async at all, # the `async with` statement explicitly points that # the context manager should be used from async function context. # # This design allows to avoid many silly misusages. # # TimeoutError is raised immediately when scheduled # if the deadline is passed. # The purpose is to time out as soon as possible # without waiting for the next await expression. __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") def __init__( self, deadline: Optional[float], loop: asyncio.AbstractEventLoop ) -> None: self._loop = loop self._state = _State.INIT self._task: Optional["asyncio.Task[object]"] = None self._timeout_handler = None # type: Optional[asyncio.Handle] if deadline is None: self._deadline = None # type: Optional[float] else: self.update(deadline) def __enter__(self) -> "Timeout": warnings.warn( "with timeout() is deprecated, use async with timeout() instead", DeprecationWarning, stacklevel=2, ) self._do_enter() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> Optional[bool]: self._do_exit(exc_type) return None async def __aenter__(self) -> "Timeout": self._do_enter() return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> Optional[bool]: self._do_exit(exc_type) return None @property def expired(self) -> bool: """Is timeout expired during execution?""" return self._state == _State.TIMEOUT @property def deadline(self) -> Optional[float]: return self._deadline def reject(self) -> None: """Reject scheduled timeout if any.""" # cancel is maybe better name but # task.cancel() raises CancelledError in asyncio world. if self._state not in (_State.INIT, _State.ENTER): raise RuntimeError(f"invalid state {self._state.value}") self._reject() def _reject(self) -> None: self._task = None if self._timeout_handler is not None: self._timeout_handler.cancel() self._timeout_handler = None def shift(self, delay: float) -> None: """Advance timeout on delay seconds. The delay can be negative. Raise RuntimeError if shift is called when deadline is not scheduled """ deadline = self._deadline if deadline is None: raise RuntimeError("cannot shift timeout if deadline is not scheduled") self.update(deadline + delay) def update(self, deadline: float) -> None: """Set deadline to absolute value. deadline argument points on the time in the same clock system as loop.time(). If new deadline is in the past the timeout is raised immediately. Please note: it is not POSIX time but a time with undefined starting base, e.g. the time of the system power on. """ if self._state == _State.EXIT: raise RuntimeError("cannot reschedule after exit from context manager") if self._state == _State.TIMEOUT: raise RuntimeError("cannot reschedule expired timeout") if self._timeout_handler is not None: self._timeout_handler.cancel() self._deadline = deadline if self._state != _State.INIT: self._reschedule() def _reschedule(self) -> None: assert self._state == _State.ENTER deadline = self._deadline if deadline is None: return now = self._loop.time() if self._timeout_handler is not None: self._timeout_handler.cancel() self._task = asyncio.current_task() if deadline <= now: self._timeout_handler = self._loop.call_soon(self._on_timeout) else: self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) def _do_enter(self) -> None: if self._state != _State.INIT: raise RuntimeError(f"invalid state {self._state.value}") self._state = _State.ENTER self._reschedule() def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: assert self._task is not None _uncancel_task(self._task) self._timeout_handler = None self._task = None raise asyncio.TimeoutError # timeout has not expired self._state = _State.EXIT self._reject() return None def _on_timeout(self) -> None: assert self._task is not None self._task.cancel() self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None # End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py websockets-15.0.1/src/websockets/asyncio/client.py000066400000000000000000000754021476212450300222200ustar00rootroot00000000000000from __future__ import annotations import asyncio import logging import os import socket import ssl as ssl_module import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidMessage, InvalidProxyMessage, InvalidProxyStatus, InvalidStatus, ProxyError, SecurityError, ) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection __all__ = ["connect", "unix_connect", "ClientConnection"] MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) class ClientConnection(Connection): """ :mod:`asyncio` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. """ def __init__( self, protocol: ClientProtocol, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol super().__init__( protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, ) self.response_rcvd: asyncio.Future[None] = self.loop.create_future() async def handshake( self, additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, ) -> None: """ Perform the opening handshake. """ async with self.send_context(expected_state=CONNECTING): self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header is not None: self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) await asyncio.wait( [self.response_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) # self.protocol.handshake_exc is set when the connection is lost before # receiving a response, when the response cannot be parsed, or when the # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ Process one incoming event. """ # First event - handshake response. if self.response is None: assert isinstance(event, Response) self.response = event self.response_rcvd.set_result(None) # Later events - frames. else: super().process_event(event) def process_exception(exc: Exception) -> Exception | None: """ Determine whether a connection error is retryable or fatal. When reconnecting automatically with ``async for ... in connect(...)``, if a connection attempt fails, :func:`process_exception` is called to determine whether to retry connecting or to raise the exception. This function defines the default behavior, which is to retry on: * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network errors; * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, 502, 503, or 504: server or proxy errors. All other exceptions are considered fatal. You can change this behavior with the ``process_exception`` argument of :func:`connect`. Return :obj:`None` if the exception is retryable i.e. when the error could be transient and trying to reconnect with the same parameters could succeed. The exception will be logged at the ``INFO`` level. Return an exception, either ``exc`` or a new exception, if the exception is fatal i.e. when trying to reconnect will most likely produce the same error. That exception will be raised, breaking out of the retry loop. """ # This catches python-socks' ProxyConnectionError and ProxyTimeoutError. # Remove asyncio.TimeoutError when dropping Python < 3.11. if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error 502, # Bad Gateway 503, # Service Unavailable 504, # Gateway Timeout ]: return None return exc # This is spelled in lower case because it's exposed as a callable in the API. class connect: """ Connect to the WebSocket server at ``uri``. This coroutine returns a :class:`ClientConnection` instance, which you can use to send and receive messages. :func:`connect` may be used as an asynchronous context manager:: from websockets.asyncio.client import connect async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: async for websocket in connect(...): try: ... except websockets.exceptions.ConnectionClosed: continue If the connection fails with a transient error, it is retried with exponential backoff. If it fails with a fatal error, the exception is raised, breaking out of the loop. The connection is closed automatically after each iteration of the loop. Args: uri: URI of the WebSocket server. origin: Value of the ``Origin`` header, for servers that require it. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. proxy: If a proxy is configured, it is used by default. Set ``proxy`` to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. process_exception: When reconnecting automatically, tell whether an error is transient or fatal. The default behavior is defined by :func:`process_exception`. Refer to its documentation for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. If you want to disable flow control entirely, you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. create_connection: Factory for the :class:`ClientConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. Any other keyword arguments are passed to the event loop's :meth:`~asyncio.loop.create_connection` method. For example: * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS context is created with :func:`~ssl.create_default_context`. * You can set ``server_hostname`` to override the host name from ``uri`` in the TLS handshake. * You can set ``host`` and ``port`` to connect to a different host and port from those found in ``uri``. This only changes the destination of the TCP connection. The host name from ``uri`` is still used in the TLS handshake for secure connections and in the ``Host`` header. * You can set ``sock`` to provide a preexisting TCP socket. You may call :func:`socket.create_connection` (not to be confused with the event loop's :meth:`~asyncio.loop.create_connection` method) to create a suitable client socket and customize it. When using a proxy: * Prefix keyword arguments with ``proxy_`` for configuring TLS between the client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. * Use the standard keyword arguments for configuring TLS between the proxy and the WebSocket server: ``ssl``, ``server_hostname``, ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. * Other keyword arguments are used only for connecting to the proxy. Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. """ def __init__( self, uri: str, *, # WebSocket origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, compression: str | None = "deflate", # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ClientConnection] | None = None, # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: self.uri = uri if subprotocols is not None: validate_subprotocols(subprotocols) if compression == "deflate": extensions = enable_client_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if logger is None: logger = logging.getLogger("websockets.client") if create_connection is None: create_connection = ClientConnection def protocol_factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( uri, origin=origin, extensions=extensions, subprotocols=subprotocols, max_size=max_size, logger=logger, ) # This is a connection in websockets and a protocol in asyncio. connection = create_connection( protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, ) return connection self.proxy = proxy self.protocol_factory = protocol_factory self.additional_headers = additional_headers self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger self.connection_kwargs = kwargs async def create_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() kwargs = self.connection_kwargs.copy() ws_uri = parse_uri(self.uri) proxy = self.proxy if kwargs.get("unix", False): proxy = None if kwargs.get("sock") is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) def factory() -> ClientConnection: return self.protocol_factory(ws_uri) if ws_uri.secure: kwargs.setdefault("ssl", True) kwargs.setdefault("server_hostname", ws_uri.host) if kwargs.get("ssl") is None: raise ValueError("ssl=None is incompatible with a wss:// URI") else: if kwargs.get("ssl") is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) elif proxy is not None: proxy_parsed = parse_proxy(proxy) if proxy_parsed.scheme[:5] == "socks": # Connect to the server through the proxy. sock = await connect_socks_proxy( proxy_parsed, ws_uri, local_addr=kwargs.pop("local_addr", None), ) # Initialize WebSocket connection via the proxy. _, connection = await loop.create_connection( factory, sock=sock, **kwargs, ) elif proxy_parsed.scheme[:4] == "http": # Split keyword arguments between the proxy and the server. all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} for key, value in all_kwargs.items(): if key.startswith("ssl") or key == "server_hostname": kwargs[key] = value elif key.startswith("proxy_"): proxy_kwargs[key[6:]] = value else: proxy_kwargs[key] = value # Validate the proxy_ssl argument. if proxy_parsed.scheme == "https": proxy_kwargs.setdefault("ssl", True) if proxy_kwargs.get("ssl") is None: raise ValueError( "proxy_ssl=None is incompatible with an https:// proxy" ) else: if proxy_kwargs.get("ssl") is not None: raise ValueError( "proxy_ssl argument is incompatible with an http:// proxy" ) # Connect to the server through the proxy. transport = await connect_http_proxy( proxy_parsed, ws_uri, user_agent_header=self.user_agent_header, **proxy_kwargs, ) # Initialize WebSocket connection via the proxy. connection = factory() transport.set_protocol(connection) ssl = kwargs.pop("ssl", None) if ssl is True: ssl = ssl_module.create_default_context() if ssl is not None: new_transport = await loop.start_tls( transport, connection, ssl, **kwargs ) assert new_transport is not None # help mypy transport = new_transport connection.connection_made(transport) else: raise AssertionError("unsupported proxy") else: # Connect to the server directly. if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection def process_redirect(self, exc: Exception) -> Exception | str: """ Determine whether a connection error is a redirect that can be followed. Return the new URI if it's a valid redirect. Else, return an exception. """ if not ( isinstance(exc, InvalidStatus) and exc.response.status_code in [ 300, # Multiple Choices 301, # Moved Permanently 302, # Found 303, # See Other 307, # Temporary Redirect 308, # Permanent Redirect ] and "Location" in exc.response.headers ): return exc old_ws_uri = parse_uri(self.uri) new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. if self.connection_kwargs.get("sock") is not None: return ValueError( f"cannot follow redirect to {new_uri} with a preexisting socket" ) # TLS downgrade is forbidden. if old_ws_uri.secure and not new_ws_uri.secure: return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") # Apply restrictions to cross-origin redirects. if ( old_ws_uri.secure != new_ws_uri.secure or old_ws_uri.host != new_ws_uri.host or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. if self.connection_kwargs.get("unix", False): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with a Unix socket" ) # Cross-origin redirects when host and port are overridden are ill-defined. if ( self.connection_kwargs.get("host") is not None or self.connection_kwargs.get("port") is not None ): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with an explicit host or port" ) return new_uri # ... = await connect(...) def __await__(self) -> Generator[Any, None, ClientConnection]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): for _ in range(MAX_REDIRECTS): self.connection = await self.create_connection() try: await self.connection.handshake( self.additional_headers, self.user_agent_header, ) except asyncio.CancelledError: self.connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is # the default in HTTP/1.1 because create_connection ties # opening the network connection with initializing the # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. self.connection.transport.abort() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. if isinstance(uri_or_exc, str): self.uri = uri_or_exc continue # Response isn't a valid redirect; raise the exception. if uri_or_exc is exc: raise else: raise uri_or_exc from exc else: self.connection.start_keepalive() return self.connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") except TimeoutError as exc: # Re-raise exception with an informative error message. raise TimeoutError("timed out during opening handshake") from exc # ... = yield from connect(...) - remove when dropping Python < 3.10 __iter__ = __await__ # async with connect(...) as ...: ... async def __aenter__(self) -> ClientConnection: return await self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: await self.connection.close() # async for ... in connect(...): async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays: Generator[float] | None = None while True: try: async with self as protocol: yield protocol except Exception as exc: # Determine whether the exception is retryable or fatal. # The API of process_exception is "return an exception or None"; # "raise an exception" is also supported because it's a frequent # mistake. It isn't documented in order to keep the API simple. try: new_exc = self.process_exception(exc) except Exception as raised_exc: new_exc = raised_exc # The connection failed with a fatal error. # Raise the exception and exit the loop. if new_exc is exc: raise if new_exc is not None: raise new_exc from exc # The connection failed with a retryable error. # Start or continue backoff and reconnect. if delays is None: delays = backoff() delay = next(delays) self.logger.info( "connect failed; reconnecting in %.1f seconds: %s", delay, # Remove first argument when dropping Python 3.9. traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(delay) continue else: # The connection succeeded. Reset backoff. delays = None def unix_connect( path: str | None = None, uri: str | None = None, **kwargs: Any, ) -> connect: """ Connect to a WebSocket server listening on a Unix socket. This function accepts the same keyword arguments as :func:`connect`. It's only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to ``ws://localhost/`` or, when a ``ssl`` argument is provided, to ``wss://localhost/``. """ if uri is None: if kwargs.get("ssl") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) try: from python_socks import ProxyType from python_socks.async_.asyncio import Proxy as SocksProxy SOCKS_PROXY_TYPES = { "socks5h": ProxyType.SOCKS5, "socks5": ProxyType.SOCKS5, "socks4a": ProxyType.SOCKS4, "socks4": ProxyType.SOCKS4, } SOCKS_PROXY_RDNS = { "socks5h": True, "socks5": False, "socks4a": True, "socks4": False, } async def connect_socks_proxy( proxy: Proxy, ws_uri: WebSocketURI, **kwargs: Any, ) -> socket.socket: """Connect via a SOCKS proxy and return the socket.""" socks_proxy = SocksProxy( SOCKS_PROXY_TYPES[proxy.scheme], proxy.host, proxy.port, proxy.username, proxy.password, SOCKS_PROXY_RDNS[proxy.scheme], ) # connect() is documented to raise OSError. # socks_proxy.connect() doesn't raise TimeoutError; it gets canceled. # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. try: return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) except OSError: raise except Exception as exc: raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: async def connect_socks_proxy( proxy: Proxy, ws_uri: WebSocketURI, **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") def prepare_connect_request( proxy: Proxy, ws_uri: WebSocketURI, user_agent_header: str | None = None, ) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) if user_agent_header is not None: headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( proxy.username, proxy.password ) # We cannot use the Request class because it supports only GET requests. return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() class HTTPProxyConnection(asyncio.Protocol): def __init__( self, ws_uri: WebSocketURI, proxy: Proxy, user_agent_header: str | None = None, ): self.ws_uri = ws_uri self.proxy = proxy self.user_agent_header = user_agent_header self.reader = StreamReader() self.parser = Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, include_body=False, ) loop = asyncio.get_running_loop() self.response: asyncio.Future[Response] = loop.create_future() def run_parser(self) -> None: try: next(self.parser) except StopIteration as exc: response = exc.value if 200 <= response.status_code < 300: self.response.set_result(response) else: self.response.set_exception(InvalidProxyStatus(response)) except Exception as exc: proxy_exc = InvalidProxyMessage( "did not receive a valid HTTP response from proxy" ) proxy_exc.__cause__ = exc self.response.set_exception(proxy_exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.transport = transport self.transport.write( prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header) ) def data_received(self, data: bytes) -> None: self.reader.feed_data(data) self.run_parser() def eof_received(self) -> None: self.reader.feed_eof() self.run_parser() def connection_lost(self, exc: Exception | None) -> None: self.reader.feed_eof() if exc is not None: self.response.set_exception(exc) async def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, user_agent_header: str | None = None, **kwargs: Any, ) -> asyncio.Transport: transport, protocol = await asyncio.get_running_loop().create_connection( lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header), proxy.host, proxy.port, **kwargs, ) try: # This raises exceptions if the connection to the proxy fails. await protocol.response except Exception: transport.close() raise return transport websockets-15.0.1/src/websockets/asyncio/compatibility.py000066400000000000000000000014221476212450300236020ustar00rootroot00000000000000from __future__ import annotations import sys __all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] if sys.version_info[:2] >= (3, 11): TimeoutError = TimeoutError aiter = aiter anext = anext from asyncio import ( timeout as asyncio_timeout, # noqa: F401 timeout_at as asyncio_timeout_at, # noqa: F401 ) else: # Python < 3.11 from asyncio import TimeoutError def aiter(async_iterable): return type(async_iterable).__aiter__(async_iterable) async def anext(async_iterator): return await type(async_iterator).__anext__(async_iterator) from .async_timeout import ( timeout as asyncio_timeout, # noqa: F401 timeout_at as asyncio_timeout_at, # noqa: F401 ) websockets-15.0.1/src/websockets/asyncio/connection.py000066400000000000000000001371221476212450300230770ustar00rootroot00000000000000from __future__ import annotations import asyncio import collections import contextlib import logging import random import struct import sys import traceback import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType from typing import Any, Literal, cast, overload from ..exceptions import ( ConcurrencyError, ConnectionClosed, ConnectionClosedOK, ProtocolError, ) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol from .compatibility import ( TimeoutError, aiter, anext, asyncio_timeout, asyncio_timeout_at, ) from .messages import Assembler __all__ = ["Connection"] class Connection(asyncio.Protocol): """ :mod:`asyncio` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use :class:`~websockets.asyncio.client.ClientConnection` or :class:`~websockets.asyncio.server.ServerConnection`. """ def __init__( self, protocol: Protocol, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): write_limit = (write_limit, None) self.write_limit = write_limit # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( self.protocol.logger, {"websocket": self}, ) # Copy attributes from the protocol for convenience. self.id: uuid.UUID = self.protocol.id """Unique identifier of the connection. Useful in logs.""" self.logger: LoggerLike = self.protocol.logger """Logger for this connection.""" self.debug = self.protocol.debug # HTTP handshake request and response. self.request: Request | None = None """Opening handshake request.""" self.response: Response | None = None """Opening handshake response.""" # Event loop running this connection. self.loop = asyncio.get_running_loop() # Assembler turning frames into messages and serializing reads. self.recv_messages: Assembler # initialized in connection_made # Deadline for the closing handshake. self.close_deadline: float | None = None # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ Latency of the connection, in seconds. Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism that sends Ping frames automatically at regular intervals. You can also send Ping frames and measure latency with :meth:`ping`. """ # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None # Exception raised while reading from the connection, to be chained to # ConnectionClosed in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None # Completed when the TCP connection is closed and the WebSocket # connection state becomes CLOSED. self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( collections.deque() ) # Public attributes @property def local_address(self) -> Any: """ Local address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family. See :meth:`~socket.socket.getsockname`. """ return self.transport.get_extra_info("sockname") @property def remote_address(self) -> Any: """ Remote address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family. See :meth:`~socket.socket.getpeername`. """ return self.transport.get_extra_info("peername") @property def state(self) -> State: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should call :meth:`~recv` or :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.state @property def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. :obj:`None` if no subprotocol was negotiated. """ return self.protocol.subprotocol @property def close_code(self) -> int | None: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should inspect attributes of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.close_code @property def close_reason(self) -> str | None: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should inspect attributes of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.close_reason # Public methods async def __aenter__(self) -> Connection: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: if exc_type is None: await self.close() else: await self.close(CloseCode.INTERNAL_ERROR) async def __aiter__(self) -> AsyncIterator[Data]: """ Iterate on incoming messages. The iterator calls :meth:`recv` and yields messages asynchronously in an infinite loop. It exits when the connection is closed normally. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception after a protocol error or a network failure. """ try: while True: yield await self.recv() except ConnectionClosedOK: return @overload async def recv(self, decode: Literal[True]) -> str: ... @overload async def recv(self, decode: Literal[False]) -> bytes: ... @overload async def recv(self, decode: bool | None = None) -> Data: ... async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. When the connection is closed, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. This is how you detect the end of the message stream. Canceling :meth:`recv` is safe. There's no risk of losing data. The next invocation of :meth:`recv` will return the next message. This makes it possible to enforce a timeout by wrapping :meth:`recv` in :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. When the message is fragmented, :meth:`recv` waits until all fragments are received, reassembles them, and returns the whole message. Args: decode: Set this flag to override the default behavior of returning :class:`str` or :class:`bytes`. See below for details. Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and return a bytestring (:class:`bytes`). This improves performance when decoding isn't needed, for example if the message contains JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and return a string (:class:`str`). This may be useful for servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ try: return await self.recv_messages.get(decode) except EOFError: pass # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None except UnicodeDecodeError as exc: async with self.send_context(): self.protocol.fail( CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}", ) # fallthrough # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc @overload def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... @overload def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... @overload def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. This method is designed for receiving fragmented messages. It returns an asynchronous iterator that yields each fragment as it is received. This iterator must be fully consumed. Else, future calls to :meth:`recv` or :meth:`recv_streaming` will raise :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. Canceling :meth:`recv_streaming` before receiving the first frame is safe. Canceling it after receiving one or more frames leaves the iterator in a partially consumed state, making the connection unusable. Instead, you should close the connection with :meth:`close`. Args: decode: Set this flag to override the default behavior of returning :class:`str` or :class:`bytes`. See below for details. Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and return bytestrings (:class:`bytes`). This may be useful to optimize performance when decoding isn't needed. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and return strings (:class:`str`). This is useful for servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ try: async for frame in self.recv_messages.get_iter(decode): yield frame return except EOFError: pass # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None except UnicodeDecodeError as exc: async with self.send_context(): self.protocol.fail( CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}", ) # fallthrough # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc async def send( self, message: Data | Iterable[Data] | AsyncIterable[Data], text: bool | None = None, ) -> None: """ Send a message. A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``text`` argument: * Set ``text=True`` to send a bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. * Set ``text=False`` to send a string (:class:`str`) in a Binary_ frame. This may be useful for servers that expect binary frames instead of text frames. :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there are only two situations where :meth:`send` may yield control to the event loop and then get canceled; in both cases, :meth:`close` has the same effect and is more clear: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. :meth:`close` will likely time out then abort the TCP connection. 2. ``message`` is an asynchronous iterator that yields control. Stopping in the middle of a fragmented message will cause a protocol error and the connection will be closed. When the connection is closed, :meth:`send` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. Args: message: Message to send. Raises: ConnectionClosed: When the connection is closed. TypeError: If ``message`` doesn't have a supported type. """ # While sending a fragmented message, prevent sending other messages # until all fragments are sent. while self.fragmented_send_waiter is not None: await asyncio.shield(self.fragmented_send_waiter) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. if isinstance(message, str): async with self.send_context(): if text is False: self.protocol.send_binary(message.encode()) else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): async with self.send_context(): if text is True: self.protocol.send_text(message) else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). elif isinstance(message, Mapping): raise TypeError("data is a dict-like object") # Fragmented message -- regular iterator. elif isinstance(message, Iterable): chunks = iter(message) try: chunk = next(chunks) except StopIteration: return assert self.fragmented_send_waiter is None self.fragmented_send_waiter = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): async with self.send_context(): if text is False: self.protocol.send_binary(chunk.encode(), fin=False) else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): async with self.send_context(): if text is True: self.protocol.send_text(chunk, fin=False) else: self.protocol.send_binary(chunk, fin=False) encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: if isinstance(chunk, str) and encode: async with self.send_context(): self.protocol.send_continuation(chunk.encode(), fin=False) elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") # Final fragment. async with self.send_context(): self.protocol.send_continuation(b"", fin=True) except Exception: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, "error in fragmented message", ) raise finally: self.fragmented_send_waiter.set_result(None) self.fragmented_send_waiter = None # Fragmented message -- async iterator. elif isinstance(message, AsyncIterable): achunks = aiter(message) try: chunk = await anext(achunks) except StopAsyncIteration: return assert self.fragmented_send_waiter is None self.fragmented_send_waiter = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): if text is False: async with self.send_context(): self.protocol.send_binary(chunk.encode(), fin=False) else: async with self.send_context(): self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): if text is True: async with self.send_context(): self.protocol.send_text(chunk, fin=False) else: async with self.send_context(): self.protocol.send_binary(chunk, fin=False) encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: if isinstance(chunk, str) and encode: async with self.send_context(): self.protocol.send_continuation(chunk.encode(), fin=False) elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") # Final fragment. async with self.send_context(): self.protocol.send_continuation(b"", fin=True) except Exception: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, "error in fragmented message", ) raise finally: self.fragmented_send_waiter.set_result(None) self.fragmented_send_waiter = None else: raise TypeError("data must be str, bytes, iterable, or async iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ Perform the closing handshake. :meth:`close` waits for the other end to complete the handshake and for the TCP connection to terminate. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. Args: code: WebSocket close code. reason: WebSocket close reason. """ try: # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", ) else: self.protocol.send_close(code, reason) except ConnectionClosed: # Ignore ConnectionClosed exceptions raised from send_context(). # They mean that the connection is closed, which was the goal. pass async def wait_closed(self) -> None: """ Wait until the connection is closed. :meth:`wait_closed` waits for the closing handshake to complete and for the TCP connection to terminate. """ await asyncio.shield(self.connection_lost_waiter) async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point Args: data: Payload of the ping. A :class:`str` will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. Returns: A future that will be completed when the corresponding pong is received. You can ignore it if you don't intend to wait. The result of the future is the latency of the connection in seconds. :: pong_waiter = await ws.ping() # only if you want to wait for the corresponding pong latency = await pong_waiter Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ if isinstance(data, BytesLike): data = bytes(data) elif isinstance(data, str): data = data.encode() elif data is not None: raise TypeError("data must be str or bytes-like") async with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.pong_waiters: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pong_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. self.pong_waiters[data] = (pong_waiter, self.loop.time()) self.protocol.send_ping(data) return pong_waiter async def pong(self, data: Data = b"") -> None: """ Send a Pong_. .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. Args: data: Payload of the pong. A :class:`str` will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. """ if isinstance(data, BytesLike): data = bytes(data) elif isinstance(data, str): data = data.encode() else: raise TypeError("data must be str or bytes-like") async with self.send_context(): self.protocol.send_pong(data) # Private methods def process_event(self, event: Event) -> None: """ Process one incoming event. This method is overridden in subclasses to handle the handshake. """ assert isinstance(event, Frame) if event.opcode in DATA_OPCODES: self.recv_messages.put(event) if event.opcode is Opcode.PONG: self.acknowledge_pings(bytes(event.data)) def acknowledge_pings(self, data: bytes) -> None: """ Acknowledge pings when receiving a pong. """ # Ignore unsolicited pong. if data not in self.pong_waiters: return pong_timestamp = self.loop.time() # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp if not pong_waiter.done(): pong_waiter.set_result(latency) if ping_id == data: self.latency = latency break else: raise AssertionError("solicited pong not found in pings") # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: del self.pong_waiters[ping_id] def abort_pings(self) -> None: """ Raise ConnectionClosed in pending pings. They'll never receive a pong once the connection is closed. """ assert self.protocol.state is CLOSED exc = self.protocol.close_exc for pong_waiter, _ping_timestamp in self.pong_waiters.values(): if not pong_waiter.done(): pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. pong_waiter.cancel() self.pong_waiters.clear() async def keepalive(self) -> None: """ Send a Ping frame and wait for a Pong frame at regular intervals. """ assert self.ping_interval is not None latency = 0.0 try: while True: # If self.ping_timeout > latency > self.ping_interval, # pings will be sent immediately after receiving pongs. # The period will be longer than self.ping_interval. await asyncio.sleep(self.ping_interval - latency) # This cannot raise ConnectionClosed when the connection is # closing because ping(), via send_context(), waits for the # connection to be closed before raising ConnectionClosed. # However, connection_lost() cancels keepalive_task before # it gets a chance to resume excuting. pong_waiter = await self.ping() if self.debug: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): # connection_lost cancels keepalive immediately # after setting a ConnectionClosed exception on # pong_waiter. A CancelledError is raised here, # not a ConnectionClosed exception. latency = await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("- timed out waiting for keepalive pong") async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) raise AssertionError( "send_context() should wait for connection_lost(), " "which cancels keepalive()" ) except Exception: self.logger.error("keepalive ping failed", exc_info=True) def start_keepalive(self) -> None: """ Run :meth:`keepalive` in a task, unless keepalive is disabled. """ if self.ping_interval is not None: self.keepalive_task = self.loop.create_task(self.keepalive()) @contextlib.asynccontextmanager async def send_context( self, *, expected_state: State = OPEN, # CONNECTING during the opening handshake ) -> AsyncIterator[None]: """ Create a context for writing to the connection from user code. On entry, :meth:`send_context` checks that the connection is open; on exit, it writes outgoing data to the socket:: async with self.send_context(): self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected to close on exit, or when an unexpected error happens, terminating the connection, :meth:`send_context` waits until the connection is closed then raises :exc:`~websockets.exceptions.ConnectionClosed`. """ # Should we wait until the connection is closed? wait_for_close = False # Should we close the transport and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? original_exc: BaseException | None = None if self.protocol.state is expected_state: # Let the caller interact with the protocol. try: yield except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: self.logger.error("unexpected internal error", exc_info=True) # This branch should never run. It's a safety net in case of # bugs. Since we don't know what happened, we will close the # connection and raise the exception to the caller. wait_for_close = False raise_close_exc = True original_exc = exc else: # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. # Since we tested earlier that protocol.state was OPEN # (or CONNECTING), self.close_deadline is still None. if self.close_timeout is not None: assert self.close_deadline is None self.close_deadline = self.loop.time() + self.close_timeout # Write outgoing data to the socket and enforce flow control. try: self.send_data() await self.drain() except Exception as exc: if self.debug: self.logger.debug("! error while sending data", exc_info=True) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False raise_close_exc = True original_exc = exc else: # self.protocol.state is not expected_state # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True # Calculate close_deadline if it wasn't set yet. if self.close_timeout is not None: if self.close_deadline is None: self.close_deadline = self.loop.time() + self.close_timeout raise_close_exc = True # If the connection is expected to close soon and the close timeout # elapses, close the socket to terminate the connection. if wait_for_close: try: async with asyncio_timeout_at(self.close_deadline): await asyncio.shield(self.connection_lost_waiter) except TimeoutError: # There's no risk to overwrite another error because # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") # Set recv_exc before closing the transport in order to get # proper exception reporting. raise_close_exc = True self.set_recv_exc(original_exc) # If an error occurred, close the transport to terminate the connection and # raise an exception. if raise_close_exc: self.transport.abort() # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc def send_data(self) -> None: """ Send outgoing data. Raises: OSError: When a socket operations fails. """ for data in self.protocol.data_to_send(): if data: self.transport.write(data) else: # Half-close the TCP connection when possible i.e. no TLS. if self.transport.can_write_eof(): if self.debug: self.logger.debug("x half-closing TCP connection") # write_eof() doesn't document which exceptions it raises. # OSError is plausible. uvloop can raise RuntimeError here. try: self.transport.write_eof() except (OSError, RuntimeError): # pragma: no cover pass # Else, close the TCP connection. else: # pragma: no cover if self.debug: self.logger.debug("x closing TCP connection") self.transport.close() def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. """ if self.recv_exc is None: self.recv_exc = exc # asyncio.Protocol methods # Connection callbacks def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.recv_messages = Assembler( *self.max_queue, pause=transport.pause_reading, resume=transport.resume_reading, ) transport.set_write_buffer_limits(*self.write_limit) self.transport = transport def connection_lost(self, exc: Exception | None) -> None: # Calling protocol.receive_eof() is safe because it's idempotent. # This guarantees that the protocol state becomes CLOSED. self.protocol.receive_eof() assert self.protocol.state is CLOSED self.set_recv_exc(exc) # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() self.abort_pings() if self.keepalive_task is not None: self.keepalive_task.cancel() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) # Adapted from asyncio.streams.FlowControlMixin if self.paused: # pragma: no cover self.paused = False for waiter in self.drain_waiters: if not waiter.done(): if exc is None: waiter.set_result(None) else: waiter.set_exception(exc) # Flow control callbacks def pause_writing(self) -> None: # pragma: no cover # Adapted from asyncio.streams.FlowControlMixin assert not self.paused self.paused = True def resume_writing(self) -> None: # pragma: no cover # Adapted from asyncio.streams.FlowControlMixin assert self.paused self.paused = False for waiter in self.drain_waiters: if not waiter.done(): waiter.set_result(None) async def drain(self) -> None: # pragma: no cover # We don't check if the connection is closed because we call drain() # immediately after write() and write() would fail in that case. # Adapted from asyncio.streams.StreamWriter # Yield to the event loop so that connection_lost() may be called. if self.transport.is_closing(): await asyncio.sleep(0) # Adapted from asyncio.streams.FlowControlMixin if self.paused: waiter = self.loop.create_future() self.drain_waiters.append(waiter) try: await waiter finally: self.drain_waiters.remove(waiter) # Streaming protocol callbacks def data_received(self, data: bytes) -> None: # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. events = self.protocol.events_received() # Write outgoing data to the transport. try: self.send_data() except Exception as exc: if self.debug: self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) if self.protocol.close_expected(): # If the connection is expected to close soon, set the # close deadline based on the close timeout. if self.close_timeout is not None: if self.close_deadline is None: self.close_deadline = self.loop.time() + self.close_timeout for event in events: # This isn't expected to raise an exception. self.process_event(event) def eof_received(self) -> None: # Feed the end of the data stream to the connection. self.protocol.receive_eof() # This isn't expected to raise an exception. events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it shouldn't raise errors. self.send_data() # This code path is triggered when receiving an HTTP response # without a Content-Length header. This is the only case where # reading until EOF generates an event; all other events have # a known length. Ignore for coverage measurement because tests # are in test_client.py rather than test_connection.py. for event in events: # pragma: no cover # This isn't expected to raise an exception. self.process_event(event) # The WebSocket protocol has its own closing handshake: endpoints close # the TCP or TLS connection after sending and receiving a close frame. # As a consequence, they never need to write after receiving EOF, so # there's no reason to keep the transport open by returning True. # Besides, that doesn't work on TLS connections. # broadcast() is defined in the connection module even though it's primarily # used by servers and documented in the server module because it works with # client connections too and because it's easier to test together with the # Connection class. def broadcast( connections: Iterable[Connection], message: Data, raise_exceptions: bool = False, ) -> None: """ Broadcast a message to several WebSocket connections. A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. If you broadcast messages faster than a connection can handle them, messages will pile up in its write buffer until the connection times out. Keep ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. Unlike :meth:`~websockets.asyncio.connection.Connection.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, fragmentation is useful for sending large messages without buffering them in memory, while :func:`broadcast` buffers one copy per connection as fast as possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. :func:`broadcast` ignores failures to write the message on some connections. It continues writing to other connections. On Python 3.11 and above, you may set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. While :func:`broadcast` makes more sense for servers, it works identically with clients, if you have a use case for opening connections to many servers and broadcasting a message to them. Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. raise_exceptions: Whether to raise an exception in case of failures. Raises: TypeError: If ``message`` doesn't have a supported type. """ if isinstance(message, str): send_method = "send_text" message = message.encode() elif isinstance(message, BytesLike): send_method = "send_binary" else: raise TypeError("data must be str or bytes") if raise_exceptions: if sys.version_info[:2] < (3, 11): # pragma: no cover raise ValueError("raise_exceptions requires at least Python 3.11") exceptions: list[Exception] = [] for connection in connections: exception: Exception if connection.protocol.state is not OPEN: continue if connection.fragmented_send_waiter is not None: if raise_exceptions: exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) else: connection.logger.warning( "skipped broadcast: sending a fragmented message", ) continue try: # Call connection.protocol.send_text or send_binary. # Either way, message is already converted to bytes. getattr(connection.protocol, send_method)(message) connection.send_data() except Exception as write_exception: if raise_exceptions: exception = RuntimeError("failed to write message") exception.__cause__ = write_exception exceptions.append(exception) else: connection.logger.warning( "skipped broadcast: failed to write message: %s", traceback.format_exception_only( # Remove first argument when dropping Python 3.9. type(write_exception), write_exception, )[0].strip(), ) if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) # Pretend that broadcast is actually defined in the server module. broadcast.__module__ = "websockets.asyncio.server" websockets-15.0.1/src/websockets/asyncio/messages.py000066400000000000000000000253611476212450300225500ustar00rootroot00000000000000from __future__ import annotations import asyncio import codecs import collections from collections.abc import AsyncIterator, Iterable from typing import Any, Callable, Generic, Literal, TypeVar, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data __all__ = ["Assembler"] UTF8Decoder = codecs.getincrementaldecoder("utf-8") T = TypeVar("T") class SimpleQueue(Generic[T]): """ Simplified version of :class:`asyncio.Queue`. Provides only the subset of functionality needed by :class:`Assembler`. """ def __init__(self) -> None: self.loop = asyncio.get_running_loop() self.get_waiter: asyncio.Future[None] | None = None self.queue: collections.deque[T] = collections.deque() def __len__(self) -> int: return len(self.queue) def put(self, item: T) -> None: """Put an item into the queue without waiting.""" self.queue.append(item) if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: if not block: raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter finally: self.get_waiter.cancel() self.get_waiter = None return self.queue.popleft() def reset(self, items: Iterable[T]) -> None: """Put back items into an empty, idle queue.""" assert self.get_waiter is None, "cannot reset() while get() is running" assert not self.queue, "cannot reset() while queue isn't empty" self.queue.extend(items) def abort(self) -> None: """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) class Assembler: """ Assemble messages from frames. :class:`Assembler` expects only data frames. The stream of frames must respect the protocol; if it doesn't, the behavior is undefined. Args: pause: Called when the buffer of frames goes above the high water mark; should pause reading from the network. resume: Called when the buffer of frames goes below the low water mark; should resume reading from the network. """ # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. if high is not None and low is None: low = high // 4 if high is None and low is not None: high = low * 4 if high is not None and low is not None: if low < 0: raise ValueError("low must be positive or equal to zero") if high < low: raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False # This flag marks the end of the connection. self.closed = False @overload async def get(self, decode: Literal[True]) -> str: ... @overload async def get(self, decode: Literal[False]) -> bytes: ... @overload async def get(self, decode: bool | None = None) -> Data: ... async def get(self, decode: bool | None = None) -> Data: """ Read the next message. :meth:`get` returns a single :class:`str` or :class:`bytes`. If the message is fragmented, :meth:`get` waits until the last frame is received, then it reassembles the message and returns it. To receive messages frame by frame, use :meth:`get_iter` instead. Args: decode: :obj:`False` disables UTF-8 decoding of text frames and returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. """ if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution # until get() fetches a complete message or is canceled. try: # First frame frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: decode = frame.opcode is OP_TEXT frames = [frame] # Following frames, for fragmented messages while not frame.fin: try: frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. self.frames.reset(frames) raise self.maybe_resume() assert frame.opcode is OP_CONT frames.append(frame) finally: self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: return data.decode() else: return data @overload def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... @overload def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... @overload def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Stream the next message. Iterating the return value of :meth:`get_iter` asynchronously yields a :class:`str` or :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. Args: decode: :obj:`False` disables UTF-8 decoding of text frames and returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. """ if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution # until get_iter() fetches a complete message or is canceled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. # First frame try: frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: decode = frame.opcode is OP_TEXT if decode: decoder = UTF8Decoder() yield decoder.decode(frame.data, frame.fin) else: yield frame.data # Following frames, for fragmented messages while not frame.fin: # We cannot handle asyncio.CancelledError because we don't buffer # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: yield decoder.decode(frame.data, frame.fin) else: yield frame.data self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. Raises: EOFError: If the stream of frames has ended. """ if self.closed: raise EOFError("stream of frames ended") self.frames.put(frame) self.maybe_pause() def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" # Skip if flow control is disabled if self.high is None: return # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" # Skip if flow control is disabled if self.low is None: return # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False self.resume() def close(self) -> None: """ End the stream of frames. Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ if self.closed: return self.closed = True # Unblock get() or get_iter(). self.frames.abort() websockets-15.0.1/src/websockets/asyncio/router.py000066400000000000000000000147111476212450300222560ustar00rootroot00000000000000from __future__ import annotations import http import ssl as ssl_module import urllib.parse from typing import Any, Awaitable, Callable, Literal from werkzeug.exceptions import NotFound from werkzeug.routing import Map, RequestRedirect from ..http11 import Request, Response from .server import Server, ServerConnection, serve __all__ = ["route", "unix_route", "Router"] class Router: """WebSocket router supporting :func:`route`.""" def __init__( self, url_map: Map, server_name: str | None = None, url_scheme: str = "ws", ) -> None: self.url_map = url_map self.server_name = server_name self.url_scheme = url_scheme for rule in self.url_map.iter_rules(): rule.websocket = True def get_server_name(self, connection: ServerConnection, request: Request) -> str: if self.server_name is None: return request.headers["Host"] else: return self.server_name def redirect(self, connection: ServerConnection, url: str) -> Response: response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") response.headers["Location"] = url return response def not_found(self, connection: ServerConnection) -> Response: return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") def route_request( self, connection: ServerConnection, request: Request ) -> Response | None: """Route incoming request.""" url_map_adapter = self.url_map.bind( server_name=self.get_server_name(connection, request), url_scheme=self.url_scheme, ) try: parsed = urllib.parse.urlparse(request.path) handler, kwargs = url_map_adapter.match( path_info=parsed.path, query_args=parsed.query, ) except RequestRedirect as redirect: return self.redirect(connection, redirect.new_url) except NotFound: return self.not_found(connection) connection.handler, connection.handler_kwargs = handler, kwargs return None async def handler(self, connection: ServerConnection) -> None: """Handle a connection.""" return await connection.handler(connection, **connection.handler_kwargs) def route( url_map: Map, *args: Any, server_name: str | None = None, ssl: ssl_module.SSLContext | Literal[True] | None = None, create_router: type[Router] | None = None, **kwargs: Any, ) -> Awaitable[Server]: """ Create a WebSocket server dispatching connections to different handlers. This feature requires the third-party library `werkzeug`_: .. code-block:: console $ pip install werkzeug .. _werkzeug: https://werkzeug.palletsprojects.com/ :func:`route` accepts the same arguments as :func:`~websockets.sync.server.serve`, except as described below. The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns to connection handlers. In addition to the connection, handlers receive parameters captured in the URL as keyword arguments. Here's an example:: from websockets.asyncio.router import route from werkzeug.routing import Map, Rule async def channel_handler(websocket, channel_id): ... url_map = Map([ Rule("/channel/", endpoint=channel_handler), ... ]) # set this future to exit the server stop = asyncio.get_running_loop().create_future() async with route(url_map, ...) as server: await stop Refer to the documentation of :mod:`werkzeug.routing` for details. If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, when the server runs behind a reverse proxy that modifies the ``Host`` header or terminates TLS, you need additional configuration: * Set ``server_name`` to the name of the server as seen by clients. When not provided, websockets uses the value of the ``Host`` header. * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling TLS. Under the hood, this bind the URL map with a ``url_scheme`` of ``wss://`` instead of ``ws://``. There is no need to specify ``websocket=True`` in each rule. It is added automatically. Args: url_map: Mapping of URL patterns to connection handlers. server_name: Name of the server as seen by clients. If :obj:`None`, websockets uses the value of the ``Host`` header. ssl: Configuration for enabling TLS on the connection. Set it to :obj:`True` if a reverse proxy terminates TLS connections. create_router: Factory for the :class:`Router` dispatching requests to handlers. Set it to a wrapper or a subclass to customize routing. """ url_scheme = "ws" if ssl is None else "wss" if ssl is not True and ssl is not None: kwargs["ssl"] = ssl if create_router is None: create_router = Router router = create_router(url_map, server_name, url_scheme) _process_request: ( Callable[ [ServerConnection, Request], Awaitable[Response | None] | Response | None, ] | None ) = kwargs.pop("process_request", None) if _process_request is None: process_request: Callable[ [ServerConnection, Request], Awaitable[Response | None] | Response | None, ] = router.route_request else: async def process_request( connection: ServerConnection, request: Request ) -> Response | None: response = _process_request(connection, request) if isinstance(response, Awaitable): response = await response if response is not None: return response return router.route_request(connection, request) return serve(router.handler, *args, process_request=process_request, **kwargs) def unix_route( url_map: Map, path: str | None = None, **kwargs: Any, ) -> Awaitable[Server]: """ Create a WebSocket Unix server dispatching connections to different handlers. :func:`unix_route` combines the behaviors of :func:`route` and :func:`~websockets.asyncio.server.unix_serve`. Args: url_map: Mapping of URL patterns to connection handlers. path: File system path to the Unix socket. """ return route(url_map, unix=True, path=path, **kwargs) websockets-15.0.1/src/websockets/asyncio/server.py000066400000000000000000001110111476212450300222330ustar00rootroot00000000000000from __future__ import annotations import asyncio import hmac import http import logging import re import socket import sys from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode from ..headers import ( build_www_authenticate_basic, parse_authorization_basic, validate_subprotocols, ) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .connection import Connection, broadcast __all__ = [ "broadcast", "serve", "unix_serve", "ServerConnection", "Server", "basic_auth", ] class ServerConnection(Connection): """ :mod:`asyncio` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. server: Server that manages this connection. """ def __init__( self, protocol: ServerProtocol, server: Server, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol super().__init__( protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() self.username: str # see basic_auth() self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ Create a plain text HTTP response. ``process_request`` and ``process_response`` may call this method to return an HTTP response instead of performing the WebSocket opening handshake. You can modify the response before returning it, for example by changing HTTP headers. Args: status: HTTP status code. text: HTTP response body; it will be encoded to UTF-8. Returns: HTTP response to send to the client. """ return self.protocol.reject(status, text) async def handshake( self, process_request: ( Callable[ [ServerConnection, Request], Awaitable[Response | None] | Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], Awaitable[Response | None] | Response | None, ] | None ) = None, server_header: str | None = SERVER, ) -> None: """ Perform the opening handshake. """ await asyncio.wait( [self.request_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) if self.request is not None: async with self.send_context(expected_state=CONNECTING): response = None if process_request is not None: try: response = process_request(self, self.request) if isinstance(response, Awaitable): response = await response except Exception as exc: self.protocol.handshake_exc = exc response = self.protocol.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), ) if response is None: if self.server.is_serving(): self.response = self.protocol.accept(self.request) else: self.response = self.protocol.reject( http.HTTPStatus.SERVICE_UNAVAILABLE, "Server is shutting down.\n", ) else: assert isinstance(response, Response) # help mypy self.response = response if server_header: self.response.headers["Server"] = server_header response = None if process_response is not None: try: response = process_response(self, self.request, self.response) if isinstance(response, Awaitable): response = await response except Exception as exc: self.protocol.handshake_exc = exc response = self.protocol.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), ) if response is not None: assert isinstance(response, Response) # help mypy self.response = response self.protocol.send_response(self.response) # self.protocol.handshake_exc is set when the connection is lost before # receiving a request, when the request cannot be parsed, or when the # handshake fails, including when process_request or process_response # raises an exception. # It isn't set when process_request or process_response sends an HTTP # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ Process one incoming event. """ # First event - handshake request. if self.request is None: assert isinstance(event, Request) self.request = event self.request_rcvd.set_result(None) # Later events - frames. else: super().process_event(event) def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) self.server.start_connection_handler(self) class Server: """ WebSocket server returned by :func:`serve`. This class mirrors the API of :class:`asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. Args: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_request`` may be a function or a coroutine. process_response: Intercept the response during the opening handshake. Modify the response or return a new HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_response`` may be a function or a coroutine. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. """ def __init__( self, handler: Callable[[ServerConnection], Awaitable[None]], *, process_request: ( Callable[ [ServerConnection, Request], Awaitable[Response | None] | Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], Awaitable[Response | None] | Response | None, ] | None ) = None, server_header: str | None = SERVER, open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: self.loop = asyncio.get_running_loop() self.handler = handler self.process_request = process_request self.process_response = process_response self.server_header = server_header self.open_timeout = open_timeout if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger # Keep track of active connections. self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} # Task responsible for closing the server and terminating connections. self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] = self.loop.create_future() @property def connections(self) -> set[ServerConnection]: """ Set of active connections. This property contains all connections that completed the opening handshake successfully and didn't start the closing handshake yet. It can be useful in combination with :func:`~broadcast`. """ return {connection for connection in self.handlers if connection.state is OPEN} def wrap(self, server: asyncio.Server) -> None: """ Attach to a given :class:`asyncio.Server`. Since :meth:`~asyncio.loop.create_server` doesn't support injecting a custom ``Server`` class, the easiest solution that doesn't rely on private :mod:`asyncio` APIs is to: - instantiate a :class:`Server` - give the protocol factory a reference to that instance - call :meth:`~asyncio.loop.create_server` with the factory - attach the resulting :class:`asyncio.Server` with this method """ self.server = server for sock in server.sockets: if sock.family == socket.AF_INET: name = "%s:%d" % sock.getsockname() elif sock.family == socket.AF_INET6: name = "[%s]:%d" % sock.getsockname()[:2] elif sock.family == socket.AF_UNIX: name = sock.getsockname() # In the unlikely event that someone runs websockets over a # protocol other than IP or Unix sockets, avoid crashing. else: # pragma: no cover name = str(sock.getsockname()) self.logger.info("server listening on %s", name) async def conn_handler(self, connection: ServerConnection) -> None: """ Handle the lifecycle of a WebSocket connection. Since this method doesn't have a caller that can handle exceptions, it attempts to log relevant ones. It guarantees that the TCP connection is closed before exiting. """ try: async with asyncio_timeout(self.open_timeout): try: await connection.handshake( self.process_request, self.process_response, self.server_header, ) except asyncio.CancelledError: connection.transport.abort() raise except Exception: connection.logger.error("opening handshake failed", exc_info=True) connection.transport.abort() return if connection.protocol.state is not OPEN: # process_request or process_response rejected the handshake. connection.transport.abort() return try: connection.start_keepalive() await self.handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) await connection.close(CloseCode.INTERNAL_ERROR) else: await connection.close() except TimeoutError: # When the opening handshake times out, there's nothing to log. pass except Exception: # pragma: no cover # Don't leak connections on unexpected errors. connection.transport.abort() finally: # Registration is tied to the lifecycle of conn_handler() because # the server waits for connection handlers to terminate, even if # all connections are already closed. del self.handlers[connection] def start_connection_handler(self, connection: ServerConnection) -> None: """ Register a connection with this server. """ # The connection must be registered in self.handlers immediately. # If it was registered in conn_handler(), a race condition could # happen when closing the server after scheduling conn_handler() # but before it starts executing. self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) def close(self, close_connections: bool = True) -> None: """ Close the server. * Close the underlying :class:`asyncio.Server`. * When ``close_connections`` is :obj:`True`, which is the default, close existing connections. Specifically: * Reject opening WebSocket connections with an HTTP 503 (service unavailable) error. This happens when the server accepted the TCP connection but didn't complete the opening handshake before closing. * Close open WebSocket connections with close code 1001 (going away). * Wait until all connection handlers terminate. :meth:`close` is idempotent. """ if self.close_task is None: self.close_task = self.get_loop().create_task( self._close(close_connections) ) async def _close(self, close_connections: bool) -> None: """ Implementation of :meth:`close`. This calls :meth:`~asyncio.Server.close` on the underlying :class:`asyncio.Server` object to stop accepting new connections and then closes open connections with close code 1001. """ self.logger.info("server closing") # Stop accepting new connections. self.server.close() # Wait until all accepted connections reach connection_made() and call # register(). See https://github.com/python/cpython/issues/79033 for # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) if close_connections: # Close OPEN connections with close code 1001. After server.close(), # handshake() closes OPENING connections with an HTTP 503 error. close_tasks = [ asyncio.create_task(connection.close(1001)) for connection in self.handlers if connection.protocol.state is not CONNECTING ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: await asyncio.wait(close_tasks) # Wait until all TCP connections are closed. await self.server.wait_closed() # Wait until all connection handlers terminate. # asyncio.wait doesn't accept an empty first argument. if self.handlers: await asyncio.wait(self.handlers.values()) # Tell wait_closed() to return. self.closed_waiter.set_result(None) self.logger.info("server closed") async def wait_closed(self) -> None: """ Wait until the server is closed. When :meth:`wait_closed` returns, all TCP connections are closed and all connection handlers have returned. To ensure a fast shutdown, a connection handler should always be awaiting at least one of: * :meth:`~ServerConnection.recv`: when the connection is closed, it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; * :meth:`~ServerConnection.wait_closed`: when the connection is closed, it returns. Then the connection handler is immediately notified of the shutdown; it can clean up and exit. """ await asyncio.shield(self.closed_waiter) def get_loop(self) -> asyncio.AbstractEventLoop: """ See :meth:`asyncio.Server.get_loop`. """ return self.server.get_loop() def is_serving(self) -> bool: # pragma: no cover """ See :meth:`asyncio.Server.is_serving`. """ return self.server.is_serving() async def start_serving(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.start_serving`. Typical use:: server = await serve(..., start_serving=False) # perform additional setup here... # ... then start the server await server.start_serving() """ await self.server.start_serving() async def serve_forever(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.serve_forever`. Typical use:: server = await serve(...) # this coroutine doesn't return # canceling it stops the server await server.serve_forever() This is an alternative to using :func:`serve` as an asynchronous context manager. Shutdown is triggered by canceling :meth:`serve_forever` instead of exiting a :func:`serve` context. """ await self.server.serve_forever() @property def sockets(self) -> Iterable[socket.socket]: """ See :attr:`asyncio.Server.sockets`. """ return self.server.sockets async def __aenter__(self) -> Server: # pragma: no cover return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() # This is spelled in lower case because it's exposed as a callable in the API. class serve: """ Create a WebSocket server listening on ``host`` and ``port``. Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler`` coroutine. The handler receives the :class:`ServerConnection` instance, which you can use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. This coroutine returns a :class:`Server` whose API mirrors :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: from websockets.asyncio.server import serve def handler(websocket): ... # set this future to exit the server stop = asyncio.get_running_loop().create_future() async with serve(handler, host, port): await stop Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: server = await serve(handler, host, port) await server.serve_forever() Args: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. host: Network interfaces the server binds to. See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Values can be :class:`str` to test for an exact match or regular expressions compiled by :func:`re.compile` to test against a pattern. Include :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It receives a :class:`ServerConnection` (not a :class:`~websockets.server.ServerProtocol`!) instance and a list of subprotocols offered by the client. Other than the first argument, it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response or :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_request`` may be a function or a coroutine. process_response: Intercept the response during the opening handshake. Return an HTTP response to force the response or :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_response`` may be a function or a coroutine. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. If you want to disable flow control entirely, you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. create_connection: Factory for the :class:`ServerConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. Any other keyword arguments are passed to the event loop's :meth:`~asyncio.loop.create_server` method. For example: * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. * You can set ``sock`` to provide a preexisting TCP socket. You may call :func:`socket.create_server` (not to be confused with the event loop's :meth:`~asyncio.loop.create_server` method) to create a suitable server socket and customize it. * You can set ``start_serving`` to ``False`` to start accepting connections only after you call :meth:`~Server.start_serving()` or :meth:`~Server.serve_forever()`. """ def __init__( self, handler: Callable[[ServerConnection], Awaitable[None]], host: str | None = None, port: int | None = None, *, # WebSocket origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], Subprotocol | None, ] | None ) = None, compression: str | None = "deflate", # HTTP process_request: ( Callable[ [ServerConnection, Request], Awaitable[Response | None] | Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], Awaitable[Response | None] | Response | None, ] | None ) = None, server_header: str | None = SERVER, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ServerConnection] | None = None, # Other keyword arguments are passed to loop.create_server **kwargs: Any, ) -> None: if subprotocols is not None: validate_subprotocols(subprotocols) if compression == "deflate": extensions = enable_server_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if create_connection is None: create_connection = ServerConnection self.server = Server( handler, process_request=process_request, process_response=process_response, server_header=server_header, open_timeout=open_timeout, logger=logger, ) if kwargs.get("ssl") is not None: kwargs.setdefault("ssl_handshake_timeout", open_timeout) if sys.version_info[:2] >= (3, 11): # pragma: no branch kwargs.setdefault("ssl_shutdown_timeout", close_timeout) def factory() -> ServerConnection: """ Create an asyncio protocol for managing a WebSocket connection. """ # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], Subprotocol | None, ] | None ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) # This is a protocol in the Sans-I/O implementation of websockets. protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, max_size=max_size, logger=logger, ) # This is a connection in websockets and a protocol in asyncio. connection = create_connection( protocol, self.server, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, ) return connection loop = asyncio.get_running_loop() if kwargs.pop("unix", False): self.create_server = loop.create_unix_server(factory, **kwargs) else: # mypy cannot tell that kwargs must provide sock when port is None. self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] # async with serve(...) as ...: ... async def __aenter__(self) -> Server: return await self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: self.server.close() await self.server.wait_closed() # ... = await serve(...) def __await__(self) -> Generator[Any, None, Server]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() async def __await_impl__(self) -> Server: server = await self.create_server self.server.wrap(server) return self.server # ... = yield from serve(...) - remove when dropping Python < 3.10 __iter__ = __await__ def unix_serve( handler: Callable[[ServerConnection], Awaitable[None]], path: str | None = None, **kwargs: Any, ) -> Awaitable[Server]: """ Create a WebSocket server listening on a Unix socket. This function is identical to :func:`serve`, except the ``host`` and ``port`` arguments are replaced by ``path``. It's only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. Args: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. path: File system path to the Unix socket. """ return serve(handler, unix=True, path=path, **kwargs) def is_credentials(credentials: Any) -> bool: try: username, password = credentials except (TypeError, ValueError): return False else: return isinstance(username, str) and isinstance(password, str) def basic_auth( realm: str = "", credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None, ) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]: """ Factory for ``process_request`` to enforce HTTP Basic Authentication. :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: from websockets.asyncio.server import basic_auth, serve async with serve( ..., process_request=basic_auth( realm="my dev server", credentials=("hello", "iloveyou"), ), ): If authentication succeeds, the connection's ``username`` attribute is set. If it fails, the server responds with an HTTP 401 Unauthorized status. One of ``credentials`` or ``check_credentials`` must be provided; not both. Args: realm: Scope of protection. It should contain only ASCII characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. check_credentials: Function or coroutine that verifies credentials. It receives ``username`` and ``password`` arguments and returns whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. ValueError: If ``credentials`` and ``check_credentials`` are both provided or both not provided. """ if (credentials is None) == (check_credentials is None): raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: raise TypeError(f"invalid credentials argument: {credentials}") credentials_dict = dict(credentials_list) def check_credentials(username: str, password: str) -> bool: try: expected_password = credentials_dict[username] except KeyError: return False return hmac.compare_digest(expected_password, password) assert check_credentials is not None # help mypy async def process_request( connection: ServerConnection, request: Request, ) -> Response | None: """ Perform HTTP Basic Authentication. If it succeeds, set the connection's ``username`` attribute and return :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. """ try: authorization = request.headers["Authorization"] except KeyError: response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Missing credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response try: username, password = parse_authorization_basic(authorization) except InvalidHeader: response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Unsupported credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response valid_credentials = check_credentials(username, password) if isinstance(valid_credentials, Awaitable): valid_credentials = await valid_credentials if not valid_credentials: response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Invalid credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response connection.username = username return None return process_request websockets-15.0.1/src/websockets/auth.py000066400000000000000000000010701476212450300202240ustar00rootroot00000000000000from __future__ import annotations import warnings with warnings.catch_warnings(): # Suppress redundant DeprecationWarning raised by websockets.legacy. warnings.filterwarnings("ignore", category=DeprecationWarning) from .legacy.auth import * from .legacy.auth import __all__ # noqa: F401 warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", DeprecationWarning, ) websockets-15.0.1/src/websockets/cli.py000066400000000000000000000123301476212450300200330ustar00rootroot00000000000000from __future__ import annotations import argparse import asyncio import os import sys from typing import Generator from .asyncio.client import ClientConnection, connect from .asyncio.messages import SimpleQueue from .exceptions import ConnectionClosed from .frames import Close from .streams import StreamReader from .version import version as websockets_version __all__ = ["main"] def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position "\N{ESC}7" # Add a new line "\N{LINE FEED}" # Move cursor up "\N{ESC}[A" # Insert blank line, scroll last line down "\N{ESC}[L" # Print string in the inserted blank line f"{string}\N{LINE FEED}" # Restore cursor position "\N{ESC}8" # Move cursor down "\N{ESC}[B" ) sys.stdout.flush() def print_over_input(string: str) -> None: sys.stdout.write( # Move cursor to beginning of line "\N{CARRIAGE RETURN}" # Delete current line "\N{ESC}[K" # Print string f"{string}\N{LINE FEED}" ) sys.stdout.flush() class ReadLines(asyncio.Protocol): def __init__(self) -> None: self.reader = StreamReader() self.messages: SimpleQueue[str] = SimpleQueue() def parse(self) -> Generator[None, None, None]: while True: sys.stdout.write("> ") sys.stdout.flush() line = yield from self.reader.read_line(sys.maxsize) self.messages.put(line.decode().rstrip("\r\n")) def connection_made(self, transport: asyncio.BaseTransport) -> None: self.parser = self.parse() next(self.parser) def data_received(self, data: bytes) -> None: self.reader.feed_data(data) next(self.parser) def eof_received(self) -> None: self.reader.feed_eof() # next(self.parser) isn't useful and would raise EOFError. def connection_lost(self, exc: Exception | None) -> None: self.reader.discard() self.messages.abort() async def print_incoming_messages(websocket: ClientConnection) -> None: async for message in websocket: if isinstance(message, str): print_during_input("< " + message) else: print_during_input("< (binary) " + message.hex()) async def send_outgoing_messages( websocket: ClientConnection, messages: SimpleQueue[str], ) -> None: while True: try: message = await messages.get() except EOFError: break try: await websocket.send(message) except ConnectionClosed: # pragma: no cover break async def interactive_client(uri: str) -> None: try: websocket = await connect(uri) except Exception as exc: print(f"Failed to connect to {uri}: {exc}.") sys.exit(1) else: print(f"Connected to {uri}.") loop = asyncio.get_running_loop() transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin) incoming = asyncio.create_task( print_incoming_messages(websocket), ) outgoing = asyncio.create_task( send_outgoing_messages(websocket, protocol.messages), ) try: await asyncio.wait( [incoming, outgoing], # Clean up and exit when the server closes the connection # or the user enters EOT (^D), whichever happens first. return_when=asyncio.FIRST_COMPLETED, ) # asyncio.run() cancels the main task when the user triggers SIGINT (^C). # https://docs.python.org/3/library/asyncio-runner.html#handling-keyboard-interruption # Clean up and exit without re-raising CancelledError to prevent Python # from raising KeyboardInterrupt and displaying a stack track. except asyncio.CancelledError: # pragma: no cover pass finally: incoming.cancel() outgoing.cancel() transport.close() await websocket.close() assert websocket.close_code is not None and websocket.close_reason is not None close_status = Close(websocket.close_code, websocket.close_reason) print_over_input(f"Connection closed: {close_status}.") def main(argv: list[str] | None = None) -> None: parser = argparse.ArgumentParser( prog="websockets", description="Interactive WebSocket client.", add_help=False, ) group = parser.add_mutually_exclusive_group() group.add_argument("--version", action="store_true") group.add_argument("uri", metavar="", nargs="?") args = parser.parse_args(argv) if args.version: print(f"websockets {websockets_version}") return if args.uri is None: parser.print_usage() sys.exit(2) # Enable VT100 to support ANSI escape codes in Command Prompt on Windows. # See https://github.com/python/cpython/issues/74261 for why this works. if sys.platform == "win32": os.system("") try: import readline # noqa: F401 except ImportError: # readline isn't available on all platforms pass # Remove the try/except block when dropping Python < 3.11. try: asyncio.run(interactive_client(args.uri)) except KeyboardInterrupt: # pragma: no cover pass websockets-15.0.1/src/websockets/client.py000066400000000000000000000323741476212450300205540ustar00rootroot00000000000000from __future__ import annotations import os import random import warnings from collections.abc import Generator, Sequence from typing import Any from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidMessage, InvalidStatus, InvalidUpgrade, NegotiationError, ) from .extensions import ClientExtensionFactory, Extension from .headers import ( build_authorization_basic, build_extension, build_host, build_subprotocol, parse_connection, parse_extension, parse_subprotocol, parse_upgrade, ) from .http11 import Request, Response from .imports import lazy_import from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, LoggerLike, Origin, Subprotocol, UpgradeProtocol, ) from .uri import WebSocketURI from .utils import accept_key, generate_key __all__ = ["ClientProtocol"] class ClientProtocol(Protocol): """ Sans-I/O implementation of a WebSocket client connection. Args: uri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. subprotocols: List of supported subprotocols, in order of decreasing preference. state: Initial state of the WebSocket connection. max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. """ def __init__( self, uri: WebSocketURI, *, origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, ) -> None: super().__init__( side=CLIENT, state=state, max_size=max_size, logger=logger, ) self.uri = uri self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.key = generate_key() def connect(self) -> Request: """ Create a handshake request to open a connection. You must send the handshake request with :meth:`send_request`. You can modify it before sending it, for example to add HTTP headers. Returns: WebSocket handshake request event to send to the server. """ headers = Headers() headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) if self.uri.user_info: headers["Authorization"] = build_authorization_basic(*self.uri.user_info) if self.origin is not None: headers["Origin"] = self.origin headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = self.key headers["Sec-WebSocket-Version"] = "13" if self.available_extensions is not None: headers["Sec-WebSocket-Extensions"] = build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in self.available_extensions ] ) if self.available_subprotocols is not None: headers["Sec-WebSocket-Protocol"] = build_subprotocol( self.available_subprotocols ) return Request(self.uri.resource_name, headers) def process_response(self, response: Response) -> None: """ Check a handshake response. Args: request: WebSocket handshake response received from the server. Raises: InvalidHandshake: If the handshake response is invalid. """ if response.status_code != 101: raise InvalidStatus(response) headers = response.headers connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None ) upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) try: s_w_accept = headers["Sec-WebSocket-Accept"] except KeyError: raise InvalidHeader("Sec-WebSocket-Accept") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) self.extensions = self.process_extensions(headers) self.subprotocol = self.process_subprotocol(headers) def process_extensions(self, headers: Headers) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. Check that each extension is supported, as well as its parameters. :rfc:`6455` leaves the rules up to the specification of each extension. To provide this level of flexibility, for each extension accepted by the server, we check for a match with each extension available in the client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, it may be configured several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. Args: headers: WebSocket handshake response headers. Returns: List of accepted extensions. Raises: InvalidHandshake: To abort the handshake. """ accepted_extensions: list[Extension] = [] extensions = headers.get_all("Sec-WebSocket-Extensions") if extensions: if self.available_extensions is None: raise NegotiationError("no extensions supported") parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] ) for name, response_params in parsed_extensions: for extension_factory in self.available_extensions: # Skip non-matching extensions based on their name. if extension_factory.name != name: continue # Skip non-matching extensions based on their params. try: extension = extension_factory.process_response_params( response_params, accepted_extensions ) except NegotiationError: continue # Add matching extension to the final list. accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list # matched what the server sent. Fail the connection. else: raise NegotiationError( f"Unsupported extension: " f"name = {name}, params = {response_params}" ) return accepted_extensions def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. If provided, check that it contains exactly one supported subprotocol. Args: headers: WebSocket handshake response headers. Returns: Subprotocol, if one was selected. """ subprotocol: Subprotocol | None = None subprotocols = headers.get_all("Sec-WebSocket-Protocol") if subprotocols: if self.available_subprotocols is None: raise NegotiationError("no subprotocols supported") parsed_subprotocols: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in subprotocols], [] ) if len(parsed_subprotocols) > 1: raise InvalidHeader( "Sec-WebSocket-Protocol", f"multiple values: {', '.join(parsed_subprotocols)}", ) subprotocol = parsed_subprotocols[0] if subprotocol not in self.available_subprotocols: raise NegotiationError(f"unsupported subprotocol: {subprotocol}") return subprotocol def send_request(self, request: Request) -> None: """ Send a handshake request to the server. Args: request: WebSocket handshake request event. """ if self.debug: self.logger.debug("> GET %s HTTP/1.1", request.path) for key, value in request.headers.raw_items(): self.logger.debug("> %s: %s", key, value) self.writes.append(request.serialize()) def parse(self) -> Generator[None]: if self.state is CONNECTING: try: response = yield from Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, ) except Exception as exc: self.handshake_exc = InvalidMessage( "did not receive a valid HTTP response" ) self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield if self.debug: code, phrase = response.status_code, response.reason_phrase self.logger.debug("< HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("< %s: %s", key, value) if response.body: self.logger.debug("< [body] (%d bytes)", len(response.body)) try: self.process_response(response) except InvalidHandshake as exc: response._exception = exc self.events.append(response) self.handshake_exc = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield assert self.state is CONNECTING self.state = OPEN self.events.append(response) yield from super().parse() class ClientConnection(ClientProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( # deprecated in 11.0 - 2023-04-02 "ClientConnection was renamed to ClientProtocol", DeprecationWarning, ) super().__init__(*args, **kwargs) BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) def backoff( initial_delay: float = BACKOFF_INITIAL_DELAY, min_delay: float = BACKOFF_MIN_DELAY, max_delay: float = BACKOFF_MAX_DELAY, factor: float = BACKOFF_FACTOR, ) -> Generator[float]: """ Generate a series of backoff delays between reconnection attempts. Yields: How many seconds to wait before retrying to connect. """ # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. yield random.random() * initial_delay delay = min_delay while delay < max_delay: yield delay delay *= factor while True: yield max_delay lazy_import( globals(), deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", }, ) websockets-15.0.1/src/websockets/connection.py000066400000000000000000000005031476212450300214220ustar00rootroot00000000000000from __future__ import annotations import warnings from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 warnings.warn( # deprecated in 11.0 - 2023-04-02 "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol", DeprecationWarning, ) websockets-15.0.1/src/websockets/datastructures.py000066400000000000000000000127571476212450300223560ustar00rootroot00000000000000from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, MutableMapping from typing import Any, Protocol, Union __all__ = [ "Headers", "HeadersLike", "MultipleValuesError", ] class MultipleValuesError(LookupError): """ Exception raised when :class:`Headers` has multiple values for a key. """ def __str__(self) -> str: # Implement the same logic as KeyError_str in Objects/exceptions.c. if len(self.args) == 1: return repr(self.args[0]) return super().__str__() class Headers(MutableMapping[str, str]): """ Efficient data structure for manipulating HTTP headers. A :class:`list` of ``(name, values)`` is inefficient for lookups. A :class:`dict` doesn't suffice because header names are case-insensitive and multiple occurrences of headers with the same name are possible. :class:`Headers` stores HTTP headers in a hybrid data structure to provide efficient insertions and lookups while preserving the original data. In order to account for multiple values with minimal hassle, :class:`Headers` follows this logic: - When getting a header with ``headers[name]``: - if there's no value, :exc:`KeyError` is raised; - if there's exactly one value, it's returned; - if there's more than one value, :exc:`MultipleValuesError` is raised. - When setting a header with ``headers[name] = value``, the value is appended to the list of values for that header. - When deleting a header with ``del headers[name]``, all values for that header are removed (this is slow). Other methods for manipulating headers are consistent with this logic. As long as no header occurs multiple times, :class:`Headers` behaves like :class:`dict`, except keys are lower-cased to provide case-insensitivity. Two methods support manipulating multiple values explicitly: - :meth:`get_all` returns a list of all values for a header; - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. """ __slots__ = ["_dict", "_list"] # Like dict, Headers accepts an optional "mapping or iterable" argument. def __init__(self, *args: HeadersLike, **kwargs: str) -> None: self._dict: dict[str, list[str]] = {} self._list: list[tuple[str, str]] = [] self.update(*args, **kwargs) def __str__(self) -> str: return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" def __repr__(self) -> str: return f"{self.__class__.__name__}({self._list!r})" def copy(self) -> Headers: copy = self.__class__() copy._dict = self._dict.copy() copy._list = self._list.copy() return copy def serialize(self) -> bytes: # Since headers only contain ASCII characters, we can keep this simple. return str(self).encode() # Collection methods def __contains__(self, key: object) -> bool: return isinstance(key, str) and key.lower() in self._dict def __iter__(self) -> Iterator[str]: return iter(self._dict) def __len__(self) -> int: return len(self._dict) # MutableMapping methods def __getitem__(self, key: str) -> str: value = self._dict[key.lower()] if len(value) == 1: return value[0] else: raise MultipleValuesError(key) def __setitem__(self, key: str, value: str) -> None: self._dict.setdefault(key.lower(), []).append(value) self._list.append((key, value)) def __delitem__(self, key: str) -> None: key_lower = key.lower() self._dict.__delitem__(key_lower) # This is inefficient. Fortunately deleting HTTP headers is uncommon. self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return NotImplemented return self._dict == other._dict def clear(self) -> None: """ Remove all headers. """ self._dict = {} self._list = [] def update(self, *args: HeadersLike, **kwargs: str) -> None: """ Update from a :class:`Headers` instance and/or keyword arguments. """ args = tuple( arg.raw_items() if isinstance(arg, Headers) else arg for arg in args ) super().update(*args, **kwargs) # Methods for handling multiple values def get_all(self, key: str) -> list[str]: """ Return the (possibly empty) list of all values for a header. Args: key: Header name. """ return self._dict.get(key.lower(), []) def raw_items(self) -> Iterator[tuple[str, str]]: """ Return an iterator of all values as ``(name, value)`` pairs. """ return iter(self._list) # copy of _typeshed.SupportsKeysAndGetItem. class SupportsKeysAndGetItem(Protocol): # pragma: no cover """ Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods. """ def keys(self) -> Iterable[str]: ... def __getitem__(self, key: str) -> str: ... # Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10. HeadersLike = Union[ Headers, Mapping[str, str], Iterable[tuple[str, str]], SupportsKeysAndGetItem, ] """ Types accepted where :class:`Headers` is expected. In addition to :class:`Headers` itself, this includes dict-like types where both keys and values are :class:`str`. """ websockets-15.0.1/src/websockets/exceptions.py000066400000000000000000000310131476212450300214440ustar00rootroot00000000000000""" :mod:`websockets.exceptions` defines the following hierarchy of exceptions. * :exc:`WebSocketException` * :exc:`ConnectionClosed` * :exc:`ConnectionClosedOK` * :exc:`ConnectionClosedError` * :exc:`InvalidURI` * :exc:`InvalidProxy` * :exc:`InvalidHandshake` * :exc:`SecurityError` * :exc:`ProxyError` * :exc:`InvalidProxyMessage` * :exc:`InvalidProxyStatus` * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` * :exc:`InvalidOrigin` * :exc:`InvalidUpgrade` * :exc:`NegotiationError` * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` (legacy) * :exc:`RedirectHandshake` (legacy) * :exc:`ProtocolError` (Sans-I/O) * :exc:`PayloadTooBig` (Sans-I/O) * :exc:`InvalidState` (Sans-I/O) * :exc:`ConcurrencyError` """ from __future__ import annotations import warnings from .imports import lazy_import __all__ = [ "WebSocketException", "ConnectionClosed", "ConnectionClosedOK", "ConnectionClosedError", "InvalidURI", "InvalidProxy", "InvalidHandshake", "SecurityError", "ProxyError", "InvalidProxyMessage", "InvalidProxyStatus", "InvalidMessage", "InvalidStatus", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", "ProtocolError", "PayloadTooBig", "InvalidState", "ConcurrencyError", ] class WebSocketException(Exception): """ Base class for all exceptions defined by websockets. """ class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. Attributes: rcvd: If a close frame was received, its code and reason are available in ``rcvd.code`` and ``rcvd.reason``. sent: If a close frame was sent, its code and reason are available in ``sent.code`` and ``sent.reason``. rcvd_then_sent: If close frames were received and sent, this attribute tells in which order this happened, from the perspective of this side of the connection. """ def __init__( self, rcvd: frames.Close | None, sent: frames.Close | None, rcvd_then_sent: bool | None = None, ) -> None: self.rcvd = rcvd self.sent = sent self.rcvd_then_sent = rcvd_then_sent assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None) def __str__(self) -> str: if self.rcvd is None: if self.sent is None: return "no close frame received or sent" else: return f"sent {self.sent}; no close frame received" else: if self.sent is None: return f"received {self.rcvd}; no close frame sent" else: if self.rcvd_then_sent: return f"received {self.rcvd}; then sent {self.sent}" else: return f"sent {self.sent}; then received {self.rcvd}" # code and reason attributes are provided for backwards-compatibility @property def code(self) -> int: warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code", DeprecationWarning, ) if self.rcvd is None: return frames.CloseCode.ABNORMAL_CLOSURE return self.rcvd.code @property def reason(self) -> str: warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason", DeprecationWarning, ) if self.rcvd is None: return "" return self.rcvd.reason class ConnectionClosedOK(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated properly. A close code with code 1000 (OK) or 1001 (going away) or without a code was received and sent. """ class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. A close frame with a code other than 1000 (OK) or 1001 (going away) was received or sent, or the closing handshake didn't complete properly. """ class InvalidURI(WebSocketException): """ Raised when connecting to a URI that isn't a valid WebSocket URI. """ def __init__(self, uri: str, msg: str) -> None: self.uri = uri self.msg = msg def __str__(self) -> str: return f"{self.uri} isn't a valid URI: {self.msg}" class InvalidProxy(WebSocketException): """ Raised when connecting via a proxy that isn't valid. """ def __init__(self, proxy: str, msg: str) -> None: self.proxy = proxy self.msg = msg def __str__(self) -> str: return f"{self.proxy} isn't a valid proxy: {self.msg}" class InvalidHandshake(WebSocketException): """ Base class for exceptions raised when the opening handshake fails. """ class SecurityError(InvalidHandshake): """ Raised when a handshake request or response breaks a security rule. Security limits can be configured with :doc:`environment variables <../reference/variables>`. """ class ProxyError(InvalidHandshake): """ Raised when failing to connect to a proxy. """ class InvalidProxyMessage(ProxyError): """ Raised when an HTTP proxy response is malformed. """ class InvalidProxyStatus(ProxyError): """ Raised when an HTTP proxy rejects the connection. """ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: return f"proxy rejected connection: HTTP {self.response.status_code:d}" class InvalidMessage(InvalidHandshake): """ Raised when a handshake request or response is malformed. """ class InvalidStatus(InvalidHandshake): """ Raised when a handshake response rejects the WebSocket upgrade. """ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: return ( f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" ) class InvalidHeader(InvalidHandshake): """ Raised when an HTTP header doesn't have a valid format or value. """ def __init__(self, name: str, value: str | None = None) -> None: self.name = name self.value = value def __str__(self) -> str: if self.value is None: return f"missing {self.name} header" elif self.value == "": return f"empty {self.name} header" else: return f"invalid {self.name} header: {self.value}" class InvalidHeaderFormat(InvalidHeader): """ Raised when an HTTP header cannot be parsed. The format of the header doesn't match the grammar for that header. """ def __init__(self, name: str, error: str, header: str, pos: int) -> None: super().__init__(name, f"{error} at {pos} in {header}") class InvalidHeaderValue(InvalidHeader): """ Raised when an HTTP header has a wrong value. The format of the header is correct but the value isn't acceptable. """ class InvalidOrigin(InvalidHeader): """ Raised when the Origin header in a request isn't allowed. """ def __init__(self, origin: str | None) -> None: super().__init__("Origin", origin) class InvalidUpgrade(InvalidHeader): """ Raised when the Upgrade or Connection header isn't correct. """ class NegotiationError(InvalidHandshake): """ Raised when negotiating an extension or a subprotocol fails. """ class DuplicateParameter(NegotiationError): """ Raised when a parameter name is repeated in an extension header. """ def __init__(self, name: str) -> None: self.name = name def __str__(self) -> str: return f"duplicate parameter: {self.name}" class InvalidParameterName(NegotiationError): """ Raised when a parameter name in an extension header is invalid. """ def __init__(self, name: str) -> None: self.name = name def __str__(self) -> str: return f"invalid parameter name: {self.name}" class InvalidParameterValue(NegotiationError): """ Raised when a parameter value in an extension header is invalid. """ def __init__(self, name: str, value: str | None) -> None: self.name = name self.value = value def __str__(self) -> str: if self.value is None: return f"missing value for parameter {self.name}" elif self.value == "": return f"empty value for parameter {self.name}" else: return f"invalid value for parameter {self.name}: {self.value}" class ProtocolError(WebSocketException): """ Raised when receiving or sending a frame that breaks the protocol. The Sans-I/O implementation raises this exception when: * receiving or sending a frame that contains invalid data; * receiving or sending an invalid sequence of frames. """ class PayloadTooBig(WebSocketException): """ Raised when parsing a frame with a payload that exceeds the maximum size. The Sans-I/O layer uses this exception internally. It doesn't bubble up to the I/O layer. The :meth:`~websockets.extensions.Extension.decode` method of extensions must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit. """ def __init__( self, size_or_message: int | None | str, max_size: int | None = None, cur_size: int | None = None, ) -> None: if isinstance(size_or_message, str): assert max_size is None assert cur_size is None warnings.warn( # deprecated in 14.0 - 2024-11-09 "PayloadTooBig(message) is deprecated; " "change to PayloadTooBig(size, max_size)", DeprecationWarning, ) self.message: str | None = size_or_message else: self.message = None self.size: int | None = size_or_message assert max_size is not None self.max_size: int = max_size self.cur_size: int | None = None self.set_current_size(cur_size) def __str__(self) -> str: if self.message is not None: return self.message else: message = "frame " if self.size is not None: message += f"with {self.size} bytes " if self.cur_size is not None: message += f"after reading {self.cur_size} bytes " message += f"exceeds limit of {self.max_size} bytes" return message def set_current_size(self, cur_size: int | None) -> None: assert self.cur_size is None if cur_size is not None: self.max_size += cur_size self.cur_size = cur_size class InvalidState(WebSocketException, AssertionError): """ Raised when sending a frame is forbidden in the current state. Specifically, the Sans-I/O layer raises this exception when: * sending a data frame to a connection in a state other :attr:`~websockets.protocol.State.OPEN`; * sending a control frame to a connection in a state other than :attr:`~websockets.protocol.State.OPEN` or :attr:`~websockets.protocol.State.CLOSING`. """ class ConcurrencyError(WebSocketException, RuntimeError): """ Raised when receiving or sending messages concurrently. WebSocket is a connection-oriented protocol. Reads must be serialized; so must be writes. However, reading and writing concurrently is possible. """ # At the bottom to break import cycles created by type annotations. from . import frames, http11 # noqa: E402 lazy_import( globals(), deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", }, ) websockets-15.0.1/src/websockets/extensions/000077500000000000000000000000001476212450300211125ustar00rootroot00000000000000websockets-15.0.1/src/websockets/extensions/__init__.py000066400000000000000000000001421476212450300232200ustar00rootroot00000000000000from .base import * __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] websockets-15.0.1/src/websockets/extensions/base.py000066400000000000000000000055271476212450300224070ustar00rootroot00000000000000from __future__ import annotations from collections.abc import Sequence from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] class Extension: """ Base class for extensions. """ name: ExtensionName """Extension identifier.""" def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame: """ Decode an incoming frame. Args: frame: Incoming frame. max_size: Maximum payload size in bytes. Returns: Decoded frame. Raises: PayloadTooBig: If decoding the payload exceeds ``max_size``. """ raise NotImplementedError def encode(self, frame: Frame) -> Frame: """ Encode an outgoing frame. Args: frame: Outgoing frame. Returns: Encoded frame. """ raise NotImplementedError class ClientExtensionFactory: """ Base class for client-side extension factories. """ name: ExtensionName """Extension identifier.""" def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build parameters to send to the server for this extension. Returns: Parameters to send to the server. """ raise NotImplementedError def process_response_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], ) -> Extension: """ Process parameters received from the server. Args: params: Parameters received from the server for this extension. accepted_extensions: List of previously accepted extensions. Returns: An extension instance. Raises: NegotiationError: If parameters aren't acceptable. """ raise NotImplementedError class ServerExtensionFactory: """ Base class for server-side extension factories. """ name: ExtensionName """Extension identifier.""" def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], ) -> tuple[list[ExtensionParameter], Extension]: """ Process parameters received from the client. Args: params: Parameters received from the client for this extension. accepted_extensions: List of previously accepted extensions. Returns: To accept the offer, parameters to send to the client for this extension and an extension instance. Raises: NegotiationError: To reject the offer, if parameters received from the client aren't acceptable. """ raise NotImplementedError websockets-15.0.1/src/websockets/extensions/permessage_deflate.py000066400000000000000000000621571476212450300253160ustar00rootroot00000000000000from __future__ import annotations import zlib from collections.abc import Sequence from typing import Any, Literal from .. import frames from ..exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, NegotiationError, PayloadTooBig, ProtocolError, ) from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory __all__ = [ "PerMessageDeflate", "ClientPerMessageDeflateFactory", "enable_client_permessage_deflate", "ServerPerMessageDeflateFactory", "enable_server_permessage_deflate", ] _EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] class PerMessageDeflate(Extension): """ Per-Message Deflate extension. """ name = ExtensionName("permessage-deflate") def __init__( self, remote_no_context_takeover: bool, local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, compress_settings: dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. """ if compress_settings is None: compress_settings = {} assert remote_no_context_takeover in [False, True] assert local_no_context_takeover in [False, True] assert 8 <= remote_max_window_bits <= 15 assert 8 <= local_max_window_bits <= 15 assert "wbits" not in compress_settings self.remote_no_context_takeover = remote_no_context_takeover self.local_no_context_takeover = local_no_context_takeover self.remote_max_window_bits = remote_max_window_bits self.local_max_window_bits = local_max_window_bits self.compress_settings = compress_settings if not self.remote_no_context_takeover: self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) if not self.local_no_context_takeover: self.encoder = zlib.compressobj( wbits=-self.local_max_window_bits, **self.compress_settings, ) # To handle continuation frames properly, we must keep track of # whether that initial frame was encoded. self.decode_cont_data = False # There's no need for self.encode_cont_data because we always encode # outgoing frames, so it would always be True. def __repr__(self) -> str: return ( f"PerMessageDeflate(" f"remote_no_context_takeover={self.remote_no_context_takeover}, " f"local_no_context_takeover={self.local_no_context_takeover}, " f"remote_max_window_bits={self.remote_max_window_bits}, " f"local_max_window_bits={self.local_max_window_bits})" ) def decode( self, frame: frames.Frame, *, max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. """ # Skip control frames. if frame.opcode in frames.CTRL_OPCODES: return frame # Handle continuation data frames: # - skip if the message isn't encoded # - reset "decode continuation data" flag if it's a final frame if frame.opcode is frames.OP_CONT: if not self.decode_cont_data: return frame if frame.fin: self.decode_cont_data = False # Handle text and binary data frames: # - skip if the message isn't encoded # - unset the rsv1 flag on the first frame of a compressed message # - set "decode continuation data" flag if it's a non-final frame else: if not frame.rsv1: return frame if not frame.fin: self.decode_cont_data = True # Re-initialize per-message decoder. if self.remote_no_context_takeover: self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). if frame.fin and len(frame.data) < 2044: # Profiling shows that appending four bytes, which makes a copy, is # faster than calling decompress() again when data is less than 2kB. data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK else: data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: assert max_size is not None # help mypy raise PayloadTooBig(None, max_size) if frame.fin and len(frame.data) >= 2044: # This cannot generate additional data. self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: del self.decoder return frames.Frame( frame.opcode, data, frame.fin, # Unset the rsv1 flag on the first frame of a compressed message. False, frame.rsv2, frame.rsv3, ) def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. """ # Skip control frames. if frame.opcode in frames.CTRL_OPCODES: return frame # Since we always encode messages, there's no "encode continuation # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( wbits=-self.local_max_window_bits, **self.compress_settings, ) # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) if frame.fin: # Sync flush generates between 5 or 6 bytes, ending with the bytes # 0x00 0x00 0xff 0xff, which must be removed. assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK # Making a copy is faster than memoryview(a)[:-4] until 2kB. if len(data) < 2048: data = data[:-4] else: data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder return frames.Frame( frame.opcode, data, frame.fin, # Set the rsv1 flag on the first frame of a compressed message. frame.opcode is not frames.OP_CONT, frame.rsv2, frame.rsv3, ) def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, client_max_window_bits: int | Literal[True] | None, ) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ params: list[ExtensionParameter] = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: params.append(("client_no_context_takeover", None)) if server_max_window_bits: params.append(("server_max_window_bits", str(server_max_window_bits))) if client_max_window_bits is True: # only in handshake requests params.append(("client_max_window_bits", None)) elif client_max_window_bits: params.append(("client_max_window_bits", str(client_max_window_bits))) return params def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool ) -> tuple[bool, bool, int | None, int | Literal[True] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be provided without a value. This is only allowed in handshake requests. """ server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None client_max_window_bits: int | Literal[True] | None = None for name, value in params: if name == "server_no_context_takeover": if server_no_context_takeover: raise DuplicateParameter(name) if value is None: server_no_context_takeover = True else: raise InvalidParameterValue(name, value) elif name == "client_no_context_takeover": if client_no_context_takeover: raise DuplicateParameter(name) if value is None: client_no_context_takeover = True else: raise InvalidParameterValue(name, value) elif name == "server_max_window_bits": if server_max_window_bits is not None: raise DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: raise InvalidParameterValue(name, value) elif name == "client_max_window_bits": if client_max_window_bits is not None: raise DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: raise InvalidParameterValue(name, value) else: raise InvalidParameterName(name) return ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ Client-side extension factory for the Per-Message Deflate extension. Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. Args: server_no_context_takeover: Prevent server from using context takeover. client_no_context_takeover: Prevent client from using context takeover. server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15, or :obj:`True` to indicate support without setting a limit. compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. """ name = ExtensionName("permessage-deflate") def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | Literal[True] | None = True, compress_settings: dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") if not ( client_max_window_bits is None or client_max_window_bits is True or 8 <= client_max_window_bits <= 15 ): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and "wbits" in compress_settings: raise ValueError( "compress_settings must not include wbits, " "set client_max_window_bits instead" ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build request parameters. """ return _build_parameters( self.server_no_context_takeover, self.client_no_context_takeover, self.server_max_window_bits, self.client_max_window_bits, ) def process_response_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], ) -> PerMessageDeflate: """ Process response parameters. Return an extension instance. """ if any(other.name == self.name for other in accepted_extensions): raise NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. # Load response parameters in local variables. ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) = _extract_parameters(params, is_server=False) # After comparing the request and the response, the final # configuration must be available in the local variables. # server_no_context_takeover # # Req. Resp. Result # ------ ------ -------------------------------------------------- # False False False # False True True # True False Error! # True True True if self.server_no_context_takeover: if not server_no_context_takeover: raise NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # # Req. Resp. Result # ------ ------ -------------------------------------------------- # False False False # False True True # True False True - must change value # True True True if self.client_no_context_takeover: if not client_no_context_takeover: client_no_context_takeover = True # server_max_window_bits # Req. Resp. Result # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 M # 8≤N≤15 None Error! # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.server_max_window_bits: raise NegotiationError("unsupported server_max_window_bits") # client_max_window_bits # Req. Resp. Result # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 Error! # True None None # True 8≤M≤15 M # 8≤N≤15 None N - must change value # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.client_max_window_bits: raise NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover client_no_context_takeover, # local_no_context_takeover server_max_window_bits or 15, # remote_max_window_bits client_max_window_bits or 15, # local_max_window_bits self.compress_settings, ) def enable_client_permessage_deflate( extensions: Sequence[ClientExtensionFactory] | None, ) -> Sequence[ClientExtensionFactory]: """ Enable Per-Message Deflate with default settings in client extensions. If the extension is already present, perhaps with non-default settings, the configuration isn't changed. """ if extensions is None: extensions = [] if not any( extension_factory.name == ClientPerMessageDeflateFactory.name for extension_factory in extensions ): extensions = list(extensions) + [ ClientPerMessageDeflateFactory( compress_settings={"memLevel": 5}, ) ] return extensions class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. Args: server_no_context_takeover: Prevent server from using context takeover. client_no_context_takeover: Prevent client from using context takeover. server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15. compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. require_client_max_window_bits: Do not enable compression at all if client doesn't advertise support for ``client_max_window_bits``; the default behavior is to enable compression without enforcing ``client_max_window_bits``. """ name = ExtensionName("permessage-deflate") def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | None = None, compress_settings: dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ Configure the Per-Message Deflate extension factory. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and "wbits" in compress_settings: raise ValueError( "compress_settings must not include wbits, " "set server_max_window_bits instead" ) if client_max_window_bits is None and require_client_max_window_bits: raise ValueError( "require_client_max_window_bits is enabled, " "but client_max_window_bits isn't configured" ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings self.require_client_max_window_bits = require_client_max_window_bits def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], ) -> tuple[list[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. Return response params and an extension instance. """ if any(other.name == self.name for other in accepted_extensions): raise NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits, ) = _extract_parameters(params, is_server=True) # Configuration parameters are available in instance variables. # After comparing the request and the configuration, the response must # be available in the local variables. # server_no_context_takeover # # Config Req. Resp. # ------ ------ -------------------------------------------------- # False False False # False True True # True False True - must change value to True # True True True if self.server_no_context_takeover: if not server_no_context_takeover: server_no_context_takeover = True # client_no_context_takeover # # Config Req. Resp. # ------ ------ -------------------------------------------------- # False False False # False True True (or False) # True False True - must change value to True # True True True (or False) if self.client_no_context_takeover: if not client_no_context_takeover: client_no_context_takeover = True # server_max_window_bits # Config Req. Resp. # ------ ------ -------------------------------------------------- # None None None # None 8≤M≤15 M # 8≤N≤15 None N - must change value # 8≤N≤15 8≤M≤N M # 8≤N≤15 N self.server_max_window_bits: server_max_window_bits = self.server_max_window_bits # client_max_window_bits # Config Req. Resp. # ------ ------ -------------------------------------------------- # None None None # None True None - must change value # None 8≤M≤15 M (or None) # 8≤N≤15 None None or Error! # 8≤N≤15 True N - must change value # 8≤N≤15 8≤M≤N M (or None) # 8≤N≤15 N Sequence[ServerExtensionFactory]: """ Enable Per-Message Deflate with default settings in server extensions. If the extension is already present, perhaps with non-default settings, the configuration isn't changed. """ if extensions is None: extensions = [] if not any( ext_factory.name == ServerPerMessageDeflateFactory.name for ext_factory in extensions ): extensions = list(extensions) + [ ServerPerMessageDeflateFactory( server_max_window_bits=12, client_max_window_bits=12, compress_settings={"memLevel": 5}, ) ] return extensions websockets-15.0.1/src/websockets/frames.py000066400000000000000000000307271476212450300205530ustar00rootroot00000000000000from __future__ import annotations import dataclasses import enum import io import os import secrets import struct from collections.abc import Generator, Sequence from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError try: from .speedups import apply_mask except ImportError: from .utils import apply_mask __all__ = [ "Opcode", "OP_CONT", "OP_TEXT", "OP_BINARY", "OP_CLOSE", "OP_PING", "OP_PONG", "DATA_OPCODES", "CTRL_OPCODES", "CloseCode", "Frame", "Close", ] class Opcode(enum.IntEnum): """Opcode values for WebSocket frames.""" CONT, TEXT, BINARY = 0x00, 0x01, 0x02 CLOSE, PING, PONG = 0x08, 0x09, 0x0A OP_CONT = Opcode.CONT OP_TEXT = Opcode.TEXT OP_BINARY = Opcode.BINARY OP_CLOSE = Opcode.CLOSE OP_PING = Opcode.PING OP_PONG = Opcode.PONG DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG class CloseCode(enum.IntEnum): """Close code values for WebSocket close frames.""" NORMAL_CLOSURE = 1000 GOING_AWAY = 1001 PROTOCOL_ERROR = 1002 UNSUPPORTED_DATA = 1003 # 1004 is reserved NO_STATUS_RCVD = 1005 ABNORMAL_CLOSURE = 1006 INVALID_DATA = 1007 POLICY_VIOLATION = 1008 MESSAGE_TOO_BIG = 1009 MANDATORY_EXTENSION = 1010 INTERNAL_ERROR = 1011 SERVICE_RESTART = 1012 TRY_AGAIN_LATER = 1013 BAD_GATEWAY = 1014 TLS_HANDSHAKE = 1015 # See https://www.iana.org/assignments/websocket/websocket.xhtml CLOSE_CODE_EXPLANATIONS: dict[int, str] = { CloseCode.NORMAL_CLOSURE: "OK", CloseCode.GOING_AWAY: "going away", CloseCode.PROTOCOL_ERROR: "protocol error", CloseCode.UNSUPPORTED_DATA: "unsupported data", CloseCode.NO_STATUS_RCVD: "no status received [internal]", CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]", CloseCode.INVALID_DATA: "invalid frame payload data", CloseCode.POLICY_VIOLATION: "policy violation", CloseCode.MESSAGE_TOO_BIG: "message too big", CloseCode.MANDATORY_EXTENSION: "mandatory extension", CloseCode.INTERNAL_ERROR: "internal error", CloseCode.SERVICE_RESTART: "service restart", CloseCode.TRY_AGAIN_LATER: "try again later", CloseCode.BAD_GATEWAY: "bad gateway", CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]", } # Close code that are allowed in a close frame. # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = { CloseCode.NORMAL_CLOSURE, CloseCode.GOING_AWAY, CloseCode.PROTOCOL_ERROR, CloseCode.UNSUPPORTED_DATA, CloseCode.INVALID_DATA, CloseCode.POLICY_VIOLATION, CloseCode.MESSAGE_TOO_BIG, CloseCode.MANDATORY_EXTENSION, CloseCode.INTERNAL_ERROR, CloseCode.SERVICE_RESTART, CloseCode.TRY_AGAIN_LATER, CloseCode.BAD_GATEWAY, } OK_CLOSE_CODES = { CloseCode.NORMAL_CLOSURE, CloseCode.GOING_AWAY, CloseCode.NO_STATUS_RCVD, } BytesLike = bytes, bytearray, memoryview @dataclasses.dataclass class Frame: """ WebSocket frame. Attributes: opcode: Opcode. data: Payload data. fin: FIN bit. rsv1: RSV1 bit. rsv2: RSV2 bit. rsv3: RSV3 bit. Only these fields are needed. The MASK bit, payload length and masking-key are handled on the fly when parsing and serializing frames. """ opcode: Opcode data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False rsv3: bool = False # Configure if you want to see more in logs. Should be a multiple of 3. MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) def __str__(self) -> str: """ Return a human-readable representation of a frame. """ coding = None length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}" non_final = "" if self.fin else "continued" if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data if len(binary) > self.MAX_LOG_SIZE // 3: cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) elif self.data: # We don't know if a Continuation frame contains text or binary. # Ping and Pong frames could contain UTF-8. # Attempt to decode as UTF-8 and display it as text; fallback to # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data if len(binary) > self.MAX_LOG_SIZE // 3: cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" if len(data) > self.MAX_LOG_SIZE: cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) return f"{self.opcode.name} {data} [{metadata}]" @classmethod def parse( cls, read_exact: Callable[[int], Generator[None, None, bytes]], *, mask: bool, max_size: int | None = None, extensions: Sequence[extensions.Extension] | None = None, ) -> Generator[None, None, Frame]: """ Parse a WebSocket frame. This is a generator-based coroutine. Args: read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. mask: Whether the frame should be masked i.e. whether the read happens on the server side. max_size: Maximum payload size in bytes. extensions: List of extensions, applied in reverse order. Raises: EOFError: If the connection is closed without a full WebSocket frame. PayloadTooBig: If the frame's payload size exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. """ # Read the header. data = yield from read_exact(2) head1, head2 = struct.unpack("!BB", data) # While not Pythonic, this is marginally faster than calling bool(). fin = True if head1 & 0b10000000 else False rsv1 = True if head1 & 0b01000000 else False rsv2 = True if head1 & 0b00100000 else False rsv3 = True if head1 & 0b00010000 else False try: opcode = Opcode(head1 & 0b00001111) except ValueError as exc: raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: raise ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: data = yield from read_exact(2) (length,) = struct.unpack("!H", data) elif length == 127: data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig(length, max_size) if mask: mask_bytes = yield from read_exact(4) # Read the data. data = yield from read_exact(length) if mask: data = apply_mask(data, mask_bytes) frame = cls(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] for extension in reversed(extensions): frame = extension.decode(frame, max_size=max_size) frame.check() return frame def serialize( self, *, mask: bool, extensions: Sequence[extensions.Extension] | None = None, ) -> bytes: """ Serialize a WebSocket frame. Args: mask: Whether the frame should be masked i.e. whether the write happens on the client side. extensions: List of extensions, applied in order. Raises: ProtocolError: If the frame contains incorrect values. """ self.check() if extensions is None: extensions = [] for extension in extensions: self = extension.encode(self) output = io.BytesIO() # Prepare the header. head1 = ( (0b10000000 if self.fin else 0) | (0b01000000 if self.rsv1 else 0) | (0b00100000 if self.rsv2 else 0) | (0b00010000 if self.rsv3 else 0) | self.opcode ) head2 = 0b10000000 if mask else 0 length = len(self.data) if length < 126: output.write(struct.pack("!BB", head1, head2 | length)) elif length < 65536: output.write(struct.pack("!BBH", head1, head2 | 126, length)) else: output.write(struct.pack("!BBQ", head1, head2 | 127, length)) if mask: mask_bytes = secrets.token_bytes(4) output.write(mask_bytes) # Prepare the data. if mask: data = apply_mask(self.data, mask_bytes) else: data = self.data output.write(data) return output.getvalue() def check(self) -> None: """ Check that reserved bits and opcode have acceptable values. Raises: ProtocolError: If a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: raise ProtocolError("reserved bits must be 0") if self.opcode in CTRL_OPCODES: if len(self.data) > 125: raise ProtocolError("control frame too long") if not self.fin: raise ProtocolError("fragmented control frame") @dataclasses.dataclass class Close: """ Code and reason for WebSocket close frames. Attributes: code: Close code. reason: Close reason. """ code: int reason: str def __str__(self) -> str: """ Return a human-readable representation of a close code and reason. """ if 3000 <= self.code < 4000: explanation = "registered" elif 4000 <= self.code < 5000: explanation = "private use" else: explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown") result = f"{self.code} ({explanation})" if self.reason: result = f"{result} {self.reason}" return result @classmethod def parse(cls, data: bytes) -> Close: """ Parse the payload of a close frame. Args: data: Payload of the close frame. Raises: ProtocolError: If data is ill-formed. UnicodeDecodeError: If the reason isn't valid UTF-8. """ if len(data) >= 2: (code,) = struct.unpack("!H", data[:2]) reason = data[2:].decode() close = cls(code, reason) close.check() return close elif len(data) == 0: return cls(CloseCode.NO_STATUS_RCVD, "") else: raise ProtocolError("close frame too short") def serialize(self) -> bytes: """ Serialize the payload of a close frame. """ self.check() return struct.pack("!H", self.code) + self.reason.encode() def check(self) -> None: """ Check that the close code has a valid value for a close frame. Raises: ProtocolError: If the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): raise ProtocolError("invalid status code") # At the bottom to break import cycles created by type annotations. from . import extensions # noqa: E402 websockets-15.0.1/src/websockets/headers.py000066400000000000000000000372561476212450300207150ustar00rootroot00000000000000from __future__ import annotations import base64 import binascii import ipaddress import re from collections.abc import Sequence from typing import Callable, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( ConnectionOption, ExtensionHeader, ExtensionName, ExtensionParameter, Subprotocol, UpgradeProtocol, ) __all__ = [ "build_host", "parse_connection", "parse_upgrade", "parse_extension", "build_extension", "parse_subprotocol", "build_subprotocol", "validate_subprotocols", "build_www_authenticate_basic", "parse_authorization_basic", "build_authorization_basic", ] T = TypeVar("T") def build_host( host: str, port: int, secure: bool, *, always_include_port: bool = False, ) -> str: """ Build a ``Host`` header. """ # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) except ValueError: # host is a hostname pass else: # host is an IP address if address.version == 6: host = f"[{host}]" if always_include_port or port != (443 if secure else 80): host = f"{host}:{port}" return host # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and # https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. def peek_ahead(header: str, pos: int) -> str | None: """ Return the next character from ``header`` at the given position. Return :obj:`None` at the end of ``header``. We never need to peek more than one character ahead. """ return None if pos == len(header) else header[pos] _OWS_re = re.compile(r"[\t ]*") def parse_OWS(header: str, pos: int) -> int: """ Parse optional whitespace from ``header`` at the given position. Return the new position. The whitespace itself isn't returned because it isn't significant. """ # There's always a match, possibly empty, whose content doesn't matter. match = _OWS_re.match(header, pos) assert match is not None return match.end() _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token from ``header`` at the given position. Return the token value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ match = _token_re.match(header, pos) if match is None: raise InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() _quoted_string_re = re.compile( r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"' ) _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a quoted string from ``header`` at the given position. Return the unquoted value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ match = _quoted_string_re.match(header, pos) if match is None: raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() _quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*") _quote_re = re.compile(r"([\x22\x5c])") def build_quoted_string(value: str) -> str: """ Format ``value`` as a quoted string. This is the reverse of :func:`parse_quoted_string`. """ match = _quotable_re.fullmatch(value) if match is None: raise ValueError("invalid characters for quoted-string encoding") return '"' + _quote_re.sub(r"\\\1", value) + '"' def parse_list( parse_item: Callable[[str, int, str], tuple[T, int]], header: str, pos: int, header_name: str, ) -> list[T]: """ Parse a comma-separated list from ``header`` at the given position. This is appropriate for parsing values with the following grammar: 1#item ``parse_item`` parses one item. ``header`` is assumed not to start or end with whitespace. (This function is designed for parsing an entire header value and :func:`~websockets.http.read_headers` strips whitespace from values.) Return a list of items. Raises: InvalidHeaderFormat: On invalid inputs. """ # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient # MUST parse and ignore a reasonable number of empty list elements"; # hence while loops that remove extra delimiters. # Remove extra delimiters before the first item. while peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) items = [] while True: # Loop invariant: a item starts at pos in header. item, pos = parse_item(header, pos, header_name) items.append(item) pos = parse_OWS(header, pos) # We may have reached the end of the header. if pos == len(header): break # There must be a delimiter after each element except the last one. if peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) else: raise InvalidHeaderFormat(header_name, "expected comma", header, pos) # Remove extra delimiters before the next item. while peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) # We may have reached the end of the header. if pos == len(header): break # Since we only advance in the header by one character with peek_ahead() # or with the end position of a regex match, we can't overshoot the end. assert pos == len(header) return items def parse_connection_option( header: str, pos: int, header_name: str ) -> tuple[ConnectionOption, int]: """ Parse a Connection option from ``header`` at the given position. Return the protocol value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) return cast(ConnectionOption, item), pos def parse_connection(header: str) -> list[ConnectionOption]: """ Parse a ``Connection`` header. Return a list of HTTP connection options. Args header: value of the ``Connection`` header. Raises: InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") _protocol_re = re.compile( r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?" ) def parse_upgrade_protocol( header: str, pos: int, header_name: str ) -> tuple[UpgradeProtocol, int]: """ Parse an Upgrade protocol from ``header`` at the given position. Return the protocol value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ match = _protocol_re.match(header, pos) if match is None: raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) return cast(UpgradeProtocol, match.group()), match.end() def parse_upgrade(header: str) -> list[UpgradeProtocol]: """ Parse an ``Upgrade`` header. Return a list of HTTP protocols. Args: header: Value of the ``Upgrade`` header. Raises: InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") def parse_extension_item_param( header: str, pos: int, header_name: str ) -> tuple[ExtensionParameter, int]: """ Parse a single extension parameter from ``header`` at the given position. Return a ``(name, value)`` pair and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ # Extract parameter name. name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract parameter value, if there is one. value: str | None = None if peek_ahead(header, pos) == "=": pos = parse_OWS(header, pos + 1) if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: raise InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before ) else: value, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) return (name, value), pos def parse_extension_item( header: str, pos: int, header_name: str ) -> tuple[ExtensionHeader, int]: """ Parse an extension definition from ``header`` at the given position. Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ # Extract extension name. name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract all parameters. parameters = [] while peek_ahead(header, pos) == ";": pos = parse_OWS(header, pos + 1) parameter, pos = parse_extension_item_param(header, pos, header_name) parameters.append(parameter) return (cast(ExtensionName, name), parameters), pos def parse_extension(header: str) -> list[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. Return a list of WebSocket extensions and their parameters in this format:: [ ( 'extension name', [ ('parameter name', 'parameter value'), .... ] ), ... ] Parameter values are :obj:`None` when no value is provided. Raises: InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") parse_extension_list = parse_extension # alias for backwards compatibility def build_extension_item( name: ExtensionName, parameters: Sequence[ExtensionParameter] ) -> str: """ Build an extension definition. This is the reverse of :func:`parse_extension_item`. """ return "; ".join( [cast(str, name)] + [ # Quoted strings aren't necessary because values are always tokens. name if value is None else f"{name}={value}" for name, value in parameters ] ) def build_extension(extensions: Sequence[ExtensionHeader]) -> str: """ Build a ``Sec-WebSocket-Extensions`` header. This is the reverse of :func:`parse_extension`. """ return ", ".join( build_extension_item(name, parameters) for name, parameters in extensions ) build_extension_list = build_extension # alias for backwards compatibility def parse_subprotocol_item( header: str, pos: int, header_name: str ) -> tuple[Subprotocol, int]: """ Parse a subprotocol from ``header`` at the given position. Return the subprotocol value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) return cast(Subprotocol, item), pos def parse_subprotocol(header: str) -> list[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. Return a list of WebSocket subprotocols. Raises: InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str: """ Build a ``Sec-WebSocket-Protocol`` header. This is the reverse of :func:`parse_subprotocol`. """ return ", ".join(subprotocols) build_subprotocol_list = build_subprotocol # alias for backwards compatibility def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None: """ Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`. """ if not isinstance(subprotocols, Sequence): raise TypeError("subprotocols must be a list") if isinstance(subprotocols, str): raise TypeError("subprotocols must be a list, not a str") for subprotocol in subprotocols: if not _token_re.fullmatch(subprotocol): raise ValueError(f"invalid subprotocol: {subprotocol}") def build_www_authenticate_basic(realm: str) -> str: """ Build a ``WWW-Authenticate`` header for HTTP Basic Auth. Args: realm: Identifier of the protection space. """ # https://datatracker.ietf.org/doc/html/rfc7617#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token68 from ``header`` at the given position. Return the token value and the new position. Raises: InvalidHeaderFormat: On invalid inputs. """ match = _token68_re.match(header, pos) if match is None: raise InvalidHeaderFormat(header_name, "expected token68", header, pos) return match.group(), match.end() def parse_end(header: str, pos: int, header_name: str) -> None: """ Check that parsing reached the end of header. """ if pos < len(header): raise InvalidHeaderFormat(header_name, "trailing data", header, pos) def parse_authorization_basic(header: str) -> tuple[str, str]: """ Parse an ``Authorization`` header for HTTP Basic Auth. Return a ``(username, password)`` tuple. Args: header: Value of the ``Authorization`` header. Raises: InvalidHeaderFormat: On invalid inputs. InvalidHeaderValue: On unsupported inputs. """ # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise InvalidHeaderValue( "Authorization", f"unsupported scheme: {scheme}", ) if peek_ahead(header, pos) != " ": raise InvalidHeaderFormat( "Authorization", "expected space after scheme", header, pos ) pos += 1 basic_credentials, pos = parse_token68(header, pos, "Authorization") parse_end(header, pos, "Authorization") try: user_pass = base64.b64decode(basic_credentials.encode()).decode() except binascii.Error: raise InvalidHeaderValue( "Authorization", "expected base64-encoded credentials", ) from None try: username, password = user_pass.split(":", 1) except ValueError: raise InvalidHeaderValue( "Authorization", "expected username:password credentials", ) from None return username, password def build_authorization_basic(username: str, password: str) -> str: """ Build an ``Authorization`` header for HTTP Basic Auth. This is the reverse of :func:`parse_authorization_basic`. """ # https://datatracker.ietf.org/doc/html/rfc7617#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() return "Basic " + basic_credentials websockets-15.0.1/src/websockets/http.py000066400000000000000000000012231476212450300202420ustar00rootroot00000000000000from __future__ import annotations import warnings from .datastructures import Headers, MultipleValuesError # noqa: F401 with warnings.catch_warnings(): # Suppress redundant DeprecationWarning raised by websockets.legacy. warnings.filterwarnings("ignore", category=DeprecationWarning) from .legacy.http import read_request, read_response # noqa: F401 warnings.warn( # deprecated in 9.0 - 2021-09-01 "Headers and MultipleValuesError were moved " "from websockets.http to websockets.datastructures" "and read_request and read_response were moved " "from websockets.http to websockets.legacy.http", DeprecationWarning, ) websockets-15.0.1/src/websockets/http11.py000066400000000000000000000351151476212450300204130ustar00rootroot00000000000000from __future__ import annotations import dataclasses import os import re import sys import warnings from collections.abc import Generator from typing import Callable from .datastructures import Headers from .exceptions import SecurityError from .version import version as websockets_version __all__ = [ "SERVER", "USER_AGENT", "Request", "Response", ] PYTHON_VERSION = "{}.{}".format(*sys.version_info) # User-Agent header for HTTP requests. USER_AGENT = os.environ.get( "WEBSOCKETS_USER_AGENT", f"Python/{PYTHON_VERSION} websockets/{websockets_version}", ) # Server header for HTTP responses. SERVER = os.environ.get( "WEBSOCKETS_SERVER", f"Python/{PYTHON_VERSION} websockets/{websockets_version}", ) # Maximum total size of headers is around 128 * 8 KiB = 1 MiB. MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) # Limit request line and header lines. 8KiB is the most common default # configuration of popular HTTP servers. MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB def d(value: bytes) -> str: """ Decode a bytestring for interpolating into an error message. """ return value.decode(errors="backslashreplace") # See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. _token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") # Regex for validating header values. # We don't attempt to support obsolete line folding. # Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). # The ABNF is complicated because it attempts to express that optional # whitespace is ignored. We strip whitespace and don't revalidate that. # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") @dataclasses.dataclass class Request: """ WebSocket handshake request. Attributes: path: Request path, including optional query. headers: Request headers. """ path: str headers: Headers # body isn't useful is the context of this library. _exception: Exception | None = None @property def exception(self) -> Exception | None: # pragma: no cover warnings.warn( # deprecated in 10.3 - 2022-04-17 "Request.exception is deprecated; use ServerProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception @classmethod def parse( cls, read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Request]: """ Parse a WebSocket handshake request. This is a generator-based coroutine. The request path isn't URL-decoded or validated in any way. The request path and headers are expected to contain only ASCII characters. Other characters are represented with surrogate escapes. :meth:`parse` doesn't attempt to read the request body because WebSocket handshake requests don't have one. If the request contains a body, it may be read from the data stream after :meth:`parse` returns. Args: read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data Raises: EOFError: If the connection is closed without a full HTTP request. SecurityError: If the request exceeds a security limit. ValueError: If the request isn't well formatted. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends # to implement HTTP/1.1 strictly, there's little need for lenient parsing. try: request_line = yield from parse_line(read_line) except EOFError as exc: raise EOFError("connection closed while reading HTTP request line") from exc try: method, raw_path, protocol = request_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None if protocol != b"HTTP/1.1": raise ValueError( f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" ) if method != b"GET": raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") path = raw_path.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") if "Content-Length" in headers: raise ValueError("unsupported request body") return cls(path, headers) def serialize(self) -> bytes: """ Serialize a WebSocket handshake request. """ # Since the request line and headers only contain ASCII characters, # we can keep this simple. request = f"GET {self.path} HTTP/1.1\r\n".encode() request += self.headers.serialize() return request @dataclasses.dataclass class Response: """ WebSocket handshake response. Attributes: status_code: Response code. reason_phrase: Response reason. headers: Response headers. body: Response body. """ status_code: int reason_phrase: str headers: Headers body: bytes = b"" _exception: Exception | None = None @property def exception(self) -> Exception | None: # pragma: no cover warnings.warn( # deprecated in 10.3 - 2022-04-17 "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception @classmethod def parse( cls, read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[int], Generator[None, None, bytes]], include_body: bool = True, ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. This is a generator-based coroutine. The reason phrase and headers are expected to contain only ASCII characters. Other characters are represented with surrogate escapes. Args: read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. read_to_eof: Generator-based coroutine that reads until the end of the stream. Raises: EOFError: If the connection is closed without a full HTTP response. SecurityError: If the response exceeds a security limit. LookupError: If the response isn't well formatted. ValueError: If the response isn't well formatted. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 try: status_line = yield from parse_line(read_line) except EOFError as exc: raise EOFError("connection closed while reading HTTP status line") from exc try: protocol, raw_status_code, raw_reason = status_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None if protocol != b"HTTP/1.1": raise ValueError( f"unsupported protocol; expected HTTP/1.1: {d(status_line)}" ) try: status_code = int(raw_status_code) except ValueError: # invalid literal for int() with base 10 raise ValueError( f"invalid status code; expected integer; got {d(raw_status_code)}" ) from None if not 100 <= status_code < 600: raise ValueError( f"invalid status code; expected 100–599; got {d(raw_status_code)}" ) if not _value_re.fullmatch(raw_reason): raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") reason = raw_reason.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) if include_body: body = yield from read_body( status_code, headers, read_line, read_exact, read_to_eof ) else: body = b"" return cls(status_code, reason, headers, body) def serialize(self) -> bytes: """ Serialize a WebSocket handshake response. """ # Since the status line and headers only contain ASCII characters, # we can keep this simple. response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() response += self.headers.serialize() response += self.body return response def parse_line( read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: """ Parse a single line. CRLF is stripped from the return value. Args: read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: EOFError: If the connection is closed without a CRLF. SecurityError: If the response exceeds a security limit. """ try: line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise SecurityError("line too long") # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] def parse_headers( read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Headers]: """ Parse HTTP headers. Non-ASCII characters are represented with surrogate escapes. Args: read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: EOFError: If the connection is closed without complete headers. SecurityError: If the request exceeds a security limit. ValueError: If the request isn't well formatted. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. headers = Headers() for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) except EOFError as exc: raise EOFError("connection closed while reading HTTP headers") from exc if line == b"": break try: raw_name, raw_value = line.split(b":", 1) except ValueError: # not enough values to unpack (expected 2, got 1) raise ValueError(f"invalid HTTP header line: {d(line)}") from None if not _token_re.fullmatch(raw_name): raise ValueError(f"invalid HTTP header name: {d(raw_name)}") raw_value = raw_value.strip(b" \t") if not _value_re.fullmatch(raw_value): raise ValueError(f"invalid HTTP header value: {d(raw_value)}") name = raw_name.decode("ascii") # guaranteed to be ASCII at this point value = raw_value.decode("ascii", "surrogateescape") headers[name] = value else: raise SecurityError("too many HTTP headers") return headers def read_body( status_code: int, headers: Headers, read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 # Since websockets only does GET requests (no HEAD, no CONNECT), all # responses except 1xx, 204, and 304 include a message body. if 100 <= status_code < 200 or status_code == 204 or status_code == 304: return b"" # MultipleValuesError is sufficiently unlikely that we don't attempt to # handle it when accessing headers. Instead we document that its parent # class, LookupError, may be raised. # Conversions from str to int are protected by sys.set_int_max_str_digits.. elif (coding := headers.get("Transfer-Encoding")) is not None: if coding != "chunked": raise NotImplementedError(f"transfer coding {coding} isn't supported") body = b"" while True: chunk_size_line = yield from parse_line(read_line) raw_chunk_size = chunk_size_line.split(b";", 1)[0] # Set a lower limit than default_max_str_digits; 1 EB is plenty. if len(raw_chunk_size) > 15: str_chunk_size = raw_chunk_size.decode(errors="backslashreplace") raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes") chunk_size = int(raw_chunk_size, 16) if chunk_size == 0: break if len(body) + chunk_size > MAX_BODY_SIZE: raise SecurityError( f"chunk too large: {chunk_size} bytes after {len(body)} bytes" ) body += yield from read_exact(chunk_size) if (yield from read_exact(2)) != b"\r\n": raise ValueError("chunk without CRLF") # Read the trailer. yield from parse_headers(read_line) return body elif (raw_content_length := headers.get("Content-Length")) is not None: # Set a lower limit than default_max_str_digits; 1 EiB is plenty. if len(raw_content_length) > 18: raise SecurityError(f"body too large: {raw_content_length} bytes") content_length = int(raw_content_length) if content_length > MAX_BODY_SIZE: raise SecurityError(f"body too large: {content_length} bytes") return (yield from read_exact(content_length)) else: try: return (yield from read_to_eof(MAX_BODY_SIZE)) except RuntimeError: raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") websockets-15.0.1/src/websockets/imports.py000066400000000000000000000053531476212450300207700ustar00rootroot00000000000000from __future__ import annotations import warnings from collections.abc import Iterable from typing import Any __all__ = ["lazy_import"] def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: """ Import ``name`` from ``source`` in ``namespace``. There are two use cases: - ``name`` is an object defined in ``source``; - ``name`` is a submodule of ``source``. Neither :func:`__import__` nor :func:`~importlib.import_module` does exactly this. :func:`__import__` is closer to the intended behavior. """ level = 0 while source[level] == ".": level += 1 assert level < len(source), "importing from parent isn't supported" module = __import__(source[level:], namespace, None, [name], level) return getattr(module, name) def lazy_import( namespace: dict[str, Any], aliases: dict[str, str] | None = None, deprecated_aliases: dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. Typical use:: __getattr__, __dir__ = lazy_import( globals(), aliases={ "": "", ... }, deprecated_aliases={ ..., } ) This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`. """ if aliases is None: aliases = {} if deprecated_aliases is None: deprecated_aliases = {} namespace_set = set(namespace) aliases_set = set(aliases) deprecated_aliases_set = set(deprecated_aliases) assert not namespace_set & aliases_set, "namespace conflict" assert not namespace_set & deprecated_aliases_set, "namespace conflict" assert not aliases_set & deprecated_aliases_set, "namespace conflict" package = namespace["__name__"] def __getattr__(name: str) -> Any: assert aliases is not None # mypy cannot figure this out try: source = aliases[name] except KeyError: pass else: return import_name(name, source, namespace) assert deprecated_aliases is not None # mypy cannot figure this out try: source = deprecated_aliases[name] except KeyError: pass else: warnings.warn( f"{package}.{name} is deprecated", DeprecationWarning, stacklevel=2, ) return import_name(name, source, namespace) raise AttributeError(f"module {package!r} has no attribute {name!r}") namespace["__getattr__"] = __getattr__ def __dir__() -> Iterable[str]: return sorted(namespace_set | aliases_set | deprecated_aliases_set) namespace["__dir__"] = __dir__ websockets-15.0.1/src/websockets/legacy/000077500000000000000000000000001476212450300201575ustar00rootroot00000000000000websockets-15.0.1/src/websockets/legacy/__init__.py000066400000000000000000000004241476212450300222700ustar00rootroot00000000000000from __future__ import annotations import warnings warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.legacy is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", DeprecationWarning, ) websockets-15.0.1/src/websockets/legacy/auth.py000066400000000000000000000146031476212450300214760ustar00rootroot00000000000000from __future__ import annotations import functools import hmac import http from collections.abc import Awaitable, Iterable from typing import Any, Callable, cast from ..datastructures import Headers from ..exceptions import InvalidHeader from ..headers import build_www_authenticate_basic, parse_authorization_basic from .server import HTTPResponse, WebSocketServerProtocol __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] Credentials = tuple[str, str] def is_credentials(value: Any) -> bool: try: username, password = value except (TypeError, ValueError): return False else: return isinstance(username, str) and isinstance(password, str) class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ WebSocket server protocol that enforces HTTP Basic Auth. """ realm: str = "" """ Scope of protection. If provided, it should contain only ASCII characters because the encoding of non-ASCII characters is undefined. """ username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, realm: str | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: self.realm = realm # shadow class attribute self._check_credentials = check_credentials super().__init__(*args, **kwargs) async def check_credentials(self, username: str, password: str) -> bool: """ Check whether credentials are authorized. This coroutine may be overridden in a subclass, for example to authenticate against a database or an external service. Args: username: HTTP Basic Auth username. password: HTTP Basic Auth password. Returns: :obj:`True` if the handshake should continue; :obj:`False` if it should fail with an HTTP 401 error. """ if self._check_credentials is not None: return await self._check_credentials(username, password) return False async def process_request( self, path: str, request_headers: Headers, ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. """ try: authorization = request_headers["Authorization"] except KeyError: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Missing credentials\n", ) try: username, password = parse_authorization_basic(authorization) except InvalidHeader: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Unsupported credentials\n", ) if not await self.check_credentials(username, password): return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Invalid credentials\n", ) self.username = username return await super().process_request(path, request_headers) def basic_auth_protocol_factory( realm: str | None = None, credentials: Credentials | Iterable[Credentials] | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. :func:`basic_auth_protocol_factory` is designed to integrate with :func:`~websockets.legacy.server.serve` like this:: serve( ..., create_protocol=basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) ) Args: realm: Scope of protection. It should contain only ASCII characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. check_credentials: Coroutine that verifies credentials. It receives ``username`` and ``password`` arguments and returns a :class:`bool`. One of ``credentials`` or ``check_credentials`` must be provided but not both. create_protocol: Factory that creates the protocol. By default, this is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced by a subclass. Raises: TypeError: If the ``credentials`` or ``check_credentials`` argument is wrong. """ if (credentials is None) == (check_credentials is None): raise TypeError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(cast(Iterable[Credentials], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: raise TypeError(f"invalid credentials argument: {credentials}") credentials_dict = dict(credentials_list) async def check_credentials(username: str, password: str) -> bool: try: expected_password = credentials_dict[username] except KeyError: return False return hmac.compare_digest(expected_password, password) if create_protocol is None: create_protocol = BasicAuthWebSocketServerProtocol # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] create_protocol = cast( Callable[..., BasicAuthWebSocketServerProtocol], create_protocol ) return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials, ) websockets-15.0.1/src/websockets/legacy/client.py000066400000000000000000000645511476212450300220220ustar00rootroot00000000000000from __future__ import annotations import asyncio import functools import logging import os import random import traceback import urllib.parse import warnings from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType from typing import Any, Callable, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidHeader, InvalidHeaderValue, InvalidMessage, NegotiationError, SecurityError, ) from ..extensions import ClientExtensionFactory, Extension from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import ( build_authorization_basic, build_extension, build_host, build_subprotocol, parse_extension, parse_subprotocol, validate_subprotocols, ) from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .exceptions import InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] class WebSocketClientProtocol(WebSocketCommonProtocol): """ WebSocket client connection. :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. See :func:`connect` for the documentation of ``logger``, ``origin``, ``extensions``, ``subprotocols``, ``extra_headers``, and ``user_agent_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ is_client = True side = "client" def __init__( self, *, logger: LoggerLike | None = None, origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: logger = logging.getLogger("websockets.client") super().__init__(logger=logger, **kwargs) self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers self.user_agent_header = user_agent_header def write_http_request(self, path: str, headers: Headers) -> None: """ Write request line and headers to the HTTP request. """ self.path = path self.request_headers = headers if self.debug: self.logger.debug("> GET %s HTTP/1.1", path) for key, value in headers.raw_items(): self.logger.debug("> %s: %s", key, value) # Since the path and headers only contain ASCII characters, # we can keep this simple. request = f"GET {path} HTTP/1.1\r\n" request += str(headers) self.transport.write(request.encode()) async def read_http_response(self) -> tuple[int, Headers]: """ Read status line and headers from the HTTP response. If the response contains a body, it may be read from ``self.reader`` after this coroutine returns. Raises: InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET response. """ try: status_code, reason, headers = await read_response(self.reader) except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc if self.debug: self.logger.debug("< HTTP/1.1 %d %s", status_code, reason) for key, value in headers.raw_items(): self.logger.debug("< %s: %s", key, value) self.response_headers = headers return status_code, self.response_headers @staticmethod def process_extensions( headers: Headers, available_extensions: Sequence[ClientExtensionFactory] | None, ) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. Check that each extension is supported, as well as its parameters. Return the list of accepted extensions. Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the connection. :rfc:`6455` leaves the rules up to the specification of each :extension. To provide this level of flexibility, for each extension accepted by the server, we check for a match with each extension available in the client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, it may be configured several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. """ accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values: if available_extensions is None: raise NegotiationError("no extensions supported") parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, response_params in parsed_header_values: for extension_factory in available_extensions: # Skip non-matching extensions based on their name. if extension_factory.name != name: continue # Skip non-matching extensions based on their params. try: extension = extension_factory.process_response_params( response_params, accepted_extensions ) except NegotiationError: continue # Add matching extension to the final list. accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list # matched what the server sent. Fail the connection. else: raise NegotiationError( f"Unsupported extension: " f"name = {name}, params = {response_params}" ) return accepted_extensions @staticmethod def process_subprotocol( headers: Headers, available_subprotocols: Sequence[Subprotocol] | None ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. Check that it contains exactly one supported subprotocol. Return the selected subprotocol. """ subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values: if available_subprotocols is None: raise NegotiationError("no subprotocols supported") parsed_header_values: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) if len(parsed_header_values) > 1: raise InvalidHeaderValue( "Sec-WebSocket-Protocol", f"multiple values: {', '.join(parsed_header_values)}", ) subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: raise NegotiationError(f"unsupported subprotocol: {subprotocol}") return subprotocol async def handshake( self, wsuri: WebSocketURI, origin: Origin | None = None, available_extensions: Sequence[ClientExtensionFactory] | None = None, available_subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLike | None = None, ) -> None: """ Perform the client side of the opening handshake. Args: wsuri: URI of the WebSocket server. origin: Value of the ``Origin`` header. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers: Arbitrary HTTP headers to add to the handshake request. Raises: InvalidHandshake: If the handshake fails. """ request_headers = Headers() request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure) if wsuri.user_info: request_headers["Authorization"] = build_authorization_basic( *wsuri.user_info ) if origin is not None: request_headers["Origin"] = origin key = build_request(request_headers) if available_extensions is not None: extensions_header = build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in available_extensions ] ) request_headers["Sec-WebSocket-Extensions"] = extensions_header if available_subprotocols is not None: protocol_header = build_subprotocol(available_subprotocols) request_headers["Sec-WebSocket-Protocol"] = protocol_header if self.extra_headers is not None: request_headers.update(self.extra_headers) if self.user_agent_header: request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) status_code, response_headers = await self.read_http_response() if status_code in (301, 302, 303, 307, 308): if "Location" not in response_headers: raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) elif status_code != 101: raise InvalidStatusCode(status_code, response_headers) check_response(response_headers, key) self.extensions = self.process_extensions( response_headers, available_extensions ) self.subprotocol = self.process_subprotocol( response_headers, available_subprotocols ) self.connection_open() class Connect: """ Connect to the WebSocket server at ``uri``. Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. :func:`connect` can be used as a asynchronous context manager:: async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: async for websocket in connect(...): try: ... except websockets.exceptions.ConnectionClosed: continue The connection is closed automatically after each iteration of the loop. If an error occurs while establishing the connection, :func:`connect` retries with exponential backoff. The backoff delay starts at three seconds and increases up to one minute. If an error occurs in the body of the loop, you can handle the exception and :func:`connect` will reconnect with the next iteration; or you can let the exception bubble up and break out of the loop. This lets you decide which errors trigger a reconnection and which errors are fatal. Args: uri: URI of the WebSocket server. create_protocol: Factory for the :class:`asyncio.Protocol` managing the connection. It defaults to :class:`WebSocketClientProtocol`. Set it to a wrapper or a subclass to customize connection handling. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. origin: Value of the ``Origin`` header, for servers that require it. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers: Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. Any other keyword arguments are passed the event loop's :meth:`~asyncio.loop.create_connection` method. For example: * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS context is created with :func:`~ssl.create_default_context`. * You can set ``host`` and ``port`` to connect to a different host and port from those found in ``uri``. This only changes the destination of the TCP connection. The host name from ``uri`` is still used in the TLS handshake for secure connections and in the ``Host`` header. Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. ~asyncio.TimeoutError: If the opening handshake times out. """ MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) def __init__( self, uri: str, *, create_protocol: Callable[..., WebSocketClientProtocol] | None = None, logger: LoggerLike | None = None, compression: str | None = "deflate", origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = None, max_size: int | None = 2**20, max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: warnings.warn("rename klass to create_protocol", DeprecationWarning) # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: loop = _loop warnings.warn("remove loop argument", DeprecationWarning) wsuri = parse_uri(uri) if wsuri.secure: kwargs.setdefault("ssl", True) elif kwargs.get("ssl") is not None: raise ValueError( "connect() received a ssl argument for a ws:// URI, " "use a wss:// URI to enable TLS" ) if compression == "deflate": extensions = enable_client_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if subprotocols is not None: validate_subprotocols(subprotocols) # Help mypy and avoid this error: "type[WebSocketClientProtocol] | # Callable[..., WebSocketClientProtocol]" not callable [misc] create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol) factory = functools.partial( create_protocol, logger=logger, origin=origin, extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, user_agent_header=user_agent_header, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, host=wsuri.host, port=wsuri.port, secure=wsuri.secure, legacy_recv=legacy_recv, loop=_loop, ) if kwargs.pop("unix", False): path: str | None = kwargs.pop("path", None) create_connection = functools.partial( loop.create_unix_connection, factory, path, **kwargs ) else: host: str | None port: int | None if kwargs.get("sock") is None: host, port = wsuri.host, wsuri.port else: # If sock is given, host and port shouldn't be specified. host, port = None, None if kwargs.get("ssl"): kwargs.setdefault("server_hostname", wsuri.host) # If host and port are given, override values from the URI. host = kwargs.pop("host", host) port = kwargs.pop("port", port) create_connection = functools.partial( loop.create_connection, factory, host, port, **kwargs ) self.open_timeout = open_timeout if logger is None: logger = logging.getLogger("websockets.client") self.logger = logger # This is a coroutine function. self._create_connection = create_connection self._uri = uri self._wsuri = wsuri def handle_redirect(self, uri: str) -> None: # Update the state of this instance to connect to a new URI. old_uri = self._uri old_wsuri = self._wsuri new_uri = urllib.parse.urljoin(old_uri, uri) new_wsuri = parse_uri(new_uri) # Forbid TLS downgrade. if old_wsuri.secure and not new_wsuri.secure: raise SecurityError("redirect from WSS to WS") same_origin = ( old_wsuri.secure == new_wsuri.secure and old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port ) # Rewrite secure, host, and port for cross-origin redirects. # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. if not same_origin: factory = self._create_connection.args[0] # Support TLS upgrade. if not old_wsuri.secure and new_wsuri.secure: factory.keywords["secure"] = True self._create_connection.keywords.setdefault("ssl", True) # Replace secure, host, and port arguments of the protocol factory. factory = functools.partial( factory.func, *factory.args, **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), ) # Replace secure, host, and port arguments of create_connection. self._create_connection = functools.partial( self._create_connection.func, *(factory, new_wsuri.host, new_wsuri.port), **self._create_connection.keywords, ) # Set the new WebSocket URI. This suffices for same-origin redirects. self._uri = new_uri self._wsuri = new_wsuri # async for ... in connect(...): BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR while True: try: async with self as protocol: yield protocol except Exception as exc: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( "connect failed; reconnecting in %.1f seconds: %s", initial_delay, # Remove first argument when dropping Python 3.9. traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(initial_delay) else: self.logger.info( "connect failed again; retrying in %d seconds: %s", int(backoff_delay), # Remove first argument when dropping Python 3.9. traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(int(backoff_delay)) # Increase delay with truncated exponential backoff. backoff_delay = backoff_delay * self.BACKOFF_FACTOR backoff_delay = min(backoff_delay, self.BACKOFF_MAX) continue else: # Connection succeeded - reset backoff delay backoff_delay = self.BACKOFF_MIN # async with connect(...) as ...: async def __aenter__(self) -> WebSocketClientProtocol: return await self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: await self.protocol.close() # ... = await connect(...) def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketClientProtocol: async with asyncio_timeout(self.open_timeout): for _redirects in range(self.MAX_REDIRECTS_ALLOWED): _transport, protocol = await self._create_connection() try: await protocol.handshake( self._wsuri, origin=protocol.origin, available_extensions=protocol.available_extensions, available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, ) except RedirectHandshake as exc: protocol.fail_connection() await protocol.wait_closed() self.handle_redirect(exc.uri) # Avoid leaking a connected socket when the handshake fails. except (Exception, asyncio.CancelledError): protocol.fail_connection() await protocol.wait_closed() raise else: self.protocol = protocol return protocol else: raise SecurityError("too many redirects") # ... = yield from connect(...) - remove when dropping Python < 3.10 __iter__ = __await__ connect = Connect def unix_connect( path: str | None = None, uri: str = "ws://localhost/", **kwargs: Any, ) -> Connect: """ Similar to :func:`connect`, but for connecting to a Unix socket. This function builds upon the event loop's :meth:`~asyncio.loop.create_unix_connection` method. It is only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. Args: path: File system path to the Unix socket. uri: URI of the WebSocket server; the host is used in the TLS handshake for secure connections and in the ``Host`` header. """ return connect(uri=uri, path=path, unix=True, **kwargs) websockets-15.0.1/src/websockets/legacy/exceptions.py000066400000000000000000000036041476212450300227150ustar00rootroot00000000000000import http from .. import datastructures from ..exceptions import ( InvalidHandshake, # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1. InvalidMessage, # noqa: F401 ProtocolError as WebSocketProtocolError, # noqa: F401 ) from ..typing import StatusLike class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. """ def __init__(self, status_code: int, headers: datastructures.Headers) -> None: self.status_code = status_code self.headers = headers def __str__(self) -> str: return f"server rejected WebSocket connection: HTTP {self.status_code}" class AbortHandshake(InvalidHandshake): """ Raised to abort the handshake on purpose and return an HTTP response. This exception is an implementation detail. The public API is :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. Attributes: status (~http.HTTPStatus): HTTP status code. headers (Headers): HTTP response headers. body (bytes): HTTP response body. """ def __init__( self, status: StatusLike, headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: # If a user passes an int instead of an HTTPStatus, fix it automatically. self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body def __str__(self) -> str: return ( f"HTTP {self.status:d}, {len(self.headers)} headers, {len(self.body)} bytes" ) class RedirectHandshake(InvalidHandshake): """ Raised when a handshake gets redirected. This exception is an implementation detail. """ def __init__(self, uri: str) -> None: self.uri = uri def __str__(self) -> str: return f"redirect to {self.uri}" websockets-15.0.1/src/websockets/legacy/framing.py000066400000000000000000000143361476212450300221630ustar00rootroot00000000000000from __future__ import annotations import struct from collections.abc import Awaitable, Sequence from typing import Any, Callable, NamedTuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError from ..frames import BytesLike from ..typing import Data try: from ..speedups import apply_mask except ImportError: from ..utils import apply_mask class Frame(NamedTuple): fin: bool opcode: frames.Opcode data: bytes rsv1: bool = False rsv2: bool = False rsv3: bool = False @property def new_frame(self) -> frames.Frame: return frames.Frame( self.opcode, self.data, self.fin, self.rsv1, self.rsv2, self.rsv3, ) def __str__(self) -> str: return str(self.new_frame) def check(self) -> None: return self.new_frame.check() @classmethod async def read( cls, reader: Callable[[int], Awaitable[bytes]], *, mask: bool, max_size: int | None = None, extensions: Sequence[extensions.Extension] | None = None, ) -> Frame: """ Read a WebSocket frame. Args: reader: Coroutine that reads exactly the requested number of bytes, unless the end of file is reached. mask: Whether the frame should be masked i.e. whether the read happens on the server side. max_size: Maximum payload size in bytes. extensions: List of extensions, applied in reverse order. Raises: PayloadTooBig: If the frame exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. """ # Read the header. data = await reader(2) head1, head2 = struct.unpack("!BB", data) # While not Pythonic, this is marginally faster than calling bool(). fin = True if head1 & 0b10000000 else False rsv1 = True if head1 & 0b01000000 else False rsv2 = True if head1 & 0b00100000 else False rsv3 = True if head1 & 0b00010000 else False try: opcode = frames.Opcode(head1 & 0b00001111) except ValueError as exc: raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: raise ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: data = await reader(2) (length,) = struct.unpack("!H", data) elif length == 127: data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig(length, max_size) if mask: mask_bits = await reader(4) # Read the data. data = await reader(length) if mask: data = apply_mask(data, mask_bits) new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] for extension in reversed(extensions): new_frame = extension.decode(new_frame, max_size=max_size) new_frame.check() return cls( new_frame.fin, new_frame.opcode, new_frame.data, new_frame.rsv1, new_frame.rsv2, new_frame.rsv3, ) def write( self, write: Callable[[bytes], Any], *, mask: bool, extensions: Sequence[extensions.Extension] | None = None, ) -> None: """ Write a WebSocket frame. Args: frame: Frame to write. write: Function that writes bytes. mask: Whether the frame should be masked i.e. whether the write happens on the client side. extensions: List of extensions, applied in order. Raises: ProtocolError: If the frame contains incorrect values. """ # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. write(self.new_frame.serialize(mask=mask, extensions=extensions)) def prepare_data(data: Data) -> tuple[int, bytes]: """ Convert a string or byte-like object to an opcode and a bytes-like object. This function is designed for data frames. If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` object encoding ``data`` in UTF-8. If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like object. Raises: TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): return frames.Opcode.TEXT, data.encode() elif isinstance(data, BytesLike): return frames.Opcode.BINARY, data else: raise TypeError("data must be str or bytes-like") def prepare_ctrl(data: Data) -> bytes: """ Convert a string or byte-like object to bytes. This function is designed for ping and pong frames. If ``data`` is a :class:`str`, return a :class:`bytes` object encoding ``data`` in UTF-8. If ``data`` is a bytes-like object, return a :class:`bytes` object. Raises: TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): return data.encode() elif isinstance(data, BytesLike): return bytes(data) else: raise TypeError("data must be str or bytes-like") # Backwards compatibility with previously documented public APIs encode_data = prepare_ctrl # Backwards compatibility with previously documented public APIs from ..frames import Close # noqa: E402 F401, I001 def parse_close(data: bytes) -> tuple[int, str]: """ Parse the payload from a close frame. Returns: Close code and reason. Raises: ProtocolError: If data is ill-formed. UnicodeDecodeError: If the reason isn't valid UTF-8. """ close = Close.parse(data) return close.code, close.reason def serialize_close(code: int, reason: str) -> bytes: """ Serialize the payload for a close frame. """ return Close(code, reason).serialize() websockets-15.0.1/src/websockets/legacy/handshake.py000066400000000000000000000122451476212450300224630ustar00rootroot00000000000000from __future__ import annotations import base64 import binascii from ..datastructures import Headers, MultipleValuesError from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from ..headers import parse_connection, parse_upgrade from ..typing import ConnectionOption, UpgradeProtocol from ..utils import accept_key as accept, generate_key __all__ = ["build_request", "check_request", "build_response", "check_response"] def build_request(headers: Headers) -> str: """ Build a handshake request to send to the server. Update request headers passed in argument. Args: headers: Handshake request headers. Returns: ``key`` that must be passed to :func:`check_response`. """ key = generate_key() headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = key headers["Sec-WebSocket-Version"] = "13" return key def check_request(headers: Headers) -> str: """ Check a handshake request received from the client. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't perform ``Host`` and ``Origin`` checks. These controls are usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. Args: headers: Handshake request headers. Returns: ``key`` that must be passed to :func:`build_response`. Raises: InvalidHandshake: If the handshake request is invalid. Then, the server must return a 400 Bad Request error. """ connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. The RFC always uses "websocket", except # in section 11.2. (IANA registration) where it uses "WebSocket". if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) try: s_w_key = headers["Sec-WebSocket-Key"] except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Key") from exc except MultipleValuesError as exc: raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) except binascii.Error as exc: raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) try: s_w_version = headers["Sec-WebSocket-Version"] except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Version") from exc except MultipleValuesError as exc: raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc if s_w_version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) return s_w_key def build_response(headers: Headers, key: str) -> None: """ Build a handshake response to send to the client. Update response headers passed in argument. Args: headers: Handshake response headers. key: Returned by :func:`check_request`. """ headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept(key) def check_response(headers: Headers, key: str) -> None: """ Check a handshake response received from the server. This function doesn't verify that the response is an HTTP/1.1 or higher response with a 101 status code. These controls are the responsibility of the caller. Args: headers: Handshake response headers. key: Returned by :func:`build_request`. Raises: InvalidHandshake: If the handshake response is invalid. """ connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. The RFC always uses "websocket", except # in section 11.2. (IANA registration) where it uses "WebSocket". if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) try: s_w_accept = headers["Sec-WebSocket-Accept"] except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Accept") from exc except MultipleValuesError as exc: raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc if s_w_accept != accept(key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) websockets-15.0.1/src/websockets/legacy/http.py000066400000000000000000000156251476212450300215210ustar00rootroot00000000000000from __future__ import annotations import asyncio import os import re from ..datastructures import Headers from ..exceptions import SecurityError __all__ = ["read_request", "read_response"] MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) def d(value: bytes) -> str: """ Decode a bytestring for interpolating into an error message. """ return value.decode(errors="backslashreplace") # See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. _token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") # Regex for validating header values. # We don't attempt to support obsolete line folding. # Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). # The ABNF is complicated because it attempts to express that optional # whitespace is ignored. We strip whitespace and don't revalidate that. # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: """ Read an HTTP/1.1 GET request and return ``(path, headers)``. ``path`` isn't URL-decoded or validated in any way. ``path`` and ``headers`` are expected to contain only ASCII characters. Other characters are represented with surrogate escapes. :func:`read_request` doesn't attempt to read the request body because WebSocket handshake requests don't have one. If the request contains a body, it may be read from ``stream`` after this coroutine returns. Args: stream: Input to read the request from. Raises: EOFError: If the connection is closed without a full HTTP request. SecurityError: If the request exceeds a security limit. ValueError: If the request isn't well formatted. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends # to implement HTTP/1.1 strictly, there's little need for lenient parsing. try: request_line = await read_line(stream) except EOFError as exc: raise EOFError("connection closed while reading HTTP request line") from exc try: method, raw_path, version = request_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None if method != b"GET": raise ValueError(f"unsupported HTTP method: {d(method)}") if version != b"HTTP/1.1": raise ValueError(f"unsupported HTTP version: {d(version)}") path = raw_path.decode("ascii", "surrogateescape") headers = await read_headers(stream) return path, headers async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]: """ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. ``reason`` and ``headers`` are expected to contain only ASCII characters. Other characters are represented with surrogate escapes. :func:`read_request` doesn't attempt to read the response body because WebSocket handshake responses don't have one. If the response contains a body, it may be read from ``stream`` after this coroutine returns. Args: stream: Input to read the response from. Raises: EOFError: If the connection is closed without a full HTTP response. SecurityError: If the response exceeds a security limit. ValueError: If the response isn't well formatted. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. try: status_line = await read_line(stream) except EOFError as exc: raise EOFError("connection closed while reading HTTP status line") from exc try: version, raw_status_code, raw_reason = status_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None if version != b"HTTP/1.1": raise ValueError(f"unsupported HTTP version: {d(version)}") try: status_code = int(raw_status_code) except ValueError: # invalid literal for int() with base 10 raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None if not 100 <= status_code < 1000: raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") if not _value_re.fullmatch(raw_reason): raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") reason = raw_reason.decode() headers = await read_headers(stream) return status_code, reason, headers async def read_headers(stream: asyncio.StreamReader) -> Headers: """ Read HTTP headers from ``stream``. Non-ASCII characters are represented with surrogate escapes. """ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. headers = Headers() for _ in range(MAX_NUM_HEADERS + 1): try: line = await read_line(stream) except EOFError as exc: raise EOFError("connection closed while reading HTTP headers") from exc if line == b"": break try: raw_name, raw_value = line.split(b":", 1) except ValueError: # not enough values to unpack (expected 2, got 1) raise ValueError(f"invalid HTTP header line: {d(line)}") from None if not _token_re.fullmatch(raw_name): raise ValueError(f"invalid HTTP header name: {d(raw_name)}") raw_value = raw_value.strip(b" \t") if not _value_re.fullmatch(raw_value): raise ValueError(f"invalid HTTP header value: {d(raw_value)}") name = raw_name.decode("ascii") # guaranteed to be ASCII at this point value = raw_value.decode("ascii", "surrogateescape") headers[name] = value else: raise SecurityError("too many HTTP headers") return headers async def read_line(stream: asyncio.StreamReader) -> bytes: """ Read a single line from ``stream``. CRLF is stripped from the return value. """ # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] websockets-15.0.1/src/websockets/legacy/protocol.py000066400000000000000000001746361476212450300224130ustar00rootroot00000000000000from __future__ import annotations import asyncio import codecs import collections import logging import random import ssl import struct import sys import time import traceback import uuid import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from typing import Any, Callable, Deque, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, InvalidState, PayloadTooBig, ProtocolError, ) from ..extensions import Extension from ..frames import ( OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Close, CloseCode, Opcode, ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol from .framing import Frame, prepare_ctrl, prepare_data __all__ = ["WebSocketCommonProtocol"] # In order to ensure consistency, the code always checks the current value of # WebSocketCommonProtocol.state before assigning a new value and never yields # between the check and the assignment. class WebSocketCommonProtocol(asyncio.Protocol): """ WebSocket connection. :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use :class:`~websockets.legacy.client.WebSocketClientProtocol` or :class:`~websockets.legacy.server.WebSocketServerProtocol`. This documentation focuses on low-level details that aren't covered in the documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol` and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake of simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, especially in the presence of proxies with short timeouts on inactive connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 See the discussion of :doc:`keepalive <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds for clients and ``4 * close_timeout`` for servers. ``close_timeout`` is a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: * on the client side, when using :func:`~websockets.legacy.client.connect` as a context manager; * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1 MiB. If a larger message is received, :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` and the connection will be closed with code 1009. The ``max_queue`` parameter sets the maximum length of the queue that holds incoming messages. The default value is ``32``. Messages are added to an in-memory queue when they're received; then :meth:`recv` pops from that queue. In order to prevent excessive memory consumption when messages are received faster than they can be processed, the queue must be bounded. If the queue fills up, the protocol stops processing incoming data until :meth:`recv` is called. In this situation, various receive buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the TCP receive window will shrink, slowing down transmission to avoid packet loss. Since Python can use up to 4 bytes of memory to represent a single character, each connection may use up to ``4 * max_size * max_queue`` bytes of memory to store incoming messages. By default, this is 128 MiB. You may want to lower the limits, depending on your application's requirements. The ``read_limit`` argument sets the high-water limit of the buffer for incoming bytes. The low-water limit is half the high-water limit. The default value is 64 KiB, half of asyncio's default (based on the current implementation of :class:`~asyncio.StreamReader`). The ``write_limit`` argument sets the high-water limit of the buffer for outgoing bytes. The low-water limit is a quarter of the high-water limit. The default value is 64 KiB, equal to asyncio's default (based on the current implementation of ``FlowControlMixin``). See the discussion of :doc:`memory usage <../../topics/memory>` for details. Args: logger: Logger for this server. It defaults to ``logging.getLogger("websockets.protocol")``. See the :doc:`logging guide <../../topics/logging>` for details. ping_interval: Interval between keepalive pings in seconds. :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. For legacy reasons, the actual timeout is 4 or 5 times larger. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. max_queue: Maximum number of incoming messages in receive buffer. :obj:`None` disables the limit. read_limit: High-water mark of read buffer in bytes. write_limit: High-water mark of write buffer in bytes. """ # There are only two differences between the client-side and server-side # behavior: masking the payload and closing the underlying TCP connection. # Set is_client = True/False and side = "client"/"server" to pick a side. is_client: bool side: str = "undefined" def __init__( self, *, logger: LoggerLike | None = None, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = None, max_size: int | None = 2**20, max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. host: str | None = None, port: int | None = None, secure: bool | None = None, legacy_recv: bool = False, loop: asyncio.AbstractEventLoop | None = None, timeout: float | None = None, ) -> None: if legacy_recv: # pragma: no cover warnings.warn("legacy_recv is deprecated", DeprecationWarning) # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 else: warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout # Backwards compatibility: the loop parameter used to be supported. if loop is None: loop = asyncio.get_event_loop() else: warnings.warn("remove loop argument", DeprecationWarning) self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout self.max_size = max_size self.max_queue = max_queue self.read_limit = read_limit self.write_limit = write_limit # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) self.loop = loop self._host = host self._port = port self._secure = secure self.legacy_recv = legacy_recv # Configure read buffer limits. The high-water limit is defined by # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) # Copied from asyncio.FlowControlMixin self._paused = False self._drain_waiter: asyncio.Future[None] | None = None self._drain_lock = asyncio.Lock() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING if self.debug: self.logger.debug("= connection is CONNECTING") # HTTP protocol parameters. self.path: str """Path of the opening handshake request.""" self.request_headers: Headers """Opening handshake request headers.""" self.response_headers: Headers """Opening handshake response headers.""" # WebSocket protocol parameters. self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. self.close_rcvd: Close | None = None self.close_sent: Close | None = None self.close_rcvd_then_sent: bool | None = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are # translated by ``self.stream_reader``). self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() # Queue of received messages. self.messages: Deque[Data] = collections.deque() self._pop_message_waiter: asyncio.Future[None] | None = None self._put_message_waiter: asyncio.Future[None] | None = None # Protect sending fragmented messages. self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ Latency of the connection, in seconds. Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism that sends Ping frames automatically at regular intervals. You can also send Ping frames and measure latency with :meth:`ping`. """ # Task running the data transfer. self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. self.transfer_data_exc: BaseException | None = None # Task sending keepalive pings. self.keepalive_ping_task: asyncio.Task[None] # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] # Copied from asyncio.FlowControlMixin async def _drain_helper(self) -> None: # pragma: no cover if self.connection_lost_waiter.done(): raise ConnectionResetError("Connection lost") if not self._paused: return waiter = self._drain_waiter assert waiter is None or waiter.cancelled() waiter = self.loop.create_future() self._drain_waiter = waiter await waiter # Copied from asyncio.StreamWriter async def _drain(self) -> None: # pragma: no cover if self.reader is not None: exc = self.reader.exception() if exc is not None: raise exc if self.transport is not None: if self.transport.is_closing(): # Yield to the event loop so connection_lost() may be # called. Without this, _drain_helper() would return # immediately, and code that calls # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. await asyncio.sleep(0) await self._drain_helper() def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. Enter the OPEN state and start the data transfer phase. """ # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN if self.debug: self.logger.debug("= connection is OPEN") # Start the task that receives incoming WebSocket messages. self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) # Start the task that eventually closes the TCP connection. self.close_connection_task = self.loop.create_task(self.close_connection()) @property def host(self) -> str | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) return self._host @property def port(self) -> int | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) return self._port @property def secure(self) -> bool | None: warnings.warn("don't use secure", DeprecationWarning) return self._secure # Public API @property def local_address(self) -> Any: """ Local address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family; see :meth:`~socket.socket.getsockname`. :obj:`None` if the TCP connection isn't established yet. """ try: transport = self.transport except AttributeError: return None else: return transport.get_extra_info("sockname") @property def remote_address(self) -> Any: """ Remote address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family; see :meth:`~socket.socket.getpeername`. :obj:`None` if the TCP connection isn't established yet. """ try: transport = self.transport except AttributeError: return None else: return transport.get_extra_info("peername") @property def open(self) -> bool: """ :obj:`True` when the connection is open; :obj:`False` otherwise. This attribute may be used to detect disconnections. However, this approach is discouraged per the EAFP_ principle. Instead, you should handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp """ return self.state is State.OPEN and not self.transfer_data_task.done() @property def closed(self) -> bool: """ :obj:`True` when the connection is closed; :obj:`False` otherwise. Be aware that both :attr:`open` and :attr:`closed` are :obj:`False` during the opening and closing sequences. """ return self.state is State.CLOSED @property def close_code(self) -> int | None: """ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. .. _section 7.1.5 of RFC 6455: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: return None elif self.close_rcvd is None: return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @property def close_reason(self) -> str | None: """ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. .. _section 7.1.6 of RFC 6455: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: return None elif self.close_rcvd is None: return "" else: return self.close_rcvd.reason async def __aiter__(self) -> AsyncIterator[Data]: """ Iterate on incoming messages. The iterator exits normally when the connection is closed with the close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception when the connection is closed with any other code. """ try: while True: yield await self.recv() except ConnectionClosedOK: return async def recv(self) -> Data: """ Receive the next message. When the connection is closed, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. This is how you detect the end of the message stream. Canceling :meth:`recv` is safe. There's no risk of losing the next message. The next invocation of :meth:`recv` will return it. This makes it possible to enforce a timeout by wrapping :meth:`recv` in :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` concurrently. """ if self._pop_message_waiter is not None: raise RuntimeError( "cannot call recv while another coroutine " "is already waiting for the next message" ) # Don't await self.ensure_open() here: # - messages could be available in the queue even if the connection # is closed; # - messages could be received before the closing frame even if the # connection is closing. # Wait until there's a message in the queue (if necessary) or the # connection is closed. while len(self.messages) <= 0: pop_message_waiter: asyncio.Future[None] = self.loop.create_future() self._pop_message_waiter = pop_message_waiter try: # If asyncio.wait() is canceled, it doesn't cancel # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], return_when=asyncio.FIRST_COMPLETED, ) finally: self._pop_message_waiter = None # If asyncio.wait(...) exited because self.transfer_data_task # completed before receiving a new message, raise a suitable # exception (or return None if legacy_recv is enabled). if not pop_message_waiter.done(): if self.legacy_recv: return None # type: ignore else: # Wait until the connection is closed to raise # ConnectionClosed with the correct code and reason. await self.ensure_open() # Pop a message from the queue. message = self.messages.popleft() # Notify transfer_data(). if self._put_message_waiter is not None: self._put_message_waiter.set_result(None) self._put_message_waiter = None return message async def send( self, message: Data | Iterable[Data] | AsyncIterable[Data], ) -> None: """ Send a message. A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you want to send the keys of a dict-like object as fragments, call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there are only two situations where :meth:`send` may yield control to the event loop and then get canceled; in both cases, :meth:`close` has the same effect and is more clear: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. :meth:`close` will likely time out then abort the TCP connection. 2. ``message`` is an asynchronous iterator that yields control. Stopping in the middle of a fragmented message will cause a protocol error and the connection will be closed. When the connection is closed, :meth:`send` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. Args: message: Message to send. Raises: ConnectionClosed: When the connection is closed. TypeError: If ``message`` doesn't have a supported type. """ await self.ensure_open() # While sending a fragmented message, prevent sending other messages # until all fragments are sent. while self._fragmented_message_waiter is not None: await asyncio.shield(self._fragmented_message_waiter) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. if isinstance(message, (str, bytes, bytearray, memoryview)): opcode, data = prepare_data(message) await self.write_frame(True, opcode, data) # Catch a common mistake -- passing a dict to send(). elif isinstance(message, Mapping): raise TypeError("data is a dict-like object") # Fragmented message -- regular iterator. elif isinstance(message, Iterable): # Work around https://github.com/python/mypy/issues/6227 message = cast(Iterable[Data], message) iter_message = iter(message) try: fragment = next(iter_message) except StopIteration: return opcode, data = prepare_data(fragment) self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) # Other fragments. for fragment in iter_message: confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) # Final fragment. await self.write_frame(True, OP_CONT, b"") except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: self._fragmented_message_waiter.set_result(None) self._fragmented_message_waiter = None # Fragmented message -- asynchronous iterator elif isinstance(message, AsyncIterable): # Implement aiter_message = aiter(message) without aiter # Work around https://github.com/python/mypy/issues/5738 aiter_message = cast( Callable[[AsyncIterable[Data]], AsyncIterator[Data]], type(message).__aiter__, )(message) try: # Implement fragment = anext(aiter_message) without anext # Work around https://github.com/python/mypy/issues/5738 fragment = await cast( Callable[[AsyncIterator[Data]], Awaitable[Data]], type(aiter_message).__anext__, )(aiter_message) except StopAsyncIteration: return opcode, data = prepare_data(fragment) self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) # Other fragments. async for fragment in aiter_message: confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) # Final fragment. await self.write_frame(True, OP_CONT, b"") except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: self._fragmented_message_waiter.set_result(None) self._fragmented_message_waiter = None else: raise TypeError("data must be str, bytes-like, or iterable") async def close( self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "", ) -> None: """ Perform the closing handshake. :meth:`close` waits for the other end to complete the handshake and for the TCP connection to terminate. As a consequence, there's no need to await :meth:`wait_closed` after :meth:`close`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given that errors during connection termination aren't particularly useful. Canceling :meth:`close` is discouraged. If it takes too long, you can set a shorter ``close_timeout``. If you don't want to wait, let the Python process exit, then the OS will take care of closing the TCP connection. Args: code: WebSocket close code. reason: WebSocket close reason. """ try: async with asyncio_timeout(self.close_timeout): await self.write_close_frame(Close(code, reason)) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. # Fail the connection to shut down faster. self.fail_connection() # If no close frame is received within the timeout, asyncio_timeout() # cancels the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these # calls hits the timeout, the data transfer task will be canceled. # Other calls will receive a CancelledError here. try: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. async with asyncio_timeout(self.close_timeout): await self.transfer_data_task except (asyncio.TimeoutError, asyncio.CancelledError): pass # Wait for the close connection task to close the TCP connection. await asyncio.shield(self.close_connection_task) async def wait_closed(self) -> None: """ Wait until the connection is closed. This coroutine is identical to the :attr:`closed` attribute, except it can be awaited. This can make it easier to detect connection termination, regardless of its cause, in tasks that interact with the WebSocket connection. """ await asyncio.shield(self.connection_lost_waiter) async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive, as a check that the remote endpoint received all messages up to this point, or to measure :attr:`latency`. Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return immediately, it means the write buffer is full. If you don't want to wait, you should close the connection. Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no effect. Args: data: Payload of the ping. A string will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. Returns: A future that will be completed when the corresponding pong is received. You can ignore it if you don't intend to wait. The result of the future is the latency of the connection in seconds. :: pong_waiter = await ws.ping() # only if you want to wait for the corresponding pong latency = await pong_waiter Raises: ConnectionClosed: When the connection is closed. RuntimeError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ await self.ensure_open() if data is not None: data = prepare_ctrl(data) # Protect against duplicates if a payload is explicitly set. if data in self.pings: raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = self.loop.create_future() # Resolution of time.monotonic() may be too low on Windows. ping_timestamp = time.perf_counter() self.pings[data] = (pong_waiter, ping_timestamp) await self.write_frame(True, OP_PING, data) return asyncio.shield(pong_waiter) async def pong(self, data: Data = b"") -> None: """ Send a Pong_. .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return immediately, it means the write buffer is full. If you don't want to wait, you should close the connection. Args: data: Payload of the pong. A string will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. """ await self.ensure_open() data = prepare_ctrl(data) await self.write_frame(True, OP_PONG, data) # Private methods - no guarantees. def connection_closed_exc(self) -> ConnectionClosed: exc: ConnectionClosed if ( self.close_rcvd is not None and self.close_rcvd.code in OK_CLOSE_CODES and self.close_sent is not None and self.close_sent.code in OK_CLOSE_CODES ): exc = ConnectionClosedOK( self.close_rcvd, self.close_sent, self.close_rcvd_then_sent, ) else: exc = ConnectionClosedError( self.close_rcvd, self.close_sent, self.close_rcvd_then_sent, ) # Chain to the exception that terminated data transfer, if any. exc.__cause__ = self.transfer_data_exc return exc async def ensure_open(self) -> None: """ Check that the WebSocket connection is open. Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. """ # Handle cases from most common to least common for performance. if self.state is State.OPEN: # If self.transfer_data_task exited without a closing handshake, # self.close_connection_task may be closing the connection, going # straight from OPEN to CLOSED. if self.transfer_data_task.done(): await asyncio.shield(self.close_connection_task) raise self.connection_closed_exc() else: return if self.state is State.CLOSED: raise self.connection_closed_exc() if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to # get the proper close code and reason. self.close_connection_task # will complete within 4 or 5 * close_timeout after close(). The # CLOSING state also occurs when failing the connection. In that # case self.close_connection_task will complete even faster. await asyncio.shield(self.close_connection_task) raise self.connection_closed_exc() # Control may only reach this point in buggy third-party subclasses. assert self.state is State.CONNECTING raise InvalidState("WebSocket connection isn't established yet") async def transfer_data(self) -> None: """ Read incoming messages and put them in a queue. This coroutine runs in a task until the closing handshake is started. """ try: while True: message = await self.read_message() # Exit the loop when receiving a close frame. if message is None: break # Wait until there's room in the queue (if necessary). if self.max_queue is not None: while len(self.messages) >= self.max_queue: self._put_message_waiter = self.loop.create_future() try: await asyncio.shield(self._put_message_waiter) finally: self._put_message_waiter = None # Put the message in the queue. self.messages.append(message) # Notify recv(). if self._pop_message_waiter is not None: self._pop_message_waiter.set_result(None) self._pop_message_waiter = None except asyncio.CancelledError as exc: self.transfer_data_exc = exc # If fail_connection() cancels this task, avoid logging the error # twice and failing the connection again. raise except ProtocolError as exc: self.transfer_data_exc = exc self.fail_connection(CloseCode.PROTOCOL_ERROR) except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; # - TimeoutError if the TCP connection times out; # - IncompleteReadError, a subclass of EOFError, if fewer # bytes are available than requested; # - ssl.SSLError if the other side infringes the TLS protocol. self.transfer_data_exc = exc self.fail_connection(CloseCode.ABNORMAL_CLOSURE) except UnicodeDecodeError as exc: self.transfer_data_exc = exc self.fail_connection(CloseCode.INVALID_DATA) except PayloadTooBig as exc: self.transfer_data_exc = exc self.fail_connection(CloseCode.MESSAGE_TOO_BIG) except Exception as exc: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. self.logger.error("data transfer failed", exc_info=True) self.transfer_data_exc = exc self.fail_connection(CloseCode.INTERNAL_ERROR) async def read_message(self) -> Data | None: """ Read a single message from the connection. Re-assemble data frames if the message is fragmented. Return :obj:`None` when the closing handshake is started. """ frame = await self.read_data_frame(max_size=self.max_size) # A close frame was received. if frame is None: return None if frame.opcode == OP_TEXT: text = True elif frame.opcode == OP_BINARY: text = False else: # frame.opcode == OP_CONT raise ProtocolError("unexpected opcode") # Shortcut for the common case - no fragmentation if frame.fin: return frame.data.decode() if text else frame.data # 5.4. Fragmentation fragments: list[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") decoder = decoder_factory(errors="strict") if max_size is None: def append(frame: Frame) -> None: nonlocal fragments fragments.append(decoder.decode(frame.data, frame.fin)) else: def append(frame: Frame) -> None: nonlocal fragments, max_size fragments.append(decoder.decode(frame.data, frame.fin)) assert isinstance(max_size, int) max_size -= len(frame.data) else: if max_size is None: def append(frame: Frame) -> None: nonlocal fragments fragments.append(frame.data) else: def append(frame: Frame) -> None: nonlocal fragments, max_size fragments.append(frame.data) assert isinstance(max_size, int) max_size -= len(frame.data) append(frame) while not frame.fin: frame = await self.read_data_frame(max_size=max_size) if frame is None: raise ProtocolError("incomplete fragmented message") if frame.opcode != OP_CONT: raise ProtocolError("unexpected opcode") append(frame) return ("" if text else b"").join(fragments) async def read_data_frame(self, max_size: int | None) -> Frame | None: """ Read a single data frame from the connection. Process control frames received before the next data frame. Return :obj:`None` if a close frame is encountered before any data frame. """ # 6.2. Receiving Data while True: frame = await self.read_frame(max_size) # 5.5. Control Frames if frame.opcode == OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason self.close_rcvd = Close.parse(frame.data) if self.close_sent is not None: self.close_rcvd_then_sent = False try: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthesizes a 1005 close code. await self.write_close_frame(self.close_rcvd, frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. pass return None elif frame.opcode == OP_PING: # Answer pings, unless connection is CLOSING. if self.state is State.OPEN: try: await self.pong(frame.data) except ConnectionClosed: # Connection closed while draining write buffer. pass elif frame.opcode == OP_PONG: if frame.data in self.pings: pong_timestamp = time.perf_counter() # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] for ping_id, (pong_waiter, ping_timestamp) in self.pings.items(): ping_ids.append(ping_id) if not pong_waiter.done(): pong_waiter.set_result(pong_timestamp - ping_timestamp) if ping_id == frame.data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] # 5.6. Data Frames else: return frame async def read_frame(self, max_size: int | None) -> Frame: """ Read a single frame from the connection. """ frame = await Frame.read( self.reader.readexactly, mask=not self.is_client, max_size=max_size, extensions=self.extensions, ) if self.debug: self.logger.debug("< %s", frame) return frame def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: frame = Frame(fin, Opcode(opcode), data) if self.debug: self.logger.debug("> %s", frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions, ) async def drain(self) -> None: try: # drain() cannot be called concurrently by multiple coroutines. # See https://github.com/python/cpython/issues/74116 for details. # This workaround can be removed when dropping Python < 3.10. async with self._drain_lock: # Handle flow control automatically. await self._drain() except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() # Wait until the connection is closed to raise ConnectionClosed # with the correct code and reason. await self.ensure_open() async def write_frame( self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN ) -> None: # Defensive assertion for protocol compliance. if self.state is not _state: # pragma: no cover raise InvalidState( f"Cannot write to a WebSocket in the {self.state.name} state" ) self.write_frame_sync(fin, opcode, data) await self.drain() async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: """ Write a close frame if and only if the connection state is OPEN. This dedicated coroutine must be used for writing close frames to ensure that at most one close frame is sent on a given connection. """ # Test and set the connection state before sending the close frame to # avoid sending two frames in case of concurrent calls. if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = State.CLOSING if self.debug: self.logger.debug("= connection is CLOSING") self.close_sent = close if self.close_rcvd is not None: self.close_rcvd_then_sent = True if data is None: data = close.serialize() # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) async def keepalive_ping(self) -> None: """ Send a Ping frame and wait for a Pong frame at regular intervals. This coroutine exits when the connection terminates and one of the following happens: - :meth:`ping` raises :exc:`ConnectionClosed`, or - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. """ if self.ping_interval is None: return try: while True: await asyncio.sleep(self.ping_interval) self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): # Raises CancelledError if the connection is closed, # when close_connection() cancels keepalive_ping(). # Raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("- timed out waiting for keepalive pong") self.fail_connection( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) break except ConnectionClosed: pass except Exception: self.logger.error("keepalive ping failed", exc_info=True) async def close_connection(self) -> None: """ 7.1.1. Close the WebSocket Connection When the opening handshake succeeds, :meth:`connection_open` starts this coroutine in a task. It waits for the data transfer phase to complete then it closes the TCP connection cleanly. When the opening handshake fails, :meth:`fail_connection` does the same. There's no data transfer phase in that case. """ try: # Wait for the data transfer phase to complete. if hasattr(self, "transfer_data_task"): try: await self.transfer_data_task except asyncio.CancelledError: pass # Cancel the keepalive ping task. if hasattr(self, "keepalive_ping_task"): self.keepalive_ping_task.cancel() # A client should wait for a TCP close from the server. if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): return if self.debug: self.logger.debug("- timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): if self.debug: self.logger.debug("x half-closing TCP connection") # write_eof() doesn't document which exceptions it raises. # "[Errno 107] Transport endpoint is not connected" happens # but it isn't completely clear under which circumstances. # uvloop can raise RuntimeError here. try: self.transport.write_eof() except (OSError, RuntimeError): # pragma: no cover pass if await self.wait_for_connection_lost(): return if self.debug: self.logger.debug("- timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is canceled (for example). await self.close_transport() async def close_transport(self) -> None: """ Close the TCP connection. """ # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. # Else asyncio complains: ResourceWarning: unclosed transport. if self.connection_lost_waiter.done() and self.transport.is_closing(): return # Close the TCP connection. Buffers are flushed asynchronously. if self.debug: self.logger.debug("x closing TCP connection") self.transport.close() if await self.wait_for_connection_lost(): return if self.debug: self.logger.debug("- timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() async def wait_for_connection_lost(self) -> bool: """ Wait until the TCP connection is closed or ``self.close_timeout`` elapses. Return :obj:`True` if the connection is closed and :obj:`False` otherwise. """ if not self.connection_lost_waiter.done(): try: async with asyncio_timeout(self.close_timeout): await asyncio.shield(self.connection_lost_waiter) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because # connection_lost() could run between the moment the timeout occurs # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() def fail_connection( self, code: int = CloseCode.ABNORMAL_CLOSURE, reason: str = "", ) -> None: """ 7.1.7. Fail the WebSocket Connection This requires: 1. Stopping all processing of incoming data, which means cancelling :attr:`transfer_data_task`. The close code will be 1006 unless a close frame was received earlier. 2. Sending a close frame with an appropriate code if the opening handshake succeeded and the other side is likely to process it. 3. Closing the connection. :meth:`close_connection` takes care of this once :attr:`transfer_data_task` exits after being canceled. (The specification describes these steps in the opposite order.) """ if self.debug: self.logger.debug("! failing connection with code %d", code) # Cancel transfer_data_task if the opening handshake succeeded. # cancel() is idempotent and ignored if the task is done already. if hasattr(self, "transfer_data_task"): self.transfer_data_task.cancel() # Send a close frame when the state is OPEN (a close frame was already # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN: close = Close(code, reason) # Write the close frame without draining the write buffer. # Keeping fail_connection() synchronous guarantees it can't # get stuck and simplifies the implementation of the callers. # Not drainig the write buffer is acceptable in this context. # This duplicates a few lines of code from write_close_frame(). self.state = State.CLOSING if self.debug: self.logger.debug("= connection is CLOSING") # If self.close_rcvd was set, the connection state would be # CLOSING. Therefore self.close_rcvd isn't set and we don't # have to set self.close_rcvd_then_sent. assert self.close_rcvd is None self.close_sent = close self.write_frame_sync(True, OP_CLOSE, close.serialize()) # Start close_connection_task if the opening handshake didn't succeed. if not hasattr(self, "close_connection_task"): self.close_connection_task = self.loop.create_task(self.close_connection()) def abort_pings(self) -> None: """ Raise ConnectionClosed in pending keepalive pings. They'll never receive a pong once the connection is closed. """ assert self.state is State.CLOSED exc = self.connection_closed_exc() for pong_waiter, _ping_timestamp in self.pings.values(): pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. pong_waiter.cancel() # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Configure write buffer limits. The high-water limit is defined by ``self.write_limit``. The low-water limit currently defaults to ``self.write_limit // 4`` in :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should be all right for reasonable use cases of this library. This is the earliest point where we can get hold of the transport, which means it's the best point for configuring it. """ transport = cast(asyncio.Transport, transport) transport.set_write_buffer_limits(self.write_limit) self.transport = transport # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) def connection_lost(self, exc: Exception | None) -> None: """ 7.1.4. The WebSocket Connection is Closed. """ self.state = State.CLOSED self.logger.debug("= connection is CLOSED") self.abort_pings() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) if True: # pragma: no cover # Copied from asyncio.StreamReaderProtocol if self.reader is not None: if exc is None: self.reader.feed_eof() else: self.reader.set_exception(exc) # Copied from asyncio.FlowControlMixin # Wake up the writer if currently paused. if not self._paused: return waiter = self._drain_waiter if waiter is None: return self._drain_waiter = None if waiter.done(): return if exc is None: waiter.set_result(None) else: waiter.set_exception(exc) def pause_writing(self) -> None: # pragma: no cover assert not self._paused self._paused = True def resume_writing(self) -> None: # pragma: no cover assert self._paused self._paused = False waiter = self._drain_waiter if waiter is not None: self._drain_waiter = None if not waiter.done(): waiter.set_result(None) def data_received(self, data: bytes) -> None: self.reader.feed_data(data) def eof_received(self) -> None: """ Close the transport after receiving EOF. The WebSocket protocol has its own closing handshake: endpoints close the TCP or TLS connection after sending and receiving a close frame. As a consequence, they never need to write after receiving EOF, so there's no reason to keep the transport open by returning :obj:`True`. Besides, that doesn't work on TLS connections. """ self.reader.feed_eof() # broadcast() is defined in the protocol module even though it's primarily # used by servers and documented in the server module because it works with # client connections too and because it's easier to test together with the # WebSocketCommonProtocol class. def broadcast( websockets: Iterable[WebSocketCommonProtocol], message: Data, raise_exceptions: bool = False, ) -> None: """ Broadcast a message to several WebSocket connections. A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. If you broadcast messages faster than a connection can handle them, messages will pile up in its write buffer until the connection times out. Keep ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, fragmentation is useful for sending large messages without buffering them in memory, while :func:`broadcast` buffers one copy per connection as fast as possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. :func:`broadcast` ignores failures to write the message on some connections. It continues writing to other connections. On Python 3.11 and above, you may set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. While :func:`broadcast` makes more sense for servers, it works identically with clients, if you have a use case for opening connections to many servers and broadcasting a message to them. Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. raise_exceptions: Whether to raise an exception in case of failures. Raises: TypeError: If ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): raise TypeError("data must be str or bytes-like") if raise_exceptions: if sys.version_info[:2] < (3, 11): # pragma: no cover raise ValueError("raise_exceptions requires at least Python 3.11") exceptions = [] opcode, data = prepare_data(message) for websocket in websockets: if websocket.state is not State.OPEN: continue if websocket._fragmented_message_waiter is not None: if raise_exceptions: exception = RuntimeError("sending a fragmented message") exceptions.append(exception) else: websocket.logger.warning( "skipped broadcast: sending a fragmented message", ) continue try: websocket.write_frame_sync(True, opcode, data) except Exception as write_exception: if raise_exceptions: exception = RuntimeError("failed to write message") exception.__cause__ = write_exception exceptions.append(exception) else: websocket.logger.warning( "skipped broadcast: failed to write message: %s", traceback.format_exception_only( # Remove first argument when dropping Python 3.9. type(write_exception), write_exception, )[0].strip(), ) if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) # Pretend that broadcast is actually defined in the server module. broadcast.__module__ = "websockets.legacy.server" websockets-15.0.1/src/websockets/legacy/server.py000066400000000000000000001303021476212450300220360ustar00rootroot00000000000000from __future__ import annotations import asyncio import email.utils import functools import http import inspect import logging import socket import warnings from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType from typing import Any, Callable, Union, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( InvalidHandshake, InvalidHeader, InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, ) from ..extensions import Extension, ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import ( build_extension, parse_extension, parse_subprotocol, validate_subprotocols, ) from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .exceptions import AbortHandshake from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast __all__ = [ "broadcast", "serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer", ] # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] HTTPResponse = tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): """ WebSocket server connection. :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. You may customize the opening handshake in a subclass by overriding :meth:`process_request` or :meth:`select_subprotocol`. Args: ws_server: WebSocket server that created this connection. See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ is_client = False side = "server" def __init__( self, # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), ws_server: WebSocketServer, *, logger: LoggerLike | None = None, origins: Sequence[Origin | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, select_subprotocol: ( Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None ) = None, open_timeout: float | None = 10, **kwargs: Any, ) -> None: if logger is None: logger = logging.getLogger("websockets.server") super().__init__(logger=logger, **kwargs) # For backwards compatibility with 6.0 or earlier. if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) origins = [None if origin == "" else origin for origin in origins] # For backwards compatibility with 10.0 or earlier. Done here in # addition to serve to trigger the deprecation warning on direct # use of WebSocketServerProtocol. self.ws_handler = remove_path_argument(ws_handler) self.ws_server = ws_server self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers self.server_header = server_header self._process_request = process_request self._select_subprotocol = select_subprotocol self.open_timeout = open_timeout def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Register connection and initialize a task to handle it. """ super().connection_made(transport) # Register the connection with the server before creating the handler # task. Registering at the beginning of the handler coroutine would # create a race condition between the creation of the task, which # schedules its execution, and the moment the handler starts running. self.ws_server.register(self) self.handler_task = self.loop.create_task(self.handler()) async def handler(self) -> None: """ Handle the lifecycle of a WebSocket connection. Since this method doesn't have a caller able to handle exceptions, it attempts to log relevant ones and guarantees that the TCP connection is closed before exiting. """ try: try: async with asyncio_timeout(self.open_timeout): await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) except asyncio.TimeoutError: # pragma: no cover raise except ConnectionError: raise except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): if self.debug: self.logger.debug("! invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), ( f"Failed to open a WebSocket connection: {exc}.\n" f"\n" f"You cannot access a WebSocket server directly " f"with a browser. You need a WebSocket client.\n" ).encode(), ) elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) exc_chain = cast(BaseException, exc) exc_str = f"{exc_chain}" while exc_chain.__cause__ is not None: exc_chain = exc_chain.__cause__ exc_str += f"; {exc_chain}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), ) else: self.logger.error("opening handshake failed", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), ( b"Failed to open a WebSocket connection.\n" b"See server log for more information.\n" ), ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) if self.server_header: headers.setdefault("Server", self.server_header) headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain") headers.setdefault("Connection", "close") self.write_http_response(status, headers, body) self.logger.info( "connection rejected (%d %s)", status.value, status.phrase ) await self.close_transport() return try: await self.ws_handler(self) except Exception: self.logger.error("connection handler failed", exc_info=True) if not self.closed: self.fail_connection(1011) raise try: await self.close() except ConnectionError: raise except Exception: self.logger.error("closing handshake failed", exc_info=True) raise except Exception: # Last-ditch attempt to avoid leaking connections on errors. try: self.transport.close() except Exception: # pragma: no cover pass finally: # Unregister the connection with the server when the handler task # terminates. Registration is tied to the lifecycle of the handler # task because the server waits for tasks attached to registered # connections before terminating. self.ws_server.unregister(self) self.logger.info("connection closed") async def read_http_request(self) -> tuple[str, Headers]: """ Read request line and headers from the HTTP request. If the request contains a body, it may be read from ``self.reader`` after this coroutine returns. Raises: InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET request. """ try: path, headers = await read_request(self.reader) except asyncio.CancelledError: # pragma: no cover raise except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc if self.debug: self.logger.debug("< GET %s HTTP/1.1", path) for key, value in headers.raw_items(): self.logger.debug("< %s: %s", key, value) self.path = path self.request_headers = headers return path, headers def write_http_response( self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None ) -> None: """ Write status line and headers to the HTTP response. This coroutine is also able to write a response body. """ self.response_headers = headers if self.debug: self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) for key, value in headers.raw_items(): self.logger.debug("> %s: %s", key, value) if body is not None: self.logger.debug("> [body] (%d bytes)", len(body)) # Since the status line and headers only contain ASCII characters, # we can keep this simple. response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" response += str(headers) self.transport.write(response.encode()) if body is not None: self.transport.write(body) async def process_request( self, path: str, request_headers: Headers ) -> HTTPResponse | None: """ Intercept the HTTP request and return an HTTP response if appropriate. You may override this method in a :class:`WebSocketServerProtocol` subclass, for example: * to return an HTTP 200 OK response on a given path; then a load balancer can use this path for a health check; * to authenticate the request and return an HTTP 401 Unauthorized or an HTTP 403 Forbidden when authentication fails. You may also override this method with the ``process_request`` argument of :func:`serve` and :class:`WebSocketServerProtocol`. This is equivalent, except ``process_request`` won't have access to the protocol instance, so it can't store information for later use. :meth:`process_request` is expected to complete quickly. If it may run for a long time, then it should await :meth:`wait_closed` and exit if :meth:`wait_closed` completes, or else it could prevent the server from shutting down. Args: path: Request path, including optional query string. request_headers: Request headers. Returns: tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, headers, and body, to abort the WebSocket handshake and return that HTTP response instead. """ if self._process_request is not None: response = self._process_request(path, request_headers) if isinstance(response, Awaitable): return await response else: # For backwards compatibility with 7.0. warnings.warn( "declare process_request as a coroutine", DeprecationWarning ) return response return None @staticmethod def process_origin( headers: Headers, origins: Sequence[Origin | None] | None = None ) -> Origin | None: """ Handle the Origin HTTP request header. Args: headers: Request headers. origins: Optional list of acceptable origins. Raises: InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "multiple values") from exc if origin is not None: origin = cast(Origin, origin) if origins is not None: if origin not in origins: raise InvalidOrigin(origin) return origin @staticmethod def process_extensions( headers: Headers, available_extensions: Sequence[ServerExtensionFactory] | None, ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. Return the Sec-WebSocket-Extensions HTTP response header and the list of accepted extensions. :rfc:`6455` leaves the rules up to the specification of each :extension. To provide this level of flexibility, for each extension proposed by the client, we check for a match with each extension available in the server configuration. If no match is found, the extension is ignored. If several variants of the same extension are proposed by the client, it may be accepted several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. This process doesn't allow the server to reorder extensions. It can only select a subset of the extensions proposed by the client. Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. Args: headers: Request headers. extensions: Optional list of supported extensions. Raises: InvalidHandshake: To abort the handshake with an HTTP 400 error. """ response_header_value: str | None = None extension_headers: list[ExtensionHeader] = [] accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: for ext_factory in available_extensions: # Skip non-matching extensions based on their name. if ext_factory.name != name: continue # Skip non-matching extensions based on their params. try: response_params, extension = ext_factory.process_request_params( request_params, accepted_extensions ) except NegotiationError: continue # Add matching extension to the final list. extension_headers.append((name, response_params)) accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list # matched what the client sent. The extension is declined. # Serialize extension header. if extension_headers: response_header_value = build_extension(extension_headers) return response_header_value, accepted_extensions # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. Return Sec-WebSocket-Protocol HTTP response header, which is the same as the selected subprotocol. Args: headers: Request headers. available_subprotocols: Optional list of supported subprotocols. Raises: InvalidHandshake: To abort the handshake with an HTTP 400 error. """ subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: parsed_header_values: list[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) subprotocol = self.select_subprotocol( parsed_header_values, available_subprotocols ) return subprotocol def select_subprotocol( self, client_subprotocols: Sequence[Subprotocol], server_subprotocols: Sequence[Subprotocol], ) -> Subprotocol | None: """ Pick a subprotocol among those supported by the client and the server. If several subprotocols are available, select the preferred subprotocol by giving equal weight to the preferences of the client and the server. If no subprotocol is available, proceed without a subprotocol. You may provide a ``select_subprotocol`` argument to :func:`serve` or :class:`WebSocketServerProtocol` to override this logic. For example, you could reject the handshake if the client doesn't support a particular subprotocol, rather than accept the handshake without that subprotocol. Args: client_subprotocols: List of subprotocols offered by the client. server_subprotocols: List of subprotocols available on the server. Returns: Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. """ if self._select_subprotocol is not None: return self._select_subprotocol(client_subprotocols, server_subprotocols) subprotocols = set(client_subprotocols) & set(server_subprotocols) if not subprotocols: return None return sorted( subprotocols, key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), )[0] async def handshake( self, origins: Sequence[Origin | None] | None = None, available_extensions: Sequence[ServerExtensionFactory] | None = None, available_subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, ) -> str: """ Perform the server side of the opening handshake. Args: origins: List of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be tried. subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers: Arbitrary HTTP headers to add to the response when the handshake succeeds. Returns: path of the URI of the request. Raises: InvalidHandshake: If the handshake fails. """ path, request_headers = await self.read_http_request() # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. early_response_awaitable = self.process_request(path, request_headers) if isinstance(early_response_awaitable, Awaitable): early_response = await early_response_awaitable else: # For backwards compatibility with 7.0. warnings.warn("declare process_request as a coroutine", DeprecationWarning) early_response = early_response_awaitable # The connection may drop while process_request is running. if self.state is State.CLOSED: # This subclass of ConnectionError is silently ignored in handler(). raise BrokenPipeError("connection closed during opening handshake") # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): early_response = ( http.HTTPStatus.SERVICE_UNAVAILABLE, [], b"Server is shutting down.\n", ) if early_response is not None: raise AbortHandshake(*early_response) key = check_request(request_headers) self.origin = self.process_origin(request_headers, origins) extensions_header, self.extensions = self.process_extensions( request_headers, available_extensions ) protocol_header = self.subprotocol = self.process_subprotocol( request_headers, available_subprotocols ) response_headers = Headers() build_response(response_headers, key) if extensions_header is not None: response_headers["Sec-WebSocket-Extensions"] = extensions_header if protocol_header is not None: response_headers["Sec-WebSocket-Protocol"] = protocol_header if callable(extra_headers): extra_headers = extra_headers(path, self.request_headers) if extra_headers is not None: response_headers.update(extra_headers) response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) if self.server_header is not None: response_headers.setdefault("Server", self.server_header) self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) self.logger.info("connection open") self.connection_open() return path class WebSocketServer: """ WebSocket server returned by :func:`serve`. This class mirrors the API of :class:`~asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. Args: logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. """ def __init__(self, logger: LoggerLike | None = None) -> None: if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger # Keep track of active connections. self.websockets: set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] def wrap(self, server: asyncio.base_events.Server) -> None: """ Attach to a given :class:`~asyncio.Server`. Since :meth:`~asyncio.loop.create_server` doesn't support injecting a custom ``Server`` class, the easiest solution that doesn't rely on private :mod:`asyncio` APIs is to: - instantiate a :class:`WebSocketServer` - give the protocol factory a reference to that instance - call :meth:`~asyncio.loop.create_server` with the factory - attach the resulting :class:`~asyncio.Server` with this method """ self.server = server for sock in server.sockets: if sock.family == socket.AF_INET: name = "%s:%d" % sock.getsockname() elif sock.family == socket.AF_INET6: name = "[%s]:%d" % sock.getsockname()[:2] elif sock.family == socket.AF_UNIX: name = sock.getsockname() # In the unlikely event that someone runs websockets over a # protocol other than IP or Unix sockets, avoid crashing. else: # pragma: no cover name = str(sock.getsockname()) self.logger.info("server listening on %s", name) # Initialized here because we need a reference to the event loop. # This should be moved back to __init__ when dropping Python < 3.10. self.closed_waiter = server.get_loop().create_future() def register(self, protocol: WebSocketServerProtocol) -> None: """ Register a connection with this server. """ self.websockets.add(protocol) def unregister(self, protocol: WebSocketServerProtocol) -> None: """ Unregister a connection with this server. """ self.websockets.remove(protocol) def close(self, close_connections: bool = True) -> None: """ Close the server. * Close the underlying :class:`~asyncio.Server`. * When ``close_connections`` is :obj:`True`, which is the default, close existing connections. Specifically: * Reject opening WebSocket connections with an HTTP 503 (service unavailable) error. This happens when the server accepted the TCP connection but didn't complete the opening handshake before closing. * Close open WebSocket connections with close code 1001 (going away). * Wait until all connection handlers terminate. :meth:`close` is idempotent. """ if self.close_task is None: self.close_task = self.get_loop().create_task( self._close(close_connections) ) async def _close(self, close_connections: bool) -> None: """ Implementation of :meth:`close`. This calls :meth:`~asyncio.Server.close` on the underlying :class:`~asyncio.Server` object to stop accepting new connections and then closes open connections with close code 1001. """ self.logger.info("server closing") # Stop accepting new connections. self.server.close() # Wait until all accepted connections reach connection_made() and call # register(). See https://github.com/python/cpython/issues/79033 for # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) if close_connections: # Close OPEN connections with close code 1001. After server.close(), # handshake() closes OPENING connections with an HTTP 503 error. close_tasks = [ asyncio.create_task(websocket.close(1001)) for websocket in self.websockets if websocket.state is not State.CONNECTING ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: await asyncio.wait(close_tasks) # Wait until all TCP connections are closed. await self.server.wait_closed() # Wait until all connection handlers terminate. # asyncio.wait doesn't accept an empty first argument. if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets] ) # Tell wait_closed() to return. self.closed_waiter.set_result(None) self.logger.info("server closed") async def wait_closed(self) -> None: """ Wait until the server is closed. When :meth:`wait_closed` returns, all TCP connections are closed and all connection handlers have returned. To ensure a fast shutdown, a connection handler should always be awaiting at least one of: * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is closed, it returns. Then the connection handler is immediately notified of the shutdown; it can clean up and exit. """ await asyncio.shield(self.closed_waiter) def get_loop(self) -> asyncio.AbstractEventLoop: """ See :meth:`asyncio.Server.get_loop`. """ return self.server.get_loop() def is_serving(self) -> bool: """ See :meth:`asyncio.Server.is_serving`. """ return self.server.is_serving() async def start_serving(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.start_serving`. Typical use:: server = await serve(..., start_serving=False) # perform additional setup here... # ... then start the server await server.start_serving() """ await self.server.start_serving() async def serve_forever(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.serve_forever`. Typical use:: server = await serve(...) # this coroutine doesn't return # canceling it stops the server await server.serve_forever() This is an alternative to using :func:`serve` as an asynchronous context manager. Shutdown is triggered by canceling :meth:`serve_forever` instead of exiting a :func:`serve` context. """ await self.server.serve_forever() @property def sockets(self) -> Iterable[socket.socket]: """ See :attr:`asyncio.Server.sockets`. """ return self.server.sockets async def __aenter__(self) -> WebSocketServer: # pragma: no cover return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() class Serve: """ Start a WebSocket server listening on ``host`` and ``port``. Whenever a client connects, the server creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and delegates to the connection handler, ``ws_handler``. The handler receives the :class:`WebSocketServerProtocol` and uses it to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object provides a :meth:`~WebSocketServer.close` method to shut down the server:: # set this future to exit the server stop = asyncio.get_running_loop().create_future() server = await serve(...) await stop server.close() await server.wait_closed() :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: # set this future to exit the server stop = asyncio.get_running_loop().create_future() async with serve(...): await stop Args: ws_handler: Connection handler. It receives the WebSocket connection, which is a :class:`WebSocketServerProtocol`, in argument. host: Network interfaces the server binds to. See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. create_protocol: Factory for the :class:`asyncio.Protocol` managing the connection. It defaults to :class:`WebSocketServerProtocol`. Set it to a wrapper or a subclass to customize connection handling. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Callable[[str, Headers], \ Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. Any other keyword arguments are passed the event loop's :meth:`~asyncio.loop.create_server` method. For example: * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. * You can set ``sock`` to a :obj:`~socket.socket` that you created outside of websockets. Returns: WebSocket server. """ def __init__( self, # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), host: str | Sequence[str] | None = None, port: int | None = None, *, create_protocol: Callable[..., WebSocketServerProtocol] | None = None, logger: LoggerLike | None = None, compression: str | None = "deflate", origins: Sequence[Origin | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, select_subprotocol: ( Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None ) = None, open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = None, max_size: int | None = 2**20, max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: warnings.warn("rename klass to create_protocol", DeprecationWarning) # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: loop = _loop warnings.warn("remove loop argument", DeprecationWarning) ws_server = WebSocketServer(logger=logger) secure = kwargs.get("ssl") is not None if compression == "deflate": extensions = enable_server_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if subprotocols is not None: validate_subprotocols(subprotocols) # Help mypy and avoid this error: "type[WebSocketServerProtocol] | # Callable[..., WebSocketServerProtocol]" not callable [misc] create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) factory = functools.partial( create_protocol, # For backwards compatibility with 10.0 or earlier. Done here in # addition to WebSocketServerProtocol to trigger the deprecation # warning once per serve() call rather than once per connection. remove_path_argument(ws_handler), ws_server, host=host, port=port, secure=secure, open_timeout=open_timeout, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=_loop, legacy_recv=legacy_recv, origins=origins, extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, server_header=server_header, process_request=process_request, select_subprotocol=select_subprotocol, logger=logger, ) if kwargs.pop("unix", False): path: str | None = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( loop.create_unix_server, factory, path, **kwargs ) else: create_server = functools.partial( loop.create_server, factory, host, port, **kwargs ) # This is a coroutine function. self._create_server = create_server self.ws_server = ws_server # async with serve(...) async def __aenter__(self) -> WebSocketServer: return await self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: self.ws_server.close() await self.ws_server.wait_closed() # await serve(...) def __await__(self) -> Generator[Any, None, WebSocketServer]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketServer: server = await self._create_server() self.ws_server.wrap(server) return self.ws_server # yield from serve(...) - remove when dropping Python < 3.10 __iter__ = __await__ serve = Serve def unix_serve( # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), path: str | None = None, **kwargs: Any, ) -> Serve: """ Start a WebSocket server listening on a Unix socket. This function is identical to :func:`serve`, except the ``host`` and ``port`` arguments are replaced by ``path``. It is only available on Unix. Unrecognized keyword arguments are passed the event loop's :meth:`~asyncio.loop.create_unix_server` method. It's useful for deploying a server behind a reverse proxy such as nginx. Args: path: File system path to the Unix socket. """ return serve(ws_handler, path=path, unix=True, **kwargs) def remove_path_argument( ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) except TypeError: try: inspect.signature(ws_handler).bind(None, "") except TypeError: # pragma: no cover # ws_handler accepts neither one nor two arguments; leave it alone. pass else: # ws_handler accepts two arguments; activate backwards compatibility. warnings.warn("remove second argument of ws_handler", DeprecationWarning) async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: return await cast( Callable[[WebSocketServerProtocol, str], Awaitable[Any]], ws_handler, )(websocket, websocket.path) return _ws_handler return cast( Callable[[WebSocketServerProtocol], Awaitable[Any]], ws_handler, ) websockets-15.0.1/src/websockets/protocol.py000066400000000000000000000636511476212450300211410ustar00rootroot00000000000000from __future__ import annotations import enum import logging import uuid from collections.abc import Generator from typing import Union from .exceptions import ( ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, InvalidState, PayloadTooBig, ProtocolError, ) from .extensions import Extension from .frames import ( OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Close, CloseCode, Frame, ) from .http11 import Request, Response from .streams import StreamReader from .typing import LoggerLike, Origin, Subprotocol __all__ = [ "Protocol", "Side", "State", "SEND_EOF", ] # Change to Request | Response | Frame when dropping Python < 3.10. Event = Union[Request, Response, Frame] """Events that :meth:`~Protocol.events_received` may return.""" class Side(enum.IntEnum): """A WebSocket connection is either a server or a client.""" SERVER, CLIENT = range(2) SERVER = Side.SERVER CLIENT = Side.CLIENT class State(enum.IntEnum): """A WebSocket connection is in one of these four states.""" CONNECTING, OPEN, CLOSING, CLOSED = range(4) CONNECTING = State.CONNECTING OPEN = State.OPEN CLOSING = State.CLOSING CLOSED = State.CLOSED SEND_EOF = b"" """Sentinel signaling that the TCP connection must be half-closed.""" class Protocol: """ Sans-I/O implementation of a WebSocket connection. Args: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. state: Initial state of the WebSocket connection. max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. logger: Logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. """ def __init__( self, side: Side, *, state: State = OPEN, max_size: int | None = 2**20, logger: LoggerLike | None = None, ) -> None: # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger(f"websockets.{side.name.lower()}") self.logger: LoggerLike = logger """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) # Connection side. CLIENT or SERVER. self.side = side # Connection state. Initially OPEN because subclasses handle CONNECTING. self.state = state # Maximum size of incoming messages in bytes. self.max_size = max_size # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. self.cur_size: int | None = None # True while sending a fragmented message i.e. a data frames with the # FIN bit not set. self.expect_continuation_frame = False # WebSocket protocol parameters. self.origin: Origin | None = None self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. self.close_rcvd: Close | None = None self.close_sent: Close | None = None self.close_rcvd_then_sent: bool | None = None # Track if an exception happened during the handshake. self.handshake_exc: Exception | None = None """ Exception to raise if the opening handshake failed. :obj:`None` if the opening handshake succeeded. """ # Track if send_eof() was called. self.eof_sent = False # Parser state. self.reader = StreamReader() self.events: list[Event] = [] self.writes: list[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine self.parser_exc: Exception | None = None @property def state(self) -> State: """ State of the WebSocket connection. Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 """ return self._state @state.setter def state(self, state: State) -> None: if self.debug: self.logger.debug("= connection is %s", state.name) self._state = state @property def close_code(self) -> int | None: """ WebSocket close code received from the remote endpoint. Defined in 7.1.5_ of :rfc:`6455`. .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: return None elif self.close_rcvd is None: return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @property def close_reason(self) -> str | None: """ WebSocket close reason received from the remote endpoint. Defined in 7.1.6_ of :rfc:`6455`. .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: return None elif self.close_rcvd is None: return "" else: return self.close_rcvd.reason @property def close_exc(self) -> ConnectionClosed: """ Exception to raise when trying to interact with a closed connection. Don't raise this exception while the connection :attr:`state` is :attr:`~websockets.protocol.State.CLOSING`; wait until it's :attr:`~websockets.protocol.State.CLOSED`. Indeed, the exception includes the close code and reason, which are known only once the connection is closed. Raises: AssertionError: If the connection isn't closed yet. """ assert self.state is CLOSED, "connection isn't closed yet" exc_type: type[ConnectionClosed] if ( self.close_rcvd is not None and self.close_sent is not None and self.close_rcvd.code in OK_CLOSE_CODES and self.close_sent.code in OK_CLOSE_CODES ): exc_type = ConnectionClosedOK else: exc_type = ConnectionClosedError exc: ConnectionClosed = exc_type( self.close_rcvd, self.close_sent, self.close_rcvd_then_sent, ) # Chain to the exception raised in the parser, if any. exc.__cause__ = self.parser_exc return exc # Public methods for receiving data. def receive_data(self, data: bytes) -> None: """ Receive data from the network. After calling this method: - You must call :meth:`data_to_send` and send this data to the network. - You should call :meth:`events_received` and process resulting events. Raises: EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) next(self.parser) def receive_eof(self) -> None: """ Receive the end of the data stream from the network. After calling this method: - You must call :meth:`data_to_send` and send this data to the network; it will return ``[b""]``, signaling the end of the stream, or ``[]``. - You aren't expected to call :meth:`events_received`; it won't return any new events. :meth:`receive_eof` is idempotent. """ if self.reader.eof: return self.reader.feed_eof() next(self.parser) # Public methods for sending events. def send_continuation(self, data: bytes, fin: bool) -> None: """ Send a `Continuation frame`_. .. _Continuation frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Parameters: data: payload containing the same kind of data as the initial frame. fin: FIN bit; set it to :obj:`True` if this is the last frame of a fragmented message and to :obj:`False` otherwise. Raises: ProtocolError: If a fragmented message isn't in progress. """ if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") if self._state is not OPEN: raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) def send_text(self, data: bytes, fin: bool = True) -> None: """ Send a `Text frame`_. .. _Text frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Parameters: data: payload containing text encoded with UTF-8. fin: FIN bit; set it to :obj:`False` if this is the first frame of a fragmented message. Raises: ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") if self._state is not OPEN: raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) def send_binary(self, data: bytes, fin: bool = True) -> None: """ Send a `Binary frame`_. .. _Binary frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Parameters: data: payload containing arbitrary binary data. fin: FIN bit; set it to :obj:`False` if this is the first frame of a fragmented message. Raises: ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") if self._state is not OPEN: raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) def send_close(self, code: int | None = None, reason: str = "") -> None: """ Send a `Close frame`_. .. _Close frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 Parameters: code: close code. reason: close reason. Raises: ProtocolError: If the code isn't valid or if a reason is provided without a code. """ # While RFC 6455 doesn't rule out sending more than one close Frame, # websockets is conservative in what it sends and doesn't allow that. if self._state is not OPEN: raise InvalidState(f"connection is {self.state.name.lower()}") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") close = Close(CloseCode.NO_STATUS_RCVD, "") data = b"" else: close = Close(code, reason) data = close.serialize() # 7.1.3. The WebSocket Closing Handshake is Started self.send_frame(Frame(OP_CLOSE, data)) # Since the state is OPEN, no close frame was received yet. # As a consequence, self.close_rcvd_then_sent remains None. assert self.close_rcvd is None self.close_sent = close self.state = CLOSING def send_ping(self, data: bytes) -> None: """ Send a `Ping frame`_. .. _Ping frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 Parameters: data: payload containing arbitrary binary data. """ # RFC 6455 allows control frames after starting the closing handshake. if self._state is not OPEN and self._state is not CLOSING: raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: """ Send a `Pong frame`_. .. _Pong frame: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 Parameters: data: payload containing arbitrary binary data. """ # RFC 6455 allows control frames after starting the closing handshake. if self._state is not OPEN and self._state is not CLOSING: raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: """ `Fail the WebSocket connection`_. .. _Fail the WebSocket connection: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 Parameters: code: close code reason: close reason Raises: ProtocolError: If the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection # Send a close frame when the state is OPEN (a close frame was already # sent if it's CLOSING), except when failing the connection because # of an error reading from or writing to the network. if self.state is OPEN: if code != CloseCode.ABNORMAL_CLOSURE: close = Close(code, reason) data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close # If recv_messages() raised an exception upon receiving a close # frame but before echoing it, then close_rcvd is not None even # though the state is OPEN. This happens when the connection is # closed while receiving a fragmented message. if self.close_rcvd is not None: self.close_rcvd_then_sent = True self.state = CLOSING # When failing the connection, a server closes the TCP connection # without waiting for the client to complete the handshake, while a # client waits for the server to close the TCP connection, possibly # after sending a close frame that the client will ignore. if self.side is SERVER and not self.eof_sent: self.send_eof() # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue # to attempt to process data(including a responding Close frame) from # the remote endpoint after being instructed to _Fail the WebSocket # Connection_." self.parser = self.discard() next(self.parser) # start coroutine # Public method for getting incoming events after receiving data. def events_received(self) -> list[Event]: """ Fetch events generated from data received from the network. Call this method immediately after any of the ``receive_*()`` methods. Process resulting events, likely by passing them to the application. Returns: Events read from the connection. """ events, self.events = self.events, [] return events # Public method for getting outgoing data after receiving data or sending events. def data_to_send(self) -> list[bytes]: """ Obtain data to send to the network. Call this method immediately after any of the ``receive_*()``, ``send_*()``, or :meth:`fail` methods. Write resulting data to the connection. The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals the end of the data stream. When you receive it, half-close the TCP connection. Returns: Data to write to the connection. """ writes, self.writes = self.writes, [] return writes def close_expected(self) -> bool: """ Tell if the TCP connection is expected to close soon. Call this method immediately after any of the ``receive_*()``, ``send_close()``, or :meth:`fail` methods. If it returns :obj:`True`, schedule closing the TCP connection after a short timeout if the other side hasn't already closed it. Returns: Whether the TCP connection is expected to close soon. """ # During the opening handshake, when our state is CONNECTING, we expect # a TCP close if and only if the hansdake fails. When it does, we start # the TCP closing handshake by sending EOF with send_eof(). # Once the opening handshake completes successfully, we expect a TCP # close if and only if we sent a close frame, meaning that our state # progressed to CLOSING: # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. # If our state is CLOSED, we already received a TCP close so we don't # expect it anymore. # Micro-optimization: put the most common case first if self.state is OPEN: return False if self.state is CLOSING: return True if self.state is CLOSED: return False assert self.state is CONNECTING return self.eof_sent # Private methods for receiving data. def parse(self) -> Generator[None]: """ Parse incoming data into frames. :meth:`receive_data` and :meth:`receive_eof` run this generator coroutine until it needs more data or reaches EOF. :meth:`parse` never raises an exception. Instead, it sets the :attr:`parser_exc` and yields control. """ try: while True: if (yield from self.reader.at_eof()): if self.debug: self.logger.debug("< EOF") # If the WebSocket connection is closed cleanly, with a # closing handhshake, recv_frame() substitutes parse() # with discard(). This branch is reached only when the # connection isn't closed cleanly. raise EOFError("unexpected end of stream") if self.max_size is None: max_size = None elif self.cur_size is None: max_size = self.max_size else: max_size = self.max_size - self.cur_size # During a normal closure, execution ends here on the next # iteration of the loop after receiving a close frame. At # this point, recv_frame() replaced parse() by discard(). frame = yield from Frame.parse( self.reader.read_exact, mask=self.side is SERVER, max_size=max_size, extensions=self.extensions, ) if self.debug: self.logger.debug("< %s", frame) self.recv_frame(frame) except ProtocolError as exc: self.fail(CloseCode.PROTOCOL_ERROR, str(exc)) self.parser_exc = exc except EOFError as exc: self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc)) self.parser_exc = exc except UnicodeDecodeError as exc: self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}") self.parser_exc = exc except PayloadTooBig as exc: exc.set_current_size(self.cur_size) self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail(CloseCode.INTERNAL_ERROR) self.parser_exc = exc # During an abnormal closure, execution ends here after catching an # exception. At this point, fail() replaced parse() by discard(). yield raise AssertionError("parse() shouldn't step after error") def discard(self) -> Generator[None]: """ Discard incoming data. This coroutine replaces :meth:`parse`: - after receiving a close frame, during a normal closure (1.4); - after sending a close frame, during an abnormal closure (7.1.7). """ # After the opening handshake completes, the server closes the TCP # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. yield # Once the reader reaches EOF, its feed_data/eof() methods raise an # error, so our receive_data/eof() methods don't step the generator. raise AssertionError("discard() shouldn't step after EOF") def recv_frame(self, frame: Frame) -> None: """ Process an incoming frame. """ if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") if not frame.fin: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: if self.cur_size is None: raise ProtocolError("unexpected continuation frame") if frame.fin: self.cur_size = None else: self.cur_size += len(frame.data) elif frame.opcode is OP_PING: # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST # send a Pong frame in response" pong_frame = Frame(OP_PONG, frame.data) self.send_frame(pong_frame) elif frame.opcode is OP_PONG: # 5.5.3 Pong: "A response to an unsolicited Pong frame is not # expected." pass elif frame.opcode is OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason self.close_rcvd = Close.parse(frame.data) if self.state is CLOSING: assert self.close_sent is not None self.close_rcvd_then_sent = False if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") # 5.5.1 Close: "If an endpoint receives a Close frame and did # not previously send a Close frame, the endpoint MUST send a # Close frame in response. (When sending a Close frame in # response, the endpoint typically echos the status code it # received.)" if self.state is OPEN: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthesizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) self.close_sent = self.close_rcvd self.close_rcvd_then_sent = True self.state = CLOSING # 7.1.2. Start the WebSocket Closing Handshake: "Once an # endpoint has both sent and received a Close control frame, # that endpoint SHOULD _Close the WebSocket Connection_" # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. if self.side is SERVER: self.send_eof() # 1.4. Closing Handshake: "after receiving a control frame # indicating the connection should be closed, a peer discards # any further data received." # RFC 6455 allows reading Ping and Pong frames after a Close frame. # However, that doesn't seem useful; websockets doesn't support it. self.parser = self.discard() next(self.parser) # start coroutine else: # This can't happen because Frame.parse() validates opcodes. raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") self.events.append(frame) # Private methods for sending events. def send_frame(self, frame: Frame) -> None: if self.debug: self.logger.debug("> %s", frame) self.writes.append( frame.serialize( mask=self.side is CLIENT, extensions=self.extensions, ) ) def send_eof(self) -> None: assert not self.eof_sent self.eof_sent = True if self.debug: self.logger.debug("> EOF") self.writes.append(SEND_EOF) websockets-15.0.1/src/websockets/py.typed000066400000000000000000000000001476212450300204000ustar00rootroot00000000000000websockets-15.0.1/src/websockets/server.py000066400000000000000000000520751476212450300206040ustar00rootroot00000000000000from __future__ import annotations import base64 import binascii import email.utils import http import re import warnings from collections.abc import Generator, Sequence from typing import Any, Callable, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, ) from .extensions import Extension, ServerExtensionFactory from .headers import ( build_extension, parse_connection, parse_extension, parse_subprotocol, parse_upgrade, ) from .http11 import Request, Response from .imports import lazy_import from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol, UpgradeProtocol, ) from .utils import accept_key __all__ = ["ServerProtocol"] class ServerProtocol(Protocol): """ Sans-I/O implementation of a WebSocket server connection. Args: origins: Acceptable values of the ``Origin`` header. Values can be :class:`str` to test for an exact match or regular expressions compiled by :func:`re.compile` to test against a pattern. Include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. state: Initial state of the WebSocket connection. max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. logger: Logger for this connection; defaults to ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. """ def __init__( self, *, origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], Subprotocol | None, ] | None ) = None, state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, ) -> None: super().__init__( side=SERVER, state=state, max_size=max_size, logger=logger, ) self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols if select_subprotocol is not None: # Bind select_subprotocol then shadow self.select_subprotocol. # Use setattr to work around https://github.com/python/mypy/issues/2427. setattr( self, "select_subprotocol", select_subprotocol.__get__(self, self.__class__), ) def accept(self, request: Request) -> Response: """ Create a handshake response to accept the connection. If the handshake request is valid and the handshake successful, :meth:`accept` returns an HTTP response with status code 101. Else, it returns an HTTP response with another status code. This rejects the connection, like :meth:`reject` would. You must send the handshake response with :meth:`send_response`. You may modify the response before sending it, typically by adding HTTP headers. Args: request: WebSocket handshake request received from the client. Returns: WebSocket handshake response or HTTP response to send to the client. """ try: ( accept_header, extensions_header, protocol_header, ) = self.process_request(request) except InvalidOrigin as exc: request._exception = exc self.handshake_exc = exc if self.debug: self.logger.debug("! invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", ) except InvalidUpgrade as exc: request._exception = exc self.handshake_exc = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) response = self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( f"Failed to open a WebSocket connection: {exc}.\n" f"\n" f"You cannot access a WebSocket server directly " f"with a browser. You need a WebSocket client.\n" ), ) response.headers["Upgrade"] = "websocket" return response except InvalidHandshake as exc: request._exception = exc self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) exc_chain = cast(BaseException, exc) exc_str = f"{exc_chain}" while exc_chain.__cause__ is not None: exc_chain = exc_chain.__cause__ exc_str += f"; {exc_chain}" return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc_str}.\n", ) except Exception as exc: # Handle exceptions raised by user-provided select_subprotocol and # unexpected errors. request._exception = exc self.handshake_exc = exc self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), ) headers = Headers() headers["Date"] = email.utils.formatdate(usegmt=True) headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept_header if extensions_header is not None: headers["Sec-WebSocket-Extensions"] = extensions_header if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header return Response(101, "Switching Protocols", headers) def process_request( self, request: Request, ) -> tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't check the ``Host`` header. These controls are usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. Args: request: WebSocket handshake request received from the client. Returns: ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and ``Sec-WebSocket-Protocol`` headers for the handshake response. Raises: InvalidHandshake: If the handshake request is invalid; then the server must return 400 Bad Request error. """ headers = request.headers connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None ) upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. The RFC always uses "websocket", except # in section 11.2. (IANA registration) where it uses "WebSocket". if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) try: key = headers["Sec-WebSocket-Key"] except KeyError: raise InvalidHeader("Sec-WebSocket-Key") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None try: raw_key = base64.b64decode(key.encode(), validate=True) except binascii.Error as exc: raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", key) accept_header = accept_key(key) try: version = headers["Sec-WebSocket-Version"] except KeyError: raise InvalidHeader("Sec-WebSocket-Version") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) self.origin = self.process_origin(headers) extensions_header, self.extensions = self.process_extensions(headers) protocol_header = self.subprotocol = self.process_subprotocol(headers) return (accept_header, extensions_header, protocol_header) def process_origin(self, headers: Headers) -> Origin | None: """ Handle the Origin HTTP request header. Args: headers: WebSocket handshake request headers. Returns: origin, if it is acceptable. Raises: InvalidHandshake: If the Origin header is invalid. InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError: raise InvalidHeader("Origin", "multiple values") from None if origin is not None: origin = cast(Origin, origin) if self.origins is not None: for origin_or_regex in self.origins: if origin_or_regex == origin or ( isinstance(origin_or_regex, re.Pattern) and origin is not None and origin_or_regex.fullmatch(origin) is not None ): break else: raise InvalidOrigin(origin) return origin def process_extensions( self, headers: Headers, ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. Per :rfc:`6455`, negotiation rules are defined by the specification of each extension. To provide this level of flexibility, for each extension proposed by the client, we check for a match with each extension available in the server configuration. If no match is found, the extension is ignored. If several variants of the same extension are proposed by the client, it may be accepted several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. This process doesn't allow the server to reorder extensions. It can only select a subset of the extensions proposed by the client. Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. Args: headers: WebSocket handshake request headers. Returns: ``Sec-WebSocket-Extensions`` HTTP response header and list of accepted extensions. Raises: InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ response_header_value: str | None = None extension_headers: list[ExtensionHeader] = [] accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: for ext_factory in self.available_extensions: # Skip non-matching extensions based on their name. if ext_factory.name != name: continue # Skip non-matching extensions based on their params. try: response_params, extension = ext_factory.process_request_params( request_params, accepted_extensions ) except NegotiationError: continue # Add matching extension to the final list. extension_headers.append((name, response_params)) accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list # matched what the client sent. The extension is declined. # Serialize extension header. if extension_headers: response_header_value = build_extension(extension_headers) return response_header_value, accepted_extensions def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. Args: headers: WebSocket handshake request headers. Returns: Subprotocol, if one was selected; this is also the value of the ``Sec-WebSocket-Protocol`` response header. Raises: InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. """ subprotocols: Sequence[Subprotocol] = sum( [ parse_subprotocol(header_value) for header_value in headers.get_all("Sec-WebSocket-Protocol") ], [], ) return self.select_subprotocol(subprotocols) def select_subprotocol( self, subprotocols: Sequence[Subprotocol], ) -> Subprotocol | None: """ Pick a subprotocol among those offered by the client. If several subprotocols are supported by both the client and the server, pick the first one in the list declared the server. If the server doesn't support any subprotocols, continue without a subprotocol, regardless of what the client offers. If the server supports at least one subprotocol and the client doesn't offer any, abort the handshake with an HTTP 400 error. You provide a ``select_subprotocol`` argument to :class:`ServerProtocol` to override this logic. For example, you could accept the connection even if client doesn't offer a subprotocol, rather than reject it. Here's how to negotiate the ``chat`` subprotocol if the client supports it and continue without a subprotocol otherwise:: def select_subprotocol(protocol, subprotocols): if "chat" in subprotocols: return "chat" Args: subprotocols: List of subprotocols offered by the client. Returns: Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. Raises: NegotiationError: Custom implementations may raise this exception to abort the handshake with an HTTP 400 error. """ # Server doesn't offer any subprotocols. if not self.available_subprotocols: # None or empty list return None # Server offers at least one subprotocol but client doesn't offer any. if not subprotocols: raise NegotiationError("missing subprotocol") # Server and client both offer subprotocols. Look for a shared one. proposed_subprotocols = set(subprotocols) for subprotocol in self.available_subprotocols: if subprotocol in proposed_subprotocols: return subprotocol # No common subprotocol was found. raise NegotiationError( "invalid subprotocol; expected one of " + ", ".join(self.available_subprotocols) ) def reject(self, status: StatusLike, text: str) -> Response: """ Create a handshake response to reject the connection. A short plain text response is the best fallback when failing to establish a WebSocket connection. You must send the handshake response with :meth:`send_response`. You may modify the response before sending it, for example by changing HTTP headers. Args: status: HTTP status code. text: HTTP response body; it will be encoded to UTF-8. Returns: HTTP response to send to the client. """ # If status is an int instead of an HTTPStatus, fix it automatically. status = http.HTTPStatus(status) body = text.encode() headers = Headers( [ ("Date", email.utils.formatdate(usegmt=True)), ("Connection", "close"), ("Content-Length", str(len(body))), ("Content-Type", "text/plain; charset=utf-8"), ] ) return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: """ Send a handshake response to the client. Args: response: WebSocket handshake response event to send. """ if self.debug: code, phrase = response.status_code, response.reason_phrase self.logger.debug("> HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("> %s: %s", key, value) if response.body: self.logger.debug("> [body] (%d bytes)", len(response.body)) self.writes.append(response.serialize()) if response.status_code == 101: assert self.state is CONNECTING self.state = OPEN self.logger.info("connection open") else: self.logger.info( "connection rejected (%d %s)", response.status_code, response.reason_phrase, ) self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine def parse(self) -> Generator[None]: if self.state is CONNECTING: try: request = yield from Request.parse( self.reader.read_line, ) except Exception as exc: self.handshake_exc = InvalidMessage( "did not receive a valid HTTP request" ) self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield if self.debug: self.logger.debug("< GET %s HTTP/1.1", request.path) for key, value in request.headers.raw_items(): self.logger.debug("< %s: %s", key, value) self.events.append(request) yield from super().parse() class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( # deprecated in 11.0 - 2023-04-02 "ServerConnection was renamed to ServerProtocol", DeprecationWarning, ) super().__init__(*args, **kwargs) lazy_import( globals(), deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", "broadcast": ".legacy.server", "serve": ".legacy.server", "unix_serve": ".legacy.server", }, ) websockets-15.0.1/src/websockets/speedups.c000066400000000000000000000132071476212450300207120ustar00rootroot00000000000000/* C implementation of performance sensitive functions. */ #define PY_SSIZE_T_CLEAN #include #include /* uint8_t, uint32_t, uint64_t */ #if __ARM_NEON #include #elif __SSE2__ #include #endif static const Py_ssize_t MASK_LEN = 4; /* Similar to PyBytes_AsStringAndSize, but accepts more types */ static int _PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length) { // This supports bytes, bytearrays, and memoryview objects, // which are common data structures for handling byte streams. // If *tmp isn't NULL, the caller gets a new reference. if (PyBytes_Check(obj)) { *tmp = NULL; *buffer = PyBytes_AS_STRING(obj); *length = PyBytes_GET_SIZE(obj); } else if (PyByteArray_Check(obj)) { *tmp = NULL; *buffer = PyByteArray_AS_STRING(obj); *length = PyByteArray_GET_SIZE(obj); } else if (PyMemoryView_Check(obj)) { *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C'); if (*tmp == NULL) { return -1; } Py_buffer *mv_buf; mv_buf = PyMemoryView_GET_BUFFER(*tmp); *buffer = mv_buf->buf; *length = mv_buf->len; } else { PyErr_Format( PyExc_TypeError, "expected a bytes-like object, %.200s found", Py_TYPE(obj)->tp_name); return -1; } return 0; } /* C implementation of websockets.utils.apply_mask */ static PyObject * apply_mask(PyObject *self, PyObject *args, PyObject *kwds) { // In order to support various bytes-like types, accept any Python object. static char *kwlist[] = {"data", "mask", NULL}; PyObject *input_obj; PyObject *mask_obj; // A pointer to a char * + length will be extracted from the data and mask // arguments, possibly via a Py_buffer. PyObject *input_tmp = NULL; char *input; Py_ssize_t input_len; PyObject *mask_tmp = NULL; char *mask; Py_ssize_t mask_len; // Initialize a PyBytesObject then get a pointer to the underlying char * // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. PyObject *result = NULL; char *output; // Other variables. Py_ssize_t i = 0; // Parse inputs. if (!PyArg_ParseTupleAndKeywords( args, kwds, "OO", kwlist, &input_obj, &mask_obj)) { goto exit; } if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1) { goto exit; } if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1) { goto exit; } if (mask_len != MASK_LEN) { PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); goto exit; } // Create output. result = PyBytes_FromStringAndSize(NULL, input_len); if (result == NULL) { goto exit; } // Since we just created result, we don't need error checks. output = PyBytes_AS_STRING(result); // Perform the masking operation. // Apparently GCC cannot figure out the following optimizations by itself. // We need a new scope for MSVC 2010 (non C99 friendly) { #if __ARM_NEON // With NEON support, XOR by blocks of 16 bytes = 128 bits. Py_ssize_t input_len_128 = input_len & ~15; uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask)); for (; i < input_len_128; i += 16) { uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i)); uint8x16_t out_128 = veorq_u8(in_128, mask_128); vst1q_u8((uint8_t *)(output + i), out_128); } #elif __SSE2__ // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. // Since we cannot control the 16-bytes alignment of input and output // buffers, we rely on loadu/storeu rather than load/store. Py_ssize_t input_len_128 = input_len & ~15; __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); for (; i < input_len_128; i += 16) { __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); __m128i out_128 = _mm_xor_si128(in_128, mask_128); _mm_storeu_si128((__m128i *)(output + i), out_128); } #else // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. // We assume the memory allocator aligns everything on 8 bytes boundaries. Py_ssize_t input_len_64 = input_len & ~7; uint32_t mask_32 = *(uint32_t *)mask; uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; for (; i < input_len_64; i += 8) { *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; } #endif } // XOR the remainder of the input byte by byte. for (; i < input_len; i++) { output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; } exit: Py_XDECREF(input_tmp); Py_XDECREF(mask_tmp); return result; } static PyMethodDef speedups_methods[] = { { "apply_mask", (PyCFunction)apply_mask, METH_VARARGS | METH_KEYWORDS, "Apply masking to the data of a WebSocket message.", }, {NULL, NULL, 0, NULL}, /* Sentinel */ }; static struct PyModuleDef speedups_module = { PyModuleDef_HEAD_INIT, "websocket.speedups", /* m_name */ "C implementation of performance sensitive functions.", /* m_doc */ -1, /* m_size */ speedups_methods, /* m_methods */ NULL, NULL, NULL, NULL }; PyMODINIT_FUNC PyInit_speedups(void) { return PyModule_Create(&speedups_module); } websockets-15.0.1/src/websockets/speedups.pyi000066400000000000000000000000671476212450300212710ustar00rootroot00000000000000def apply_mask(data: bytes, mask: bytes) -> bytes: ... websockets-15.0.1/src/websockets/streams.py000066400000000000000000000077171476212450300207570ustar00rootroot00000000000000from __future__ import annotations from collections.abc import Generator class StreamReader: """ Generator-based stream reader. This class doesn't support concurrent calls to :meth:`read_line`, :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are serialized. """ def __init__(self) -> None: self.buffer = bytearray() self.eof = False def read_line(self, m: int) -> Generator[None, None, bytes]: """ Read a LF-terminated line from the stream. This is a generator-based coroutine. The return value includes the LF character. Args: m: Maximum number bytes to read; this is a security limit. Raises: EOFError: If the stream ends without a LF. RuntimeError: If the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read p = 0 # number of bytes without a newline while True: n = self.buffer.find(b"\n", p) + 1 if n > 0: break p = len(self.buffer) if p > m: raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") if self.eof: raise EOFError(f"stream ends after {p} bytes, before end of line") yield if n > m: raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes") r = self.buffer[:n] del self.buffer[:n] return r def read_exact(self, n: int) -> Generator[None, None, bytes]: """ Read a given number of bytes from the stream. This is a generator-based coroutine. Args: n: How many bytes to read. Raises: EOFError: If the stream ends in less than ``n`` bytes. """ assert n >= 0 while len(self.buffer) < n: if self.eof: p = len(self.buffer) raise EOFError(f"stream ends after {p} bytes, expected {n} bytes") yield r = self.buffer[:n] del self.buffer[:n] return r def read_to_eof(self, m: int) -> Generator[None, None, bytes]: """ Read all bytes from the stream. This is a generator-based coroutine. Args: m: Maximum number bytes to read; this is a security limit. Raises: RuntimeError: If the stream ends in more than ``m`` bytes. """ while not self.eof: p = len(self.buffer) if p > m: raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") yield r = self.buffer[:] del self.buffer[:] return r def at_eof(self) -> Generator[None, None, bool]: """ Tell whether the stream has ended and all data was read. This is a generator-based coroutine. """ while True: if self.buffer: return False if self.eof: return True # When all data was read but the stream hasn't ended, we can't # tell if until either feed_data() or feed_eof() is called. yield def feed_data(self, data: bytes) -> None: """ Write data to the stream. :meth:`feed_data` cannot be called after :meth:`feed_eof`. Args: data: Data to write. Raises: EOFError: If the stream has ended. """ if self.eof: raise EOFError("stream ended") self.buffer += data def feed_eof(self) -> None: """ End the stream. :meth:`feed_eof` cannot be called more than once. Raises: EOFError: If the stream has ended. """ if self.eof: raise EOFError("stream ended") self.eof = True def discard(self) -> None: """ Discard all buffered data, but don't end the stream. """ del self.buffer[:] websockets-15.0.1/src/websockets/sync/000077500000000000000000000000001476212450300176675ustar00rootroot00000000000000websockets-15.0.1/src/websockets/sync/__init__.py000066400000000000000000000000001476212450300217660ustar00rootroot00000000000000websockets-15.0.1/src/websockets/sync/client.py000066400000000000000000000541701476212450300215260ustar00rootroot00000000000000from __future__ import annotations import socket import ssl as ssl_module import threading import warnings from collections.abc import Sequence from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol from ..datastructures import Headers, HeadersLike from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline __all__ = ["connect", "unix_connect", "ClientConnection"] class ClientConnection(Connection): """ :mod:`threading` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. It supports iteration to receive messages:: for message in websocket: process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and ``max_queue`` arguments have the same meaning as in :func:`connect`. Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. """ def __init__( self, socket: socket.socket, protocol: ClientProtocol, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() super().__init__( socket, protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) def handshake( self, additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, timeout: float | None = None, ) -> None: """ Perform the opening handshake. """ with self.send_context(expected_state=CONNECTING): self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header is not None: self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out while waiting for handshake response") # self.protocol.handshake_exc is set when the connection is lost before # receiving a response, when the response cannot be parsed, or when the # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ Process one incoming event. """ # First event - handshake response. if self.response is None: assert isinstance(event, Response) self.response = event self.response_rcvd.set() # Later events - frames. else: super().process_event(event) def recv_events(self) -> None: """ Read incoming data from the socket and process events. """ try: super().recv_events() finally: # If the connection is closed during the handshake, unblock it. self.response_rcvd.set() def connect( uri: str, *, # TCP/TLS sock: socket.socket | None = None, ssl: ssl_module.SSLContext | None = None, server_hostname: str | None = None, # WebSocket origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, compression: str | None = "deflate", # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, proxy_ssl: ssl_module.SSLContext | None = None, proxy_server_hostname: str | None = None, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ Connect to the WebSocket server at ``uri``. This function returns a :class:`ClientConnection` instance, which you can use to send and receive messages. :func:`connect` may be used as a context manager:: from websockets.sync.client import connect with connect(...) as websocket: ... The connection is closed automatically when exiting the context. Args: uri: URI of the WebSocket server. sock: Preexisting TCP socket. ``sock`` overrides the host and port from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. server_hostname: Host name for the TLS handshake. ``server_hostname`` overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. proxy: If a proxy is configured, it is used by default. Set ``proxy`` to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. proxy_ssl: Configuration for enabling TLS on the proxy connection. proxy_server_hostname: Host name for the TLS handshake with the proxy. ``proxy_server_hostname`` overrides the host name from ``proxy``. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. If you want to disable flow control entirely, you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. create_connection: Factory for the :class:`ClientConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. Any other keyword arguments are passed to :func:`~socket.create_connection`. Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. """ # Process parameters # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") warnings.warn( # deprecated in 13.0 - 2024-08-20 "ssl_context was renamed to ssl", DeprecationWarning, ) ws_uri = parse_uri(uri) if not ws_uri.secure and ssl is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) path: str | None = kwargs.pop("path", None) if unix: if path is None and sock is None: raise ValueError("missing path argument") elif path is not None and sock is not None: raise ValueError("path and sock arguments are incompatible") if subprotocols is not None: validate_subprotocols(subprotocols) if compression == "deflate": extensions = enable_client_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if unix: proxy = None if sock is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. # The TCP and TLS timeouts must be set on the socket, then removed # to avoid conflicting with the WebSocket timeout in handshake(). deadline = Deadline(open_timeout) if create_connection is None: create_connection = ClientConnection try: # Connect socket if sock is None: if unix: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(deadline.timeout()) assert path is not None # mypy cannot figure this out sock.connect(path) elif proxy is not None: proxy_parsed = parse_proxy(proxy) if proxy_parsed.scheme[:5] == "socks": # Connect to the server through the proxy. sock = connect_socks_proxy( proxy_parsed, ws_uri, deadline, # websockets is consistent with the socket module while # python_socks is consistent across implementations. local_addr=kwargs.pop("source_address", None), ) elif proxy_parsed.scheme[:4] == "http": # Validate the proxy_ssl argument. if proxy_parsed.scheme != "https" and proxy_ssl is not None: raise ValueError( "proxy_ssl argument is incompatible with an http:// proxy" ) # Connect to the server through the proxy. sock = connect_http_proxy( proxy_parsed, ws_uri, deadline, user_agent_header=user_agent_header, ssl=proxy_ssl, server_hostname=proxy_server_hostname, **kwargs, ) else: raise AssertionError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( (ws_uri.host, ws_uri.port), **kwargs, ) sock.settimeout(None) # Disable Nagle algorithm if not unix: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) # Initialize TLS wrapper and perform TLS handshake if ws_uri.secure: if ssl is None: ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) if proxy_ssl is None: sock = ssl.wrap_socket(sock, server_hostname=server_hostname) else: sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) # Let's pretend that sock is a socket, even though it isn't. sock = cast(socket.socket, sock_2) sock.settimeout(None) # Initialize WebSocket protocol protocol = ClientProtocol( ws_uri, origin=origin, extensions=extensions, subprotocols=subprotocols, max_size=max_size, logger=logger, ) # Initialize WebSocket connection connection = create_connection( sock, protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) except Exception: if sock is not None: sock.close() raise try: connection.handshake( additional_headers, user_agent_header, deadline.timeout(), ) except Exception: connection.close_socket() connection.recv_events_thread.join() raise connection.start_keepalive() return connection def unix_connect( path: str | None = None, uri: str | None = None, **kwargs: Any, ) -> ClientConnection: """ Connect to a WebSocket server listening on a Unix socket. This function accepts the same keyword arguments as :func:`connect`. It's only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to ``ws://localhost/`` or, when a ``ssl`` is provided, to ``wss://localhost/``. """ if uri is None: # Backwards compatibility: ssl used to be called ssl_context. if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) try: from python_socks import ProxyType from python_socks.sync import Proxy as SocksProxy SOCKS_PROXY_TYPES = { "socks5h": ProxyType.SOCKS5, "socks5": ProxyType.SOCKS5, "socks4a": ProxyType.SOCKS4, "socks4": ProxyType.SOCKS4, } SOCKS_PROXY_RDNS = { "socks5h": True, "socks5": False, "socks4a": True, "socks4": False, } def connect_socks_proxy( proxy: Proxy, ws_uri: WebSocketURI, deadline: Deadline, **kwargs: Any, ) -> socket.socket: """Connect via a SOCKS proxy and return the socket.""" socks_proxy = SocksProxy( SOCKS_PROXY_TYPES[proxy.scheme], proxy.host, proxy.port, proxy.username, proxy.password, SOCKS_PROXY_RDNS[proxy.scheme], ) kwargs.setdefault("timeout", deadline.timeout()) # connect() is documented to raise OSError and TimeoutError. # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. try: return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) except (OSError, TimeoutError, socket.timeout): raise except Exception as exc: raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: def connect_socks_proxy( proxy: Proxy, ws_uri: WebSocketURI, deadline: Deadline, **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") def prepare_connect_request( proxy: Proxy, ws_uri: WebSocketURI, user_agent_header: str | None = None, ) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) if user_agent_header is not None: headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( proxy.username, proxy.password ) # We cannot use the Request class because it supports only GET requests. return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: reader = StreamReader() parser = Response.parse( reader.read_line, reader.read_exact, reader.read_to_eof, include_body=False, ) try: while True: sock.settimeout(deadline.timeout()) data = sock.recv(4096) if data: reader.feed_data(data) else: reader.feed_eof() next(parser) except StopIteration as exc: assert isinstance(exc.value, Response) # help mypy response = exc.value if 200 <= response.status_code < 300: return response else: raise InvalidProxyStatus(response) except socket.timeout: raise TimeoutError("timed out while connecting to HTTP proxy") except Exception as exc: raise InvalidProxyMessage( "did not receive a valid HTTP response from proxy" ) from exc finally: sock.settimeout(None) def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, deadline: Deadline, *, user_agent_header: str | None = None, ssl: ssl_module.SSLContext | None = None, server_hostname: str | None = None, **kwargs: Any, ) -> socket.socket: # Connect socket kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection((proxy.host, proxy.port), **kwargs) # Initialize TLS wrapper and perform TLS handshake if proxy.scheme == "https": if ssl is None: ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = proxy.host sock.settimeout(deadline.timeout()) sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) # Send CONNECT request to the proxy and read response. sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) try: read_connect_response(sock, deadline) except Exception: sock.close() raise return sock T = TypeVar("T") F = TypeVar("F", bound=Callable[..., T]) class SSLSSLSocket: """ Socket-like object providing TLS-in-TLS. Only methods that are used by websockets are implemented. """ recv_bufsize = 65536 def __init__( self, sock: socket.socket, ssl_context: ssl_module.SSLContext, server_hostname: str | None = None, ) -> None: self.incoming = ssl_module.MemoryBIO() self.outgoing = ssl_module.MemoryBIO() self.ssl_socket = sock self.ssl_object = ssl_context.wrap_bio( self.incoming, self.outgoing, server_hostname=server_hostname, ) self.run_io(self.ssl_object.do_handshake) def run_io(self, func: Callable[..., T], *args: Any) -> T: while True: want_read = False want_write = False try: result = func(*args) except ssl_module.SSLWantReadError: want_read = True except ssl_module.SSLWantWriteError: # pragma: no cover want_write = True # Write outgoing data in all cases. data = self.outgoing.read() if data: self.ssl_socket.sendall(data) # Read incoming data and retry on SSLWantReadError. if want_read: data = self.ssl_socket.recv(self.recv_bufsize) if data: self.incoming.write(data) else: self.incoming.write_eof() continue # Retry after writing outgoing data on SSLWantWriteError. if want_write: # pragma: no cover continue # Return result if no error happened. return result def recv(self, buflen: int) -> bytes: try: return self.run_io(self.ssl_object.read, buflen) except ssl_module.SSLEOFError: return b"" # always ignore ragged EOFs def send(self, data: bytes) -> int: return self.run_io(self.ssl_object.write, data) def sendall(self, data: bytes) -> None: # adapted from ssl_module.SSLSocket.sendall() count = 0 with memoryview(data) as view, view.cast("B") as byte_view: amount = len(byte_view) while count < amount: count += self.send(byte_view[count:]) # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the # flags argument aren't implemented because websockets doesn't need them. def __getattr__(self, name: str) -> Any: return getattr(self.ssl_socket, name) websockets-15.0.1/src/websockets/sync/connection.py000066400000000000000000001211461476212450300224050ustar00rootroot00000000000000from __future__ import annotations import contextlib import logging import random import socket import struct import threading import time import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType from typing import Any, Literal, overload from ..exceptions import ( ConcurrencyError, ConnectionClosed, ConnectionClosedOK, ProtocolError, ) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol from .messages import Assembler from .utils import Deadline __all__ = ["Connection"] class Connection: """ :mod:`threading` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use :class:`~websockets.sync.client.ClientConnection` or :class:`~websockets.sync.server.ServerConnection`. """ recv_bufsize = 65536 def __init__( self, socket: socket.socket, protocol: Protocol, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( self.protocol.logger, {"websocket": self}, ) # Copy attributes from the protocol for convenience. self.id: uuid.UUID = self.protocol.id """Unique identifier of the connection. Useful in logs.""" self.logger: LoggerLike = self.protocol.logger """Logger for this connection.""" self.debug = self.protocol.debug # HTTP handshake request and response. self.request: Request | None = None """Opening handshake request.""" self.response: Response | None = None """Opening handshake response.""" # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() # Lock stopping reads when the assembler buffer is full. self.recv_flow_control = threading.Lock() # Assembler turning frames into messages and serializing reads. self.recv_messages = Assembler( *self.max_queue, pause=self.recv_flow_control.acquire, resume=self.recv_flow_control.release, ) # Deadline for the closing handshake. self.close_deadline: Deadline | None = None # Whether we are busy sending a fragmented message. self.send_in_progress = False # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ Latency of the connection, in seconds. Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism that sends Ping frames automatically at regular intervals. You can also send Ping frames and measure latency with :meth:`ping`. """ # Thread that sends keepalive pings. None when ping_interval is None. self.keepalive_thread: threading.Thread | None = None # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. self.recv_events_thread = threading.Thread( target=self.recv_events, daemon=True, ) # Start recv_events only after all attributes are initialized. self.recv_events_thread.start() # Public attributes @property def local_address(self) -> Any: """ Local address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family. See :meth:`~socket.socket.getsockname`. """ return self.socket.getsockname() @property def remote_address(self) -> Any: """ Remote address of the connection. For IPv4 connections, this is a ``(host, port)`` tuple. The format of the address depends on the address family. See :meth:`~socket.socket.getpeername`. """ return self.socket.getpeername() @property def state(self) -> State: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should call :meth:`~recv` or :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.state @property def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. :obj:`None` if no subprotocol was negotiated. """ return self.protocol.subprotocol @property def close_code(self) -> int | None: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should inspect attributes of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.close_code @property def close_reason(self) -> str | None: """ State of the WebSocket connection, defined in :rfc:`6455`. This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should inspect attributes of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. """ return self.protocol.close_reason # Public methods def __enter__(self) -> Connection: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: if exc_type is None: self.close() else: self.close(CloseCode.INTERNAL_ERROR) def __iter__(self) -> Iterator[Data]: """ Iterate on incoming messages. The iterator calls :meth:`recv` and yields messages in an infinite loop. It exits when the connection is closed normally. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception after a protocol error or a network failure. """ try: while True: yield self.recv() except ConnectionClosedOK: return # This overload structure is required to avoid the error: # "parameter without a default follows parameter with a default" @overload def recv(self, timeout: float | None, decode: Literal[True]) -> str: ... @overload def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ... @overload def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... @overload def recv( self, timeout: float | None = None, *, decode: Literal[False] ) -> bytes: ... @overload def recv( self, timeout: float | None = None, decode: bool | None = None ) -> Data: ... def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. When the connection is closed, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. This is how you detect the end of the message stream. If ``timeout`` is :obj:`None`, block until a message is received. If ``timeout`` is set, wait up to ``timeout`` seconds for a message to be received and return it, else raise :exc:`TimeoutError`. If ``timeout`` is ``0`` or negative, check if a message has been received already and return it, else raise :exc:`TimeoutError`. If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. Args: timeout: Timeout for receiving a message in seconds. decode: Set this flag to override the default behavior of returning :class:`str` or :class:`bytes`. See below for details. Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and return a bytestring (:class:`bytes`). This improves performance when decoding isn't needed, for example if the message contains JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and return a string (:class:`str`). This may be useful for servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ try: return self.recv_messages.get(timeout, decode) except EOFError: pass # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None except UnicodeDecodeError as exc: with self.send_context(): self.protocol.fail( CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}", ) # fallthrough # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc @overload def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ... @overload def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ... @overload def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ... def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. This method is designed for receiving fragmented messages. It returns an iterator that yields each fragment as it is received. This iterator must be fully consumed. Else, future calls to :meth:`recv` or :meth:`recv_streaming` will raise :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. Args: decode: Set this flag to override the default behavior of returning :class:`str` or :class:`bytes`. See below for details. Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and return bytestrings (:class:`bytes`). This may be useful to optimize performance when decoding isn't needed. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and return strings (:class:`str`). This is useful for servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ try: yield from self.recv_messages.get_iter(decode) return except EOFError: pass # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None except UnicodeDecodeError as exc: with self.send_context(): self.protocol.fail( CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}", ) # fallthrough # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc def send( self, message: Data | Iterable[Data], text: bool | None = None, ) -> None: """ Send a message. A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``text`` argument: * Set ``text=True`` to send a bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. * Set ``text=False`` to send a string (:class:`str`) in a Binary_ frame. This may be useful for servers that expect binary frames instead of text frames. :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) When the connection is closed, :meth:`send` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. Args: message: Message to send. Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. if isinstance(message, str): with self.send_context(): if self.send_in_progress: raise ConcurrencyError( "cannot call send while another thread is already running send" ) if text is False: self.protocol.send_binary(message.encode()) else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): if self.send_in_progress: raise ConcurrencyError( "cannot call send while another thread is already running send" ) if text is True: self.protocol.send_text(message) else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). elif isinstance(message, Mapping): raise TypeError("data is a dict-like object") # Fragmented message -- regular iterator. elif isinstance(message, Iterable): chunks = iter(message) try: chunk = next(chunks) except StopIteration: return try: # First fragment. if isinstance(chunk, str): with self.send_context(): if self.send_in_progress: raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) self.send_in_progress = True if text is False: self.protocol.send_binary(chunk.encode(), fin=False) else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): with self.send_context(): if self.send_in_progress: raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) self.send_in_progress = True if text is True: self.protocol.send_text(chunk, fin=False) else: self.protocol.send_binary(chunk, fin=False) encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress self.protocol.send_continuation(chunk.encode(), fin=False) elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") # Final fragment. with self.send_context(): self.protocol.send_continuation(b"", fin=True) self.send_in_progress = False except ConcurrencyError: # We didn't start sending a fragmented message. # The connection is still usable. raise except Exception: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, "error in fragmented message", ) raise else: raise TypeError("data must be str, bytes, or iterable") def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ Perform the closing handshake. :meth:`close` waits for the other end to complete the handshake, for the TCP connection to terminate, and for all incoming messages to be read with :meth:`recv`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. Args: code: WebSocket close code. reason: WebSocket close reason. """ try: # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. with self.send_context(): if self.send_in_progress: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", ) else: self.protocol.send_close(code, reason) except ConnectionClosed: # Ignore ConnectionClosed exceptions raised from send_context(). # They mean that the connection is closed, which was the goal. pass def ping( self, data: Data | None = None, ack_on_close: bool = False, ) -> threading.Event: """ Send a Ping_. .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point Args: data: Payload of the ping. A :class:`str` will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. ack_on_close: when this option is :obj:`True`, the event will also be set when the connection is closed. While this avoids getting stuck waiting for a pong that will never arrive, it requires checking that the state of the connection is still ``OPEN`` to confirm that a pong was received, rather than the connection being closed. Returns: An event that will be set when the corresponding pong is received. You can ignore it if you don't intend to wait. :: pong_event = ws.ping() pong_event.wait() # only if you want to wait for the pong Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ if isinstance(data, BytesLike): data = bytes(data) elif isinstance(data, str): data = data.encode() elif data is not None: raise TypeError("data must be str or bytes-like") with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.pong_waiters: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pong_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) self.protocol.send_ping(data) return pong_waiter def pong(self, data: Data = b"") -> None: """ Send a Pong_. .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. Args: data: Payload of the pong. A :class:`str` will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. """ if isinstance(data, BytesLike): data = bytes(data) elif isinstance(data, str): data = data.encode() else: raise TypeError("data must be str or bytes-like") with self.send_context(): self.protocol.send_pong(data) # Private methods def process_event(self, event: Event) -> None: """ Process one incoming event. This method is overridden in subclasses to handle the handshake. """ assert isinstance(event, Frame) if event.opcode in DATA_OPCODES: self.recv_messages.put(event) if event.opcode is Opcode.PONG: self.acknowledge_pings(bytes(event.data)) def acknowledge_pings(self, data: bytes) -> None: """ Acknowledge pings when receiving a pong. """ with self.protocol_mutex: # Ignore unsolicited pong. if data not in self.pong_waiters: return pong_timestamp = time.monotonic() # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] for ping_id, ( pong_waiter, ping_timestamp, _ack_on_close, ) in self.pong_waiters.items(): ping_ids.append(ping_id) pong_waiter.set() if ping_id == data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: del self.pong_waiters[ping_id] def acknowledge_pending_pings(self) -> None: """ Acknowledge pending pings when the connection is closed. """ assert self.protocol.state is CLOSED for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): if ack_on_close: pong_waiter.set() self.pong_waiters.clear() def keepalive(self) -> None: """ Send a Ping frame and wait for a Pong frame at regular intervals. """ assert self.ping_interval is not None try: while True: # If self.ping_timeout > self.latency > self.ping_interval, # pings will be sent immediately after receiving pongs. # The period will be longer than self.ping_interval. self.recv_events_thread.join(self.ping_interval - self.latency) if not self.recv_events_thread.is_alive(): break try: pong_waiter = self.ping(ack_on_close=True) except ConnectionClosed: break if self.debug: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: # if pong_waiter.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") else: if self.debug: self.logger.debug("- timed out waiting for keepalive pong") with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) break except Exception: self.logger.error("keepalive ping failed", exc_info=True) def start_keepalive(self) -> None: """ Run :meth:`keepalive` in a thread, unless keepalive is disabled. """ if self.ping_interval is not None: # This thread is marked as daemon like self.recv_events_thread. self.keepalive_thread = threading.Thread( target=self.keepalive, daemon=True, ) self.keepalive_thread.start() def recv_events(self) -> None: """ Read incoming data from the socket and process events. Run this method in a thread as long as the connection is alive. ``recv_events()`` exits immediately when the ``self.socket`` is closed. """ try: while True: try: with self.recv_flow_control: if self.close_deadline is not None: self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: self.logger.debug( "! error while receiving data", exc_info=True, ) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. # In that case, send_context() already set recv_exc. # Calling set_recv_exc() avoids overwriting it. with self.protocol_mutex: self.set_recv_exc(exc) break if data == b"": break # Acquire the connection lock. with self.protocol_mutex: # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. events = self.protocol.events_received() # Write outgoing data to the socket. try: self.send_data() except Exception as exc: if self.debug: self.logger.debug( "! error while sending data", exc_info=True, ) # Similarly to the above, avoid overriding an exception # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() # returns above but before send_data() calls send(). self.set_recv_exc(exc) break if self.protocol.close_expected(): # If the connection is expected to close soon, set the # close deadline based on the close timeout. if self.close_deadline is None: self.close_deadline = Deadline(self.close_timeout) # Unlock conn_mutex before processing events. Else, the # application can't send messages in response to events. # If self.send_data raised an exception, then events are lost. # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. for event in events: # This isn't expected to raise an exception. self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. with self.protocol_mutex: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to raise an exception. events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it handles errors itself. self.send_data() # This code path is triggered when receiving an HTTP response # without a Content-Length header. This is the only case where # reading until EOF generates an event; all other events have # a known length. Ignore for coverage measurement because tests # are in test_client.py rather than test_connection.py. for event in events: # pragma: no cover # This isn't expected to raise an exception. self.process_event(event) except Exception as exc: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: self.set_recv_exc(exc) finally: # This isn't expected to raise an exception. self.close_socket() @contextlib.contextmanager def send_context( self, *, expected_state: State = OPEN, # CONNECTING during the opening handshake ) -> Iterator[None]: """ Create a context for writing to the connection from user code. On entry, :meth:`send_context` acquires the connection lock and checks that the connection is open; on exit, it writes outgoing data to the socket:: with self.send_context(): self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected to close on exit, or when an unexpected error happens, terminating the connection, :meth:`send_context` waits until the connection is closed then raises :exc:`~websockets.exceptions.ConnectionClosed`. """ # Should we wait until the connection is closed? wait_for_close = False # Should we close the socket and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? original_exc: BaseException | None = None # Acquire the protocol lock. with self.protocol_mutex: if self.protocol.state is expected_state: # Let the caller interact with the protocol. try: yield except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: self.logger.error("unexpected internal error", exc_info=True) # This branch should never run. It's a safety net in case of # bugs. Since we don't know what happened, we will close the # connection and raise the exception to the caller. wait_for_close = False raise_close_exc = True original_exc = exc else: # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. # Since we tested earlier that protocol.state was OPEN # (or CONNECTING) and we didn't release protocol_mutex, # it is certain that self.close_deadline is still None. assert self.close_deadline is None self.close_deadline = Deadline(self.close_timeout) # Write outgoing data to the socket. try: self.send_data() except Exception as exc: if self.debug: self.logger.debug( "! error while sending data", exc_info=True, ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False raise_close_exc = True original_exc = exc else: # self.protocol.state is not expected_state # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True raise_close_exc = True # To avoid a deadlock, release the connection lock by exiting the # context manager before waiting for recv_events() to terminate. # If the connection is expected to close soon and the close timeout # elapses, close the socket to terminate the connection. if wait_for_close: if self.close_deadline is None: timeout = self.close_timeout else: # Thread.join() returns immediately if timeout is negative. timeout = self.close_deadline.timeout(raise_if_elapsed=False) self.recv_events_thread.join(timeout) if self.recv_events_thread.is_alive(): # There's no risk to overwrite another error because # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") # Set recv_exc before closing the socket in order to get # proper exception reporting. raise_close_exc = True with self.protocol_mutex: self.set_recv_exc(original_exc) # If an error occurred, close the socket to terminate the connection and # raise an exception. if raise_close_exc: self.close_socket() # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from original_exc def send_data(self) -> None: """ Send outgoing data. This method requires holding protocol_mutex. Raises: OSError: When a socket operations fails. """ assert self.protocol_mutex.locked() for data in self.protocol.data_to_send(): if data: if self.close_deadline is not None: self.socket.settimeout(self.close_deadline.timeout()) self.socket.sendall(data) else: try: self.socket.shutdown(socket.SHUT_WR) except OSError: # socket already closed pass def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. This method requires holding protocol_mutex. """ assert self.protocol_mutex.locked() if self.recv_exc is None: # pragma: no branch self.recv_exc = exc def close_socket(self) -> None: """ Shutdown and close socket. Close message assembler. Calling close_socket() guarantees that recv_events() terminates. Indeed, recv_events() may block only on socket.recv() or on recv_messages.put(). """ # shutdown() is required to interrupt recv() on Linux. try: self.socket.shutdown(socket.SHUT_RDWR) except OSError: pass # socket is already closed self.socket.close() # Calling protocol.receive_eof() is safe because it's idempotent. # This guarantees that the protocol state becomes CLOSED. with self.protocol_mutex: self.protocol.receive_eof() assert self.protocol.state is CLOSED # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() # Acknowledge pings sent with the ack_on_close option. self.acknowledge_pending_pings() websockets-15.0.1/src/websockets/sync/messages.py000066400000000000000000000304771476212450300220630ustar00rootroot00000000000000from __future__ import annotations import codecs import queue import threading from typing import Any, Callable, Iterable, Iterator, Literal, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data from .utils import Deadline __all__ = ["Assembler"] UTF8Decoder = codecs.getincrementaldecoder("utf-8") class Assembler: """ Assemble messages from frames. :class:`Assembler` expects only data frames. The stream of frames must respect the protocol; if it doesn't, the behavior is undefined. Args: pause: Called when the buffer of frames goes above the high water mark; should pause reading from the network. resume: Called when the buffer of frames goes below the low water mark; should resume reading from the network. """ def __init__( self, high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() # Queue of incoming frames. self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() # We cannot put a hard limit on the size of the queue because a single # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. if high is not None and low is None: low = high // 4 if high is None and low is not None: high = low * 4 if high is not None and low is not None: if low < 0: raise ValueError("low must be positive or equal to zero") if high < low: raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False # This flag marks the end of the connection. self.closed = False def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. if self.closed: try: frame = self.frames.get(block=False) except queue.Empty: raise EOFError("stream of frames ended") from None else: try: # Check for a frame that's already received if timeout <= 0. # SimpleQueue.get() doesn't support negative timeout values. if timeout is not None and timeout <= 0: frame = self.frames.get(block=False) else: frame = self.frames.get(block=True, timeout=timeout) except queue.Empty: raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame def reset_queue(self, frames: Iterable[Frame]) -> None: # Helper to put frames back into the queue after they were fetched. # This happens only when the queue is empty. However, by the time # we acquire self.mutex, put() may have added items in the queue. # Therefore, we must handle the case where the queue is not empty. frame: Frame | None with self.mutex: queued = [] try: while True: queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: self.frames.put(frame) # This loop runs only when a race condition occurs. for frame in queued: # pragma: no cover self.frames.put(frame) # This overload structure is required to avoid the error: # "parameter without a default follows parameter with a default" @overload def get(self, timeout: float | None, decode: Literal[True]) -> str: ... @overload def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ... @overload def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... @overload def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ... @overload def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ... def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. :meth:`get` returns a single :class:`str` or :class:`bytes`. If the message is fragmented, :meth:`get` waits until the last frame is received, then it reassembles the message and returns it. To receive messages frame by frame, use :meth:`get_iter` instead. Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. decode: :obj:`False` disables UTF-8 decoding of text frames and returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. """ with self.mutex: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution # until get() fetches a complete message or times out. try: deadline = Deadline(timeout) # First frame frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) with self.mutex: self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: decode = frame.opcode is OP_TEXT frames = [frame] # Following frames, for fragmented messages while not frame.fin: try: frame = self.get_next_frame( deadline.timeout(raise_if_elapsed=False) ) except TimeoutError: # Put frames already received back into the queue # so that future calls to get() can return them. self.reset_queue(frames) raise with self.mutex: self.maybe_resume() assert frame.opcode is OP_CONT frames.append(frame) finally: self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: return data.decode() else: return data @overload def get_iter(self, decode: Literal[True]) -> Iterator[str]: ... @overload def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ... @overload def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ... def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. Iterating the return value of :meth:`get_iter` yields a :class:`str` or :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. Args: decode: :obj:`False` disables UTF-8 decoding of text frames and returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. """ with self.mutex: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution # until get_iter() fetches a complete message or times out. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. # First frame frame = self.get_next_frame() with self.mutex: self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: decode = frame.opcode is OP_TEXT if decode: decoder = UTF8Decoder() yield decoder.decode(frame.data, frame.fin) else: yield frame.data # Following frames, for fragmented messages while not frame.fin: frame = self.get_next_frame() with self.mutex: self.maybe_resume() assert frame.opcode is OP_CONT if decode: yield decoder.decode(frame.data, frame.fin) else: yield frame.data self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. Raises: EOFError: If the stream of frames has ended. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") self.frames.put(frame) self.maybe_pause() # put() and get/get_iter() call maybe_pause() and maybe_resume() while # holding self.mutex. This guarantees that the calls interleave properly. # Specifically, it prevents a race condition where maybe_resume() would # run before maybe_pause(), leaving the connection incorrectly paused. # A race condition is possible when get/get_iter() call self.frames.get() # without holding self.mutex. However, it's harmless — and even beneficial! # It can only result in popping an item from the queue before maybe_resume() # runs and skipping a pause() - resume() cycle that would otherwise occur. def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" # Skip if flow control is disabled if self.high is None: return assert self.mutex.locked() # Check for "> high" to support high = 0 if self.frames.qsize() > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" # Skip if flow control is disabled if self.low is None: return assert self.mutex.locked() # Check for "<= low" to support low = 0 if self.frames.qsize() <= self.low and self.paused: self.paused = False self.resume() def close(self) -> None: """ End the stream of frames. Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ with self.mutex: if self.closed: return self.closed = True if self.get_in_progress: # Unblock get() or get_iter(). self.frames.put(None) if self.paused: # Unblock recv_events(). self.paused = False self.resume() websockets-15.0.1/src/websockets/sync/router.py000066400000000000000000000142231476212450300215630ustar00rootroot00000000000000from __future__ import annotations import http import ssl as ssl_module import urllib.parse from typing import Any, Callable, Literal from werkzeug.exceptions import NotFound from werkzeug.routing import Map, RequestRedirect from ..http11 import Request, Response from .server import Server, ServerConnection, serve __all__ = ["route", "unix_route", "Router"] class Router: """WebSocket router supporting :func:`route`.""" def __init__( self, url_map: Map, server_name: str | None = None, url_scheme: str = "ws", ) -> None: self.url_map = url_map self.server_name = server_name self.url_scheme = url_scheme for rule in self.url_map.iter_rules(): rule.websocket = True def get_server_name(self, connection: ServerConnection, request: Request) -> str: if self.server_name is None: return request.headers["Host"] else: return self.server_name def redirect(self, connection: ServerConnection, url: str) -> Response: response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") response.headers["Location"] = url return response def not_found(self, connection: ServerConnection) -> Response: return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") def route_request( self, connection: ServerConnection, request: Request ) -> Response | None: """Route incoming request.""" url_map_adapter = self.url_map.bind( server_name=self.get_server_name(connection, request), url_scheme=self.url_scheme, ) try: parsed = urllib.parse.urlparse(request.path) handler, kwargs = url_map_adapter.match( path_info=parsed.path, query_args=parsed.query, ) except RequestRedirect as redirect: return self.redirect(connection, redirect.new_url) except NotFound: return self.not_found(connection) connection.handler, connection.handler_kwargs = handler, kwargs return None def handler(self, connection: ServerConnection) -> None: """Handle a connection.""" return connection.handler(connection, **connection.handler_kwargs) def route( url_map: Map, *args: Any, server_name: str | None = None, ssl: ssl_module.SSLContext | Literal[True] | None = None, create_router: type[Router] | None = None, **kwargs: Any, ) -> Server: """ Create a WebSocket server dispatching connections to different handlers. This feature requires the third-party library `werkzeug`_: .. code-block:: console $ pip install werkzeug .. _werkzeug: https://werkzeug.palletsprojects.com/ :func:`route` accepts the same arguments as :func:`~websockets.sync.server.serve`, except as described below. The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns to connection handlers. In addition to the connection, handlers receive parameters captured in the URL as keyword arguments. Here's an example:: from websockets.sync.router import route from werkzeug.routing import Map, Rule def channel_handler(websocket, channel_id): ... url_map = Map([ Rule("/channel/", endpoint=channel_handler), ... ]) with route(url_map, ...) as server: server.serve_forever() Refer to the documentation of :mod:`werkzeug.routing` for details. If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, when the server runs behind a reverse proxy that modifies the ``Host`` header or terminates TLS, you need additional configuration: * Set ``server_name`` to the name of the server as seen by clients. When not provided, websockets uses the value of the ``Host`` header. * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling TLS. Under the hood, this bind the URL map with a ``url_scheme`` of ``wss://`` instead of ``ws://``. There is no need to specify ``websocket=True`` in each rule. It is added automatically. Args: url_map: Mapping of URL patterns to connection handlers. server_name: Name of the server as seen by clients. If :obj:`None`, websockets uses the value of the ``Host`` header. ssl: Configuration for enabling TLS on the connection. Set it to :obj:`True` if a reverse proxy terminates TLS connections. create_router: Factory for the :class:`Router` dispatching requests to handlers. Set it to a wrapper or a subclass to customize routing. """ url_scheme = "ws" if ssl is None else "wss" if ssl is not True and ssl is not None: kwargs["ssl"] = ssl if create_router is None: create_router = Router router = create_router(url_map, server_name, url_scheme) _process_request: ( Callable[ [ServerConnection, Request], Response | None, ] | None ) = kwargs.pop("process_request", None) if _process_request is None: process_request: Callable[ [ServerConnection, Request], Response | None, ] = router.route_request else: def process_request( connection: ServerConnection, request: Request ) -> Response | None: response = _process_request(connection, request) if response is not None: return response return router.route_request(connection, request) return serve(router.handler, *args, process_request=process_request, **kwargs) def unix_route( url_map: Map, path: str | None = None, **kwargs: Any, ) -> Server: """ Create a WebSocket Unix server dispatching connections to different handlers. :func:`unix_route` combines the behaviors of :func:`route` and :func:`~websockets.sync.server.unix_serve`. Args: url_map: Mapping of URL patterns to connection handlers. path: File system path to the Unix socket. """ return route(url_map, unix=True, path=path, **kwargs) websockets-15.0.1/src/websockets/sync/server.py000066400000000000000000000654541476212450300215650ustar00rootroot00000000000000from __future__ import annotations import hmac import http import logging import os import re import selectors import socket import ssl as ssl_module import sys import threading import warnings from collections.abc import Iterable, Sequence from types import TracebackType from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode from ..headers import ( build_www_authenticate_basic, parse_authorization_basic, validate_subprotocols, ) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection from .utils import Deadline __all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] class ServerConnection(Connection): """ :mod:`threading` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. It supports iteration to receive messages:: for message in websocket: process(message) The iterator exits normally when the connection is closed with close code 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and ``max_queue`` arguments have the same meaning as in :func:`serve`. Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. """ def __init__( self, socket: socket.socket, protocol: ServerProtocol, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() super().__init__( socket, protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) self.username: str # see basic_auth() self.handler: Callable[[ServerConnection], None] # see route() self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ Create a plain text HTTP response. ``process_request`` and ``process_response`` may call this method to return an HTTP response instead of performing the WebSocket opening handshake. You can modify the response before returning it, for example by changing HTTP headers. Args: status: HTTP status code. text: HTTP response body; it will be encoded to UTF-8. Returns: HTTP response to send to the client. """ return self.protocol.reject(status, text) def handshake( self, process_request: ( Callable[ [ServerConnection, Request], Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], Response | None, ] | None ) = None, server_header: str | None = SERVER, timeout: float | None = None, ) -> None: """ Perform the opening handshake. """ if not self.request_rcvd.wait(timeout): raise TimeoutError("timed out while waiting for handshake request") if self.request is not None: with self.send_context(expected_state=CONNECTING): response = None if process_request is not None: try: response = process_request(self, self.request) except Exception as exc: self.protocol.handshake_exc = exc response = self.protocol.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), ) if response is None: self.response = self.protocol.accept(self.request) else: self.response = response if server_header: self.response.headers["Server"] = server_header response = None if process_response is not None: try: response = process_response(self, self.request, self.response) except Exception as exc: self.protocol.handshake_exc = exc response = self.protocol.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), ) if response is not None: self.response = response self.protocol.send_response(self.response) # self.protocol.handshake_exc is set when the connection is lost before # receiving a request, when the request cannot be parsed, or when the # handshake fails, including when process_request or process_response # raises an exception. # It isn't set when process_request or process_response sends an HTTP # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ Process one incoming event. """ # First event - handshake request. if self.request is None: assert isinstance(event, Request) self.request = event self.request_rcvd.set() # Later events - frames. else: super().process_event(event) def recv_events(self) -> None: """ Read incoming data from the socket and process events. """ try: super().recv_events() finally: # If the connection is closed during the handshake, unblock it. self.request_rcvd.set() class Server: """ WebSocket server returned by :func:`serve`. This class mirrors the API of :class:`~socketserver.BaseServer`, notably the :meth:`~socketserver.BaseServer.serve_forever` and :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context manager protocol. Args: socket: Server socket listening for new connections. handler: Handler for one connection. Receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. """ def __init__( self, socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, ) -> None: self.socket = socket self.handler = handler if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger if sys.platform != "win32": self.shutdown_watcher, self.shutdown_notifier = os.pipe() def serve_forever(self) -> None: """ See :meth:`socketserver.BaseServer.serve_forever`. This method doesn't return. Calling :meth:`shutdown` from another thread stops the server. Typical use:: with serve(...) as server: server.serve_forever() """ poller = selectors.DefaultSelector() try: poller.register(self.socket, selectors.EVENT_READ) except ValueError: # pragma: no cover # If shutdown() is called before poller.register(), # the socket is closed and poller.register() raises # ValueError: Invalid file descriptor: -1 return if sys.platform != "win32": poller.register(self.shutdown_watcher, selectors.EVENT_READ) while True: poller.select() try: # If the socket is closed, this will raise an exception and exit # the loop. So we don't need to check the return value of select(). sock, addr = self.socket.accept() except OSError: break # Since there isn't a mechanism for tracking connections and waiting # for them to terminate, we cannot use daemon threads, or else all # connections would be terminate brutally when closing the server. thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() def shutdown(self) -> None: """ See :meth:`socketserver.BaseServer.shutdown`. """ self.socket.close() if sys.platform != "win32": os.write(self.shutdown_notifier, b"x") def fileno(self) -> int: """ See :meth:`socketserver.BaseServer.fileno`. """ return self.socket.fileno() def __enter__(self) -> Server: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: self.shutdown() def __getattr__(name: str) -> Any: if name == "WebSocketServer": warnings.warn( # deprecated in 13.0 - 2024-08-20 "WebSocketServer was renamed to Server", DeprecationWarning, ) return Server raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def serve( handler: Callable[[ServerConnection], None], host: str | None = None, port: int | None = None, *, # TCP/TLS sock: socket.socket | None = None, ssl: ssl_module.SSLContext | None = None, # WebSocket origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], Subprotocol | None, ] | None ) = None, compression: str | None = "deflate", # HTTP process_request: ( Callable[ [ServerConnection, Request], Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], Response | None, ] | None ) = None, server_header: str | None = SERVER, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ServerConnection] | None = None, **kwargs: Any, ) -> Server: """ Create a WebSocket server listening on ``host`` and ``port``. Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler``. The handler receives the :class:`ServerConnection` instance, which you can use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure that it will be closed and call :meth:`~Server.serve_forever` to serve requests:: from websockets.sync.server import serve def handler(websocket): ... with serve(handler, ...) as server: server.serve_forever() Args: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. host: Network interfaces the server binds to. See :func:`~socket.create_server` for details. port: TCP port the server listens on. See :func:`~socket.create_server` for details. sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Values can be :class:`str` to test for an exact match or regular expressions compiled by :func:`re.compile` to test against a pattern. Include :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It receives a :class:`ServerConnection` (not a :class:`~websockets.server.ServerProtocol`!) instance and a list of subprotocols offered by the client. Other than the first argument, it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. Modify the response or return a new HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water and low-water marks. If you want to disable flow control entirely, you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. create_connection: Factory for the :class:`ServerConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. Any other keyword arguments are passed to :func:`~socket.create_server`. """ # Process parameters # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") warnings.warn( # deprecated in 13.0 - 2024-08-20 "ssl_context was renamed to ssl", DeprecationWarning, ) if subprotocols is not None: validate_subprotocols(subprotocols) if compression == "deflate": extensions = enable_server_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") if create_connection is None: create_connection = ServerConnection # Bind socket and listen # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) path: str | None = kwargs.pop("path", None) if sock is None: if unix: if path is None: raise ValueError("missing path argument") kwargs.setdefault("family", socket.AF_UNIX) sock = socket.create_server(path, **kwargs) else: sock = socket.create_server((host, port), **kwargs) else: if path is not None: raise ValueError("path and sock arguments are incompatible") # Initialize TLS wrapper if ssl is not None: sock = ssl.wrap_socket( sock, server_side=True, # Delay TLS handshake until after we set a timeout on the socket. do_handshake_on_connect=False, ) # Define request handler def conn_handler(sock: socket.socket, addr: Any) -> None: # Calculate timeouts on the TLS and WebSocket handshakes. # The TLS timeout must be set on the socket, then removed # to avoid conflicting with the WebSocket timeout in handshake(). deadline = Deadline(open_timeout) try: # Disable Nagle algorithm if not unix: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) # Perform TLS handshake if ssl is not None: sock.settimeout(deadline.timeout()) # mypy cannot figure this out assert isinstance(sock, ssl_module.SSLSocket) sock.do_handshake() sock.settimeout(None) # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], Subprotocol | None, ] | None ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) # Initialize WebSocket protocol protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, max_size=max_size, logger=logger, ) # Initialize WebSocket connection assert create_connection is not None # help mypy connection = create_connection( sock, protocol, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) except Exception: sock.close() return try: try: connection.handshake( process_request, process_response, server_header, deadline.timeout(), ) except TimeoutError: connection.close_socket() connection.recv_events_thread.join() return except Exception: connection.logger.error("opening handshake failed", exc_info=True) connection.close_socket() connection.recv_events_thread.join() return assert connection.protocol.state is OPEN try: connection.start_keepalive() handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) connection.close(CloseCode.INTERNAL_ERROR) else: connection.close() except Exception: # pragma: no cover # Don't leak sockets on unexpected errors. sock.close() # Initialize server return Server(sock, conn_handler, logger) def unix_serve( handler: Callable[[ServerConnection], None], path: str | None = None, **kwargs: Any, ) -> Server: """ Create a WebSocket server listening on a Unix socket. This function accepts the same keyword arguments as :func:`serve`. It's only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. Args: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. path: File system path to the Unix socket. """ return serve(handler, unix=True, path=path, **kwargs) def is_credentials(credentials: Any) -> bool: try: username, password = credentials except (TypeError, ValueError): return False else: return isinstance(username, str) and isinstance(password, str) def basic_auth( realm: str = "", credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, check_credentials: Callable[[str, str], bool] | None = None, ) -> Callable[[ServerConnection, Request], Response | None]: """ Factory for ``process_request`` to enforce HTTP Basic Authentication. :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: from websockets.sync.server import basic_auth, serve with serve( ..., process_request=basic_auth( realm="my dev server", credentials=("hello", "iloveyou"), ), ): If authentication succeeds, the connection's ``username`` attribute is set. If it fails, the server responds with an HTTP 401 Unauthorized status. One of ``credentials`` or ``check_credentials`` must be provided; not both. Args: realm: Scope of protection. It should contain only ASCII characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. check_credentials: Function that verifies credentials. It receives ``username`` and ``password`` arguments and returns whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. ValueError: If ``credentials`` and ``check_credentials`` are both provided or both not provided. """ if (credentials is None) == (check_credentials is None): raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: raise TypeError(f"invalid credentials argument: {credentials}") credentials_dict = dict(credentials_list) def check_credentials(username: str, password: str) -> bool: try: expected_password = credentials_dict[username] except KeyError: return False return hmac.compare_digest(expected_password, password) assert check_credentials is not None # help mypy def process_request( connection: ServerConnection, request: Request, ) -> Response | None: """ Perform HTTP Basic Authentication. If it succeeds, set the connection's ``username`` attribute and return :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. """ try: authorization = request.headers["Authorization"] except KeyError: response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Missing credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response try: username, password = parse_authorization_basic(authorization) except InvalidHeader: response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Unsupported credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response if not check_credentials(username, password): response = connection.respond( http.HTTPStatus.UNAUTHORIZED, "Invalid credentials\n", ) response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) return response connection.username = username return None return process_request websockets-15.0.1/src/websockets/sync/utils.py000066400000000000000000000021231476212450300213770ustar00rootroot00000000000000from __future__ import annotations import time __all__ = ["Deadline"] class Deadline: """ Manage timeouts across multiple steps. Args: timeout: Time available in seconds or :obj:`None` if there is no limit. """ def __init__(self, timeout: float | None) -> None: self.deadline: float | None if timeout is None: self.deadline = None else: self.deadline = time.monotonic() + timeout def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: """ Calculate a timeout from a deadline. Args: raise_if_elapsed: Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: TimeoutError: If the deadline lapsed. Returns: Time left in seconds or :obj:`None` if there is no limit. """ if self.deadline is None: return None timeout = self.deadline - time.monotonic() if raise_if_elapsed and timeout <= 0: raise TimeoutError("timed out") return timeout websockets-15.0.1/src/websockets/typing.py000066400000000000000000000037511476212450300206050ustar00rootroot00000000000000from __future__ import annotations import http import logging from typing import TYPE_CHECKING, Any, NewType, Optional, Sequence, Union __all__ = [ "Data", "LoggerLike", "StatusLike", "Origin", "Subprotocol", "ExtensionName", "ExtensionParameter", ] # Public types used in the signature of public APIs # Change to str | bytes when dropping Python < 3.10. Data = Union[str, bytes] """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 """ # Change to logging.Logger | ... when dropping Python < 3.10. if TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" else: # remove this branch when dropping support for Python < 3.11 LoggerLike = Union[logging.Logger, logging.LoggerAdapter] """Types accepted where a :class:`~logging.Logger` is expected.""" # Change to http.HTTPStatus | int when dropping Python < 3.10. StatusLike = Union[http.HTTPStatus, int] """ Types accepted where an :class:`~http.HTTPStatus` is expected.""" Origin = NewType("Origin", str) """Value of a ``Origin`` header.""" Subprotocol = NewType("Subprotocol", str) """Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" # Change to tuple[str, str | None] when dropping Python < 3.10. ExtensionParameter = tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types ExtensionHeader = tuple[ExtensionName, Sequence[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" ConnectionOption = NewType("ConnectionOption", str) """Connection option in a ``Connection`` header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) """Upgrade protocol in an ``Upgrade`` header.""" websockets-15.0.1/src/websockets/uri.py000066400000000000000000000155121476212450300200700ustar00rootroot00000000000000from __future__ import annotations import dataclasses import urllib.parse import urllib.request from .exceptions import InvalidProxy, InvalidURI __all__ = ["parse_uri", "WebSocketURI"] # All characters from the gen-delims and sub-delims sets in RFC 3987. DELIMS = ":/?#[]@!$&'()*+,;=" @dataclasses.dataclass class WebSocketURI: """ WebSocket URI. Attributes: secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. host: Normalized to lower case. port: Always set even if it's the default. path: May be empty. query: May be empty if the URI doesn't include a query component. username: Available when the URI contains `User Information`_. password: Available when the URI contains `User Information`_. .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 """ secure: bool host: str port: int path: str query: str username: str | None = None password: str | None = None @property def resource_name(self) -> str: if self.path: resource_name = self.path else: resource_name = "/" if self.query: resource_name += "?" + self.query return resource_name @property def user_info(self) -> tuple[str, str] | None: if self.username is None: return None assert self.password is not None return (self.username, self.password) def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. Args: uri: WebSocket URI. Returns: Parsed WebSocket URI. Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) if parsed.scheme not in ["ws", "wss"]: raise InvalidURI(uri, "scheme isn't ws or wss") if parsed.hostname is None: raise InvalidURI(uri, "hostname isn't provided") if parsed.fragment != "": raise InvalidURI(uri, "fragment identifier is meaningless") secure = parsed.scheme == "wss" host = parsed.hostname port = parsed.port or (443 if secure else 80) path = parsed.path query = parsed.query username = parsed.username password = parsed.password # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if username is not None and password is None: raise InvalidURI(uri, "username provided without password") try: uri.encode("ascii") except UnicodeEncodeError: # Input contains non-ASCII characters. # It must be an IRI. Convert it to a URI. host = host.encode("idna").decode() path = urllib.parse.quote(path, safe=DELIMS) query = urllib.parse.quote(query, safe=DELIMS) if username is not None: assert password is not None username = urllib.parse.quote(username, safe=DELIMS) password = urllib.parse.quote(password, safe=DELIMS) return WebSocketURI(secure, host, port, path, query, username, password) @dataclasses.dataclass class Proxy: """ Proxy. Attributes: scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, ``"https"``, or ``"http"``. host: Normalized to lower case. port: Always set even if it's the default. username: Available when the proxy address contains `User Information`_. password: Available when the proxy address contains `User Information`_. .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 """ scheme: str host: str port: int username: str | None = None password: str | None = None @property def user_info(self) -> tuple[str, str] | None: if self.username is None: return None assert self.password is not None return (self.username, self.password) def parse_proxy(proxy: str) -> Proxy: """ Parse and validate a proxy. Args: proxy: proxy. Returns: Parsed proxy. Raises: InvalidProxy: If ``proxy`` isn't a valid proxy. """ parsed = urllib.parse.urlparse(proxy) if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") if parsed.hostname is None: raise InvalidProxy(proxy, "hostname isn't provided") if parsed.path not in ["", "/"]: raise InvalidProxy(proxy, "path is meaningless") if parsed.query != "": raise InvalidProxy(proxy, "query is meaningless") if parsed.fragment != "": raise InvalidProxy(proxy, "fragment is meaningless") scheme = parsed.scheme host = parsed.hostname port = parsed.port or (443 if parsed.scheme == "https" else 80) username = parsed.username password = parsed.password # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if username is not None and password is None: raise InvalidProxy(proxy, "username provided without password") try: proxy.encode("ascii") except UnicodeEncodeError: # Input contains non-ASCII characters. # It must be an IRI. Convert it to a URI. host = host.encode("idna").decode() if username is not None: assert password is not None username = urllib.parse.quote(username, safe=DELIMS) password = urllib.parse.quote(password, safe=DELIMS) return Proxy(scheme, host, port, username, password) def get_proxy(uri: WebSocketURI) -> str | None: """ Return the proxy to use for connecting to the given WebSocket URI, if any. """ if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): return None # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if # available, else favor the proxy for HTTPS connections over the proxy for # HTTP connections. # The priority of a proxy for WebSocket connections is unspecified. We give # it the highest priority. This makes it easy to configure a specific proxy # for websockets. # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or # as {"https": "socks5h://host:port"} depending on whether they're declared # in the operating system or in environment variables. proxies = urllib.request.getproxies() if uri.secure: schemes = ["wss", "socks", "https"] else: schemes = ["ws", "socks", "https", "http"] for scheme in schemes: proxy = proxies.get(scheme) if proxy is not None: if scheme == "socks" and proxy.startswith("http://"): proxy = "socks5h://" + proxy[7:] return proxy else: return None websockets-15.0.1/src/websockets/utils.py000066400000000000000000000021761476212450300204330ustar00rootroot00000000000000from __future__ import annotations import base64 import hashlib import secrets import sys __all__ = ["accept_key", "apply_mask"] GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" def generate_key() -> str: """ Generate a random key for the Sec-WebSocket-Key header. """ key = secrets.token_bytes(16) return base64.b64encode(key).decode() def accept_key(key: str) -> str: """ Compute the value of the Sec-WebSocket-Accept header. Args: key: Value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() return base64.b64encode(sha1).decode() def apply_mask(data: bytes, mask: bytes) -> bytes: """ Apply masking to the data of a WebSocket message. Args: data: Data to mask. mask: 4-bytes mask. """ if len(mask) != 4: raise ValueError("mask must contain 4 bytes") data_int = int.from_bytes(data, sys.byteorder) mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4] mask_int = int.from_bytes(mask_repeated, sys.byteorder) return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder) websockets-15.0.1/src/websockets/version.py000066400000000000000000000062041476212450300207540ustar00rootroot00000000000000from __future__ import annotations import importlib.metadata __all__ = ["tag", "version", "commit"] # ========= =========== =================== # release development # ========= =========== =================== # tag X.Y X.Y (upcoming) # version X.Y X.Y.dev1+g5678cde # commit X.Y 5678cde # ========= =========== =================== # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. released = True tag = version = commit = "15.0.1" if not released: # pragma: no cover import pathlib import re import subprocess def get_version(tag: str) -> str: # Since setup.py executes the contents of src/websockets/version.py, # __file__ can point to either of these two files. file_path = pathlib.Path(__file__) root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] # Read version from package metadata if it is installed. try: version = importlib.metadata.version("websockets") except ImportError: pass else: # Check that this file belongs to the installed package. files = importlib.metadata.files("websockets") if files: version_files = [f for f in files if f.name == file_path.name] if version_files: version_file = version_files[0] if version_file.locate() == file_path: return version # Read version from git if available. try: description = subprocess.run( ["git", "describe", "--dirty", "--tags", "--long"], capture_output=True, cwd=root_dir, timeout=1, check=True, text=True, ).stdout.strip() # subprocess.run raises FileNotFoundError if git isn't on $PATH. except ( FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired, ): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) if match is None: raise ValueError(f"Unexpected git description: {description}") distance, remainder = match.groups() remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" # Avoid crashing if the development version cannot be determined. return f"{tag}.dev0+gunknown" version = get_version(tag) def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) if match is None: raise ValueError(f"Unexpected version: {version}") (commit,) = match.groups() return tag if commit == "unknown" else commit commit = get_commit(tag, version) websockets-15.0.1/tests/000077500000000000000000000000001476212450300151155ustar00rootroot00000000000000websockets-15.0.1/tests/__init__.py000066400000000000000000000005461476212450300172330ustar00rootroot00000000000000import logging import os format = "%(asctime)s %(levelname)s %(name)s %(message)s" if bool(os.environ.get("WEBSOCKETS_DEBUG")): # pragma: no cover # Display every frame sent or received in debug mode. level = logging.DEBUG else: # Hide stack traces of exceptions. level = logging.CRITICAL logging.basicConfig(format=format, level=level) websockets-15.0.1/tests/asyncio/000077500000000000000000000000001476212450300165625ustar00rootroot00000000000000websockets-15.0.1/tests/asyncio/__init__.py000066400000000000000000000000001476212450300206610ustar00rootroot00000000000000websockets-15.0.1/tests/asyncio/connection.py000066400000000000000000000063731476212450300213040ustar00rootroot00000000000000import asyncio import contextlib from websockets.asyncio.connection import Connection class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ def connection_made(self, transport): super().connection_made(InterceptingTransport(transport)) @contextlib.contextmanager def delay_frames_sent(self, delay): """ Add a delay before sending frames. This can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write is None self.transport.delay_write = delay try: yield finally: self.transport.delay_write = None @contextlib.contextmanager def delay_eof_sent(self, delay): """ Add a delay before sending EOF. This can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write_eof is None self.transport.delay_write_eof = delay try: yield finally: self.transport.delay_write_eof = None @contextlib.contextmanager def drop_frames_sent(self): """ Prevent frames from being sent. Since TCP is reliable, sending frames or EOF afterwards is unrealistic. """ assert not self.transport.drop_write self.transport.drop_write = True try: yield finally: self.transport.drop_write = False @contextlib.contextmanager def drop_eof_sent(self): """ Prevent EOF from being sent. Since TCP is reliable, sending frames or EOF afterwards is unrealistic. """ assert not self.transport.drop_write_eof self.transport.drop_write_eof = True try: yield finally: self.transport.drop_write_eof = False class InterceptingTransport: """ Transport wrapper that intercepts calls to ``write()`` and ``write_eof()``. This is coupled to the implementation, which relies on these two methods. Since ``write()`` and ``write_eof()`` are not coroutines, this effect is achieved by scheduling writes at a later time, after the methods return. This can easily result in out-of-order writes, which is unrealistic. """ def __init__(self, transport): self.loop = asyncio.get_running_loop() self.transport = transport self.delay_write = None self.delay_write_eof = None self.drop_write = False self.drop_write_eof = False def __getattr__(self, name): return getattr(self.transport, name) def write(self, data): if not self.drop_write: if self.delay_write is not None: self.loop.call_later(self.delay_write, self.transport.write, data) else: self.transport.write(data) def write_eof(self): if not self.drop_write_eof: if self.delay_write_eof is not None: self.loop.call_later(self.delay_write_eof, self.transport.write_eof) else: self.transport.write_eof() websockets-15.0.1/tests/asyncio/server.py000066400000000000000000000024531476212450300204460ustar00rootroot00000000000000import asyncio import socket import urllib.parse def get_host_port(server): for sock in server.sockets: if sock.family == socket.AF_INET: # pragma: no branch return sock.getsockname() raise AssertionError("expected at least one IPv4 socket") def get_uri(server, secure=None): if secure is None: secure = server.server._ssl_context is not None # hack protocol = "wss" if secure else "ws" host, port = get_host_port(server) return f"{protocol}://{host}:{port}" async def handler(ws): path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. async for expr in ws: value = eval(expr) await ws.send(str(value)) elif path == "/crash": raise RuntimeError elif path == "/no-op": pass elif path == "/delay": delay = float(await ws.recv()) await ws.close() await asyncio.sleep(delay) else: raise AssertionError(f"unexpected path: {path}") # This shortcut avoids repeating serve(handler, "localhost", 0) for every test. args = handler, "localhost", 0 class EvalShellMixin: async def assertEval(self, client, expr, value): await client.send(expr) self.assertEqual(await client.recv(), value) websockets-15.0.1/tests/asyncio/test_client.py000066400000000000000000001261351476212450300214610ustar00rootroot00000000000000import asyncio import contextlib import http import logging import os import socket import ssl import sys import unittest from unittest.mock import patch from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve from websockets.client import backoff from websockets.exceptions import ( InvalidHandshake, InvalidMessage, InvalidProxy, InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, SecurityError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from ..proxy import ProxyMixin from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler # Decorate tests that need it with @short_backoff_delay() instead of using it as # a context manager when dropping support for Python < 3.10. @contextlib.asynccontextmanager async def short_backoff_delay(): defaults = backoff.__defaults__ backoff.__defaults__ = ( defaults[0] * MS, defaults[1] * MS, defaults[2] * MS, defaults[3], ) try: yield finally: backoff.__defaults__ = defaults # Decorate tests that need it with @few_redirects() instead of using it as a # context manager when dropping support for Python < 3.10. @contextlib.asynccontextmanager async def few_redirects(): from websockets.asyncio import client max_redirects = client.MAX_REDIRECTS client.MAX_REDIRECTS = 2 try: yield finally: client.MAX_REDIRECTS = max_redirects class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_explicit_host_port(self): """Client connects using an explicit host / port.""" async with serve(*args) as server: host, port = get_host_port(server) async with connect("ws://overridden/", host=host, port=port) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_compression_is_enabled(self): """Client enables compression by default.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], ) async def test_disable_compression(self): """Client disables compression.""" async with serve(*args) as server: async with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" async with serve(*args) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") async def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" async with serve(*args) as server: async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" async with serve(*args) as server: async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) async def test_legacy_user_agent(self): """Client can override User-Agent header with additional_headers.""" async with serve(*args) as server: async with connect( get_uri(server), additional_headers={"User-Agent": "Smith"} ) as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with serve(*args) as server: async with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) await asyncio.sleep(2 * MS) self.assertGreater(client.latency, 0) async def test_disable_keepalive(self): """Client disables keepalive.""" async with serve(*args) as server: async with connect(get_uri(server), ping_interval=None) as client: await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) async def test_logger(self): """Client accepts a logger argument.""" logger = logging.getLogger("test") async with serve(*args) as server: async with connect(get_uri(server), logger=logger) as client: self.assertEqual(client.logger.name, logger.name) async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" def create_connection(*args, **kwargs): client = ClientConnection(*args, **kwargs) client.create_connection_ran = True return client async with serve(*args) as server: async with connect( get_uri(server), create_connection=create_connection ) as client: self.assertTrue(client.create_connection_ran) async def test_reconnect(self): """Client reconnects to server.""" iterations = 0 successful = 0 async def process_request(connection, request): nonlocal iterations iterations += 1 # Retriable errors if iterations == 1: await asyncio.sleep(3 * MS) elif iterations == 2: connection.transport.close() elif iterations == 3: return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") # Fatal error elif iterations == 6: return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with short_backoff_delay(): async for client in connect(get_uri(server), open_timeout=3 * MS): self.assertEqual(client.protocol.state.name, "OPEN") successful += 1 self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 402", ) self.assertEqual(iterations, 6) self.assertEqual(successful, 2) async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 def process_request(connection, request): nonlocal iteration iteration += 1 if iteration == 1: return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") def process_exception(exc): if isinstance(exc, InvalidStatus): if 500 <= exc.response.status_code < 600: return None if exc.response.status_code == 418: return Exception("🫖 💔 ☕️") self.fail("unexpected exception") async with serve(*args, process_request=process_request) as server: with self.assertRaises(Exception) as raised: async with short_backoff_delay(): async for _ in connect( get_uri(server), process_exception=process_exception ): self.fail("did not raise") self.assertEqual(iteration, 2) self.assertEqual( str(raised.exception), "🫖 💔 ☕️", ) async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" def process_request(connection, request): return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") def process_exception(exc): if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: raise Exception("🫖 💔 ☕️") self.fail("unexpected exception") async with serve(*args, process_request=process_request) as server: with self.assertRaises(Exception) as raised: async with short_backoff_delay(): async for _ in connect( get_uri(server), process_exception=process_exception ): self.fail("did not raise") self.assertEqual( str(raised.exception), "🫖 💔 ☕️", ) async def test_redirect(self): """Client follows redirect.""" def redirect(connection, request): if request.path == "/redirect": response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = "/" return response async with serve(*args, process_request=redirect) as server: async with connect(get_uri(server) + "/redirect") as client: self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect(self): """Client follows redirect to a secure URI on a different origin.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = get_uri(other_server) return response async with serve(*args, process_request=redirect) as server: async with serve(*args) as other_server: async with connect(get_uri(server)): self.assertFalse(server.connections) self.assertTrue(other_server.connections) async def test_redirect_limit(self): """Client stops following redirects after limit is reached.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = request.path return response async with serve(*args, process_request=redirect) as server: async with few_redirects(): with self.assertRaises(SecurityError) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "more than 2 redirects", ) async def test_redirect_with_explicit_host_port(self): """Client follows redirect with an explicit host / port.""" def redirect(connection, request): if request.path == "/redirect": response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = "/" return response async with serve(*args, process_request=redirect) as server: host, port = get_host_port(server) async with connect( "ws://overridden/redirect", host=host, port=port ) as client: self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect_with_explicit_host_port(self): """Client doesn't follow cross-origin redirect with an explicit host / port.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = "ws://other/" return response async with serve(*args, process_request=redirect) as server: host, port = get_host_port(server) with self.assertRaises(ValueError) as raised: async with connect("ws://overridden/", host=host, port=port): self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot follow cross-origin redirect to ws://other/ " "with an explicit host or port", ) async def test_redirect_with_existing_socket(self): """Client doesn't follow redirect when using a pre-existing socket.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = "/" return response async with serve(*args, process_request=redirect) as server: with socket.create_connection(get_host_port(server)) as sock: with self.assertRaises(ValueError) as raised: # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/redirect", sock=sock): self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot follow redirect to ws://invalid/ with a preexisting socket", ) async def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): async with connect("http://localhost"): # invalid scheme self.fail("did not raise") async def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): async with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") async def test_handshake_fails(self): """Client connects to server but the handshake fails.""" def remove_accept_header(self, request, response): del response.headers["Sec-WebSocket-Accept"] # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. async with serve(*args, process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: async with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "missing Sec-WebSocket-Accept header", ) async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: async with connect(f"ws://{host}:{port}", open_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out during opening handshake", ) async def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): self.transport.close() async with serve(*args, process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "did not receive a valid HTTP response", ) self.assertIsInstance(raised.exception.__cause__, EOFError) self.assertEqual( str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) async def test_http_response(self): """Client reads HTTP response.""" def http_response(connection, request): return connection.respond(http.HTTPStatus.OK, "👌") async with serve(*args, process_request=http_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual(raised.exception.response.status_code, 200) self.assertEqual(raised.exception.response.body.decode(), "👌") async def test_http_response_without_content_length(self): """Client reads HTTP response without a Content-Length header.""" def http_response(connection, request): response = connection.respond(http.HTTPStatus.OK, "👌") del response.headers["Content-Length"] return response async with serve(*args, process_request=http_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual(raised.exception.response.status_code, 200) self.assertEqual(raised.exception.response.body.decode(), "👌") async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" async def junk(reader, writer): await asyncio.sleep(MS) # wait for the client to send the handshake request writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") await reader.read(4096) # wait for the client to close the connection writer.close() server = await asyncio.start_server(junk, "localhost", 0) host, port = get_host_port(server) async with server: with self.assertRaises(InvalidMessage) as raised: async with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), "did not receive a valid HTTP response", ) self.assertIsInstance(raised.exception.__cause__, ValueError) self.assertEqual( str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) class SecureClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: host, port = get_host_port(server) async with connect( "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect( get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate is self-signed. async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), ) async def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. async with connect( get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" ): self.fail("did not raise") self.assertIn( "certificate verify failed: Hostname mismatch", str(raised.exception), ) async def test_cross_origin_redirect(self): """Client follows redirect to a secure URI on a different origin.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = get_uri(other_server) return response async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: async with serve(*args, ssl=SERVER_CONTEXT) as other_server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT): self.assertFalse(server.connections) self.assertTrue(other_server.connections) async def test_redirect_to_insecure_uri(self): """Client doesn't follow redirect from secure URI to non-secure URI.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = insecure_uri return response async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: with self.assertRaises(SecurityError) as raised: secure_uri = get_uri(server) insecure_uri = secure_uri.replace("wss://", "ws://") async with connect(secure_uri, ssl=CLIENT_CONTEXT): self.fail("did not raise") self.assertEqual( str(raised.exception), f"cannot follow redirect to non-secure URI {insecure_uri}", ) @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "socks5@51080" @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: async with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError with self.assertRaises(OSError) as raised: async with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) async def test_socks_proxy_connection_timeout(self): """Client times out while connecting to the SOCKS5 proxy.""" # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out during opening handshake", ) self.assertNumFlows(0) async def test_explicit_socks_proxy(self): """Client connects to server through a SOCKS5 proxy set explicitly.""" async with serve(*args) as server: async with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) async def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: async with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "proxy rejected connection: HTTP 407", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" async with serve(*args) as server: async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.protocol.state.name, "OPEN") [http_connect] = self.get_http_connects() self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" async with serve(*args) as server: async with connect(get_uri(server), user_agent_header=None) as client: self.assertEqual(client.protocol.state.name, "OPEN") [http_connect] = self.get_http_connects() self.assertNotIn(b"User-Agent", http_connect.request.headers) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) with self.assertRaises(InvalidProxyMessage) as raised: async with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( str(raised.exception), "did not receive a valid HTTP response from proxy", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) with self.assertRaises(InvalidProxyMessage) as raised: async with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( str(raised.exception), "did not receive a valid HTTP response from proxy", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port async def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" with self.assertRaises(OSError): async with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) async def test_http_proxy_connection_timeout(self): """Client times out while connecting to the HTTP proxy.""" # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out during opening handshake", ) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" async with serve(*args) as server: async with connect( get_uri(server), proxy_ssl=self.proxy_context, ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect( get_uri(server), ssl=CLIENT_CONTEXT, proxy_ssl=self.proxy_context, ) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" async with serve(*args) as server: # Pass an argument not prefixed with proxy_ for coverage. kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} async with connect( get_uri(server), proxy_ssl=self.proxy_context, proxy_server_hostname="overridden", **kwargs, ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The proxy certificate isn't trusted. async with connect("wss://example.com/"): self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate is self-signed. async with connect(get_uri(server), proxy_ssl=self.proxy_context): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), ) self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path): async with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_set_host_header(self): """Client sets the Host header to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with unix_serve(handler, path): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") async def test_cross_origin_redirect(self): """Client doesn't follows redirect to a URI on a different origin.""" def redirect(connection, request): response = connection.respond(http.HTTPStatus.FOUND, "") response.headers["Location"] = "ws://other/" return response with temp_unix_socket_path() as path: async with unix_serve(handler, path, process_request=redirect): with self.assertRaises(ValueError) as raised: async with unix_connect(path): self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot follow cross-origin redirect to ws://other/ with a Unix socket", ) async def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") async def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect( path, ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), "ssl argument is incompatible with a ws:// URI", ) async def test_secure_uri_without_ssl(self): """Client rejects ssl=None when URI is secure.""" with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( str(raised.exception), "ssl=None is incompatible with a wss:// URI", ) async def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" with self.assertRaises(ValueError) as raised: await connect( "ws://localhost/", proxy="http://localhost:8080", proxy_ssl=True, ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", ) async def test_https_proxy_without_ssl(self): """Client rejects proxy_ssl=None when proxy is HTTPS.""" with self.assertRaises(ValueError) as raised: await connect( "ws://localhost/", proxy="https://localhost:8080", proxy_ssl=None, ) self.assertEqual( str(raised.exception), "proxy_ssl=None is incompatible with an https:// proxy", ) async def test_unsupported_proxy(self): """Client rejects unsupported proxy.""" with self.assertRaises(InvalidProxy) as raised: async with connect("ws://example.com/", proxy="other://localhost:51080"): self.fail("did not raise") self.assertEqual( str(raised.exception), "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", ) async def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: await unix_connect() self.assertEqual( str(raised.exception), "no path and sock were specified", ) async def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: await unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock can not be specified at the same time", ) async def test_invalid_subprotocol(self): """Client rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: await connect("ws://localhost/", subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", ) async def test_unsupported_compression(self): """Client rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", ) websockets-15.0.1/tests/asyncio/test_connection.py000066400000000000000000001576671476212450300223600ustar00rootroot00000000000000import asyncio import contextlib import logging import socket import sys import unittest import uuid from unittest.mock import Mock, patch from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * from websockets.asyncio.connection import broadcast from websockets.exceptions import ( ConcurrencyError, ConnectionClosedError, ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol from ..utils import MS, AssertNoLogsMixin from .connection import InterceptingConnection from .utils import alist # Connection implements symmetrical behavior between clients and servers. # All tests run on the client side and the server side to validate this. class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): LOCAL = CLIENT REMOTE = SERVER async def asyncSetUp(self): loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() self.transport, self.connection = await loop.create_connection( lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), sock=socket_, ) self.remote_transport, self.remote_connection = await loop.create_connection( lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), sock=remote_socket, ) async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): """Check that a single frame was sent.""" # Let the remote side process messages. # Two runs of the event loop are required for answering pings. await asyncio.sleep(0) await asyncio.sleep(0) self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) async def assertFramesSent(self, frames): """Check that several frames were sent.""" # Let the remote side process messages. # Two runs of the event loop are required for answering pings. await asyncio.sleep(0) await asyncio.sleep(0) self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) async def assertNoFrameSent(self): """Check that no frame was sent.""" # Run the event loop twice for consistency with assertFrameSent. await asyncio.sleep(0) await asyncio.sleep(0) self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.asynccontextmanager async def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield await asyncio.sleep(MS) # let the remote side process messages @contextlib.asynccontextmanager async def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield await asyncio.sleep(MS) # let the remote side process messages @contextlib.asynccontextmanager async def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield await asyncio.sleep(MS) # let the remote side process messages @contextlib.asynccontextmanager async def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield await asyncio.sleep(MS) # let the remote side process messages # Test __aenter__ and __aexit__. async def test_aenter(self): """__aenter__ returns the connection itself.""" async with self.connection as connection: self.assertIs(connection, self.connection) async def test_aexit(self): """__aexit__ closes the connection with code 1000.""" async with self.connection: await self.assertNoFrameSent() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) async def test_exit_with_exception(self): """__exit__ with an exception closes the connection with code 1011.""" with self.assertRaises(RuntimeError): async with self.connection: raise RuntimeError await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) # Test __aiter__. async def test_aiter_text(self): """__aiter__ yields text messages.""" aiterator = aiter(self.connection) await self.remote_connection.send("😀") self.assertEqual(await anext(aiterator), "😀") await self.remote_connection.send("😀") self.assertEqual(await anext(aiterator), "😀") async def test_aiter_binary(self): """__aiter__ yields binary messages.""" aiterator = aiter(self.connection) await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") async def test_aiter_mixed(self): """__aiter__ yields a mix of text and binary messages.""" aiterator = aiter(self.connection) await self.remote_connection.send("😀") self.assertEqual(await anext(aiterator), "😀") await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") async def test_aiter_connection_closed_ok(self): """__aiter__ terminates after a normal closure.""" aiterator = aiter(self.connection) await self.remote_connection.close() with self.assertRaises(StopAsyncIteration): await anext(aiterator) async def test_aiter_connection_closed_error(self): """__aiter__ raises ConnectionClosedError after an error.""" aiterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): await anext(aiterator) # Test recv. async def test_recv_text(self): """recv receives a text message.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(), "😀") async def test_recv_binary(self): """recv receives a binary message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") async def test_recv_text_as_bytes(self): """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) async def test_recv_binary_as_text(self): """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") async def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) self.assertEqual(await self.connection.recv(), "😀😀") async def test_recv_fragmented_binary(self): """recv receives a fragmented binary message.""" await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") async def test_recv_connection_closed_ok(self): """recv raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): await self.connection.recv() async def test_recv_connection_closed_error(self): """recv raises ConnectionClosedError after an error.""" await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): await self.connection.recv() async def test_recv_non_utf8_text(self): """recv receives a non-UTF-8 text message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) with self.assertRaises(ConnectionClosedError): await self.connection.recv() await self.assertFrameSent( Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") ) async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), "cannot call recv while another coroutine " "is already running recv or recv_streaming", ) async def test_recv_during_recv_streaming(self): """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), "cannot call recv while another coroutine " "is already running recv or recv_streaming", ) async def test_recv_cancellation_before_receiving(self): """recv can be canceled before receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task # Running recv again receives the next message. await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): """recv cannot be canceled after receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) await asyncio.sleep(MS) recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task # Running recv again receives the complete message. gate.set_result(None) self.assertEqual(await self.connection.recv(), "⏳⌛️") # Test recv_streaming. async def test_recv_streaming_text(self): """recv_streaming receives a text message.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming()), ["😀"], ) async def test_recv_streaming_binary(self): """recv_streaming receives a binary message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual( await alist(self.connection.recv_streaming()), [b"\x01\x02\xfe\xff"], ) async def test_recv_streaming_text_as_bytes(self): """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) async def test_recv_streaming_binary_as_str(self): """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), ["😀"], ) async def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.connection.recv_streaming()), ["😀", "😀", ""], ) async def test_recv_streaming_fragmented_binary(self): """recv_streaming receives a fragmented binary message.""" await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.connection.recv_streaming()), [b"\x01\x02", b"\xfe\xff", b""], ) async def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): async for _ in self.connection.recv_streaming(): self.fail("did not raise") async def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): async for _ in self.connection.recv_streaming(): self.fail("did not raise") async def test_recv_streaming_non_utf8_text(self): """recv_streaming receives a non-UTF-8 text message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) with self.assertRaises(ConnectionClosedError): await alist(self.connection.recv_streaming()) await self.assertFrameSent( Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") ) async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming", ) async def test_recv_streaming_during_recv_streaming(self): """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another coroutine " r"is already running recv or recv_streaming", ) async def test_recv_streaming_cancellation_before_receiving(self): """recv_streaming can be canceled before receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task recv_streaming_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_streaming_task # Running recv_streaming again receives the next message. await self.remote_connection.send(["😀", "😀"]) self.assertEqual( await alist(self.connection.recv_streaming()), ["😀", "😀", ""], ) async def test_recv_streaming_cancellation_while_receiving(self): """recv_streaming cannot be canceled after receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) await asyncio.sleep(MS) recv_streaming_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_streaming_task gate.set_result(None) # Running recv_streaming again fails. with self.assertRaises(ConcurrencyError): await alist(self.connection.recv_streaming()) # Test send. async def test_send_text(self): """send sends a text message.""" await self.connection.send("😀") self.assertEqual(await self.remote_connection.recv(), "😀") async def test_send_binary(self): """send sends a binary message.""" await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") async def test_send_binary_from_str(self): """send sends a binary message from a str.""" await self.connection.send("😀", text=False) self.assertEqual(await self.remote_connection.recv(), "😀".encode()) async def test_send_text_from_bytes(self): """send sends a text message from bytes.""" await self.connection.send("😀".encode(), text=True) self.assertEqual(await self.remote_connection.recv(), "😀") async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) async def test_send_fragmented_binary(self): """send sends a fragmented binary message.""" await self.connection.send([b"\x01\x02", b"\xfe\xff"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), [b"\x01\x02", b"\xfe\xff", b""], ) async def test_send_fragmented_binary_from_str(self): """send sends a fragmented binary message from a str.""" await self.connection.send(["😀", "😀"], text=False) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀".encode(), "😀".encode(), b""], ) async def test_send_fragmented_text_from_bytes(self): """send sends a fragmented text message from bytes.""" await self.connection.send(["😀".encode(), "😀".encode()], text=True) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" async def fragments(): yield "😀" yield "😀" await self.connection.send(fragments()) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) async def test_send_async_fragmented_binary(self): """send sends a fragmented binary message asynchronously.""" async def fragments(): yield b"\x01\x02" yield b"\xfe\xff" await self.connection.send(fragments()) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), [b"\x01\x02", b"\xfe\xff", b""], ) async def test_send_async_fragmented_binary_from_str(self): """send sends a fragmented binary message from a str asynchronously.""" async def fragments(): yield "😀" yield "😀" await self.connection.send(fragments(), text=False) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀".encode(), "😀".encode(), b""], ) async def test_send_async_fragmented_text_from_bytes(self): """send sends a fragmented text message from bytes asynchronously.""" async def fragments(): yield "😀".encode() yield "😀".encode() await self.connection.send(fragments(), text=True) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( await alist(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): await self.connection.send("😀") async def test_send_connection_closed_error(self): """send raises ConnectionClosedError after an error.""" await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): await self.connection.send("😀") async def test_send_while_send_blocked(self): """send waits for a previous call to send to complete.""" # This test fails if the guard with fragmented_send_waiter is removed # from send() in the case when message is an Iterable. self.connection.pause_writing() asyncio.create_task(self.connection.send(["⏳", "⌛️"])) await asyncio.sleep(MS) await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) await asyncio.sleep(MS) await self.assertNoFrameSent() self.connection.resume_writing() await asyncio.sleep(MS) await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), Frame(Opcode.CONT, b"", fin=True), Frame(Opcode.TEXT, "✅".encode()), ] ) async def test_send_while_send_async_blocked(self): """send waits for a previous call to send to complete.""" # This test fails if the guard with fragmented_send_waiter is removed # from send() in the case when message is an AsyncIterable. self.connection.pause_writing() async def fragments(): yield "⏳" yield "⌛️" asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) await asyncio.sleep(MS) await self.assertNoFrameSent() self.connection.resume_writing() await asyncio.sleep(MS) await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), Frame(Opcode.CONT, b"", fin=True), Frame(Opcode.TEXT, "✅".encode()), ] ) async def test_send_during_send_async(self): """send waits for a previous call to send to complete.""" # This test fails if the guard with fragmented_send_waiter is removed # from send() in the case when message is an AsyncIterable. gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate yield "⌛️" asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) await asyncio.sleep(MS) await self.assertNoFrameSent() gate.set_result(None) await asyncio.sleep(MS) await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), Frame(Opcode.CONT, b"", fin=True), Frame(Opcode.TEXT, "✅".encode()), ] ) async def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" await self.connection.send([]) await self.connection.close() self.assertEqual(await alist(self.remote_connection), []) async def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" with self.assertRaises(TypeError): await self.connection.send(["😀", b"\xfe\xff"]) async def test_send_unsupported_iterable(self): """send raises TypeError when called with an iterable of unsupported type.""" with self.assertRaises(TypeError): await self.connection.send([None]) async def test_send_empty_async_iterable(self): """send does nothing when called with an empty async iterable.""" async def fragments(): return yield # pragma: no cover await self.connection.send(fragments()) await self.connection.close() self.assertEqual(await alist(self.remote_connection), []) async def test_send_mixed_async_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" async def fragments(): yield "😀" yield b"\xfe\xff" with self.assertRaises(TypeError): await self.connection.send(fragments()) async def test_send_unsupported_async_iterable(self): """send raises TypeError when called with an iterable of unsupported type.""" async def fragments(): yield None with self.assertRaises(TypeError): await self.connection.send(fragments()) async def test_send_dict(self): """send raises TypeError when called with a dict.""" with self.assertRaises(TypeError): await self.connection.send({"type": "object"}) async def test_send_unsupported_type(self): """send raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): await self.connection.send(None) # Test close. async def test_close(self): """close sends a close frame.""" await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) async def test_close_explicit_code_reason(self): """close sends a close frame with a given code and reason.""" await self.connection.close(CloseCode.GOING_AWAY, "bye!") await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) async def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_waits_for_connection_closed(self): """close waits for EOF before returning.""" if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") async with self.delay_eof_rcvd(MS): await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_no_timeout_waits_for_close_frame(self): """close without timeout waits for a close frame (then EOF) before returning.""" self.connection.close_timeout = None async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_no_timeout_waits_for_connection_closed(self): """close without timeout waits for EOF before returning.""" if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") self.connection.close_timeout = None async with self.delay_eof_rcvd(MS): await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): await self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") self.assertIsInstance(exc.__cause__, TimeoutError) async def test_close_timeout_waiting_for_connection_closed(self): """close times out if EOF isn't received.""" if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") async with self.drop_eof_rcvd(): await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_preserves_queued_messages(self): """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) await self.connection.close() await self.assertNoFrameSent() async def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(MS) await self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: await recv_task exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) async def test_close_during_send(self): """close fails the connection when called concurrently with send.""" gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate yield "⌛️" send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) asyncio.create_task(self.connection.close()) await asyncio.sleep(MS) gate.set_result(None) with self.assertRaises(ConnectionClosedError) as raised: await send_task exc = raised.exception self.assertEqual( str(exc), "sent 1011 (internal error) close during fragmented message; " "no close frame received", ) self.assertIsNone(exc.__cause__) # Test wait_closed. async def test_wait_closed(self): """wait_closed waits for the connection to close.""" wait_closed_task = asyncio.create_task(self.connection.wait_closed()) await asyncio.sleep(0) # let the event loop start wait_closed_task self.assertFalse(wait_closed_task.done()) await self.connection.close() self.assertTrue(wait_closed_task.done()) # Test ping. @patch("random.getrandbits", return_value=1918987876) async def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" await self.connection.ping() getrandbits.assert_called_once_with(32) await self.assertFrameSent(Frame(Opcode.PING, b"rand")) async def test_ping_explicit_text(self): """ping sends a ping frame with a payload provided as text.""" await self.connection.ping("ping") await self.assertFrameSent(Frame(Opcode.PING, b"ping")) async def test_ping_explicit_binary(self): """ping sends a ping frame with a payload provided as binary.""" await self.connection.ping(b"ping") await self.assertFrameSent(Frame(Opcode.PING, b"ping")) async def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") await self.remote_connection.pong("this") async with asyncio_timeout(MS): await pong_waiter async def test_acknowledge_canceled_ping(self): """ping is acknowledged by a pong with the same payload after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") pong_waiter.cancel() await self.remote_connection.pong("this") with self.assertRaises(asyncio.CancelledError): await pong_waiter async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") await self.remote_connection.pong("that") with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): await pong_waiter async def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") await self.connection.ping("that") await self.remote_connection.pong("that") async with asyncio_timeout(MS): await pong_waiter async def test_acknowledge_previous_canceled_ping(self): """ping is acknowledged by a pong for a later ping after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") pong_waiter_2 = await self.connection.ping("that") pong_waiter.cancel() await self.remote_connection.pong("that") async with asyncio_timeout(MS): await pong_waiter_2 with self.assertRaises(asyncio.CancelledError): await pong_waiter async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") self.assertEqual( str(raised.exception), "already waiting for a pong with the same data", ) await self.remote_connection.pong("idem") async with asyncio_timeout(MS): await pong_waiter await self.connection.ping("idem") # doesn't raise an exception async def test_ping_unsupported_type(self): """ping raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): await self.connection.ping([]) # Test pong. async def test_pong(self): """pong sends a pong frame.""" await self.connection.pong() await self.assertFrameSent(Frame(Opcode.PONG, b"")) async def test_pong_explicit_text(self): """pong sends a pong frame with a payload provided as text.""" await self.connection.pong("pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) async def test_pong_explicit_binary(self): """pong sends a pong frame with a payload provided as binary.""" await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) async def test_pong_unsupported_type(self): """pong raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): await self.connection.pong([]) # Test keepalive. @patch("random.getrandbits", return_value=1918987876) async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_task) self.assertEqual(self.connection.latency, 0) # 3 ms: keepalive() sends a ping frame. # 3.x ms: a pong frame is received. await asyncio.sleep(4 * MS) # 4 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) async def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_task) @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. # 4.x ms: a pong frame is dropped. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. # 6 ms: no pong frame is received; the connection remains open. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() await asyncio.sleep(MS) await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = MS self.connection.ping_timeout = 3 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 1 ms: keepalive() sends a ping frame. # 1.x ms: a pong frame is dropped. await asyncio.sleep(MS) # Exiting the context manager sleeps for 1 ms. # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. # 2.x ms: a pong frame is dropped. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: pong_waiter.set_exception(Exception("BOOM")) await asyncio.sleep(0) self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) # Test parameters. async def test_close_timeout(self): """close_timeout parameter configures close timeout.""" connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) async def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=None) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.low, None) async def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" connection = Connection( Protocol(self.LOCAL), max_queue=(4, 2), ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" connection = Connection( Protocol(self.LOCAL), write_limit=4096, ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" connection = Connection( Protocol(self.LOCAL), write_limit=(4096, 2048), ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) # Test attributes. async def test_id(self): """Connection has an id attribute.""" self.assertIsInstance(self.connection.id, uuid.UUID) async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) get_extra_info.assert_called_with("peername") async def test_state(self): """Connection has a state attribute.""" self.assertIs(self.connection.state, State.OPEN) async def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) async def test_response(self): """Connection has a response attribute.""" self.assertIsNone(self.connection.response) async def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) async def test_close_code(self): """Connection has a close_code attribute.""" self.assertIsNone(self.connection.close_code) async def test_close_reason(self): """Connection has a close_reason attribute.""" self.assertIsNone(self.connection.close_reason) # Test reporting of network errors. async def test_writing_in_data_received_fails(self): """Error when responding to incoming frames is correctly reported.""" # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() # Receive a ping. Responding with a pong will fail. await self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() cause = raised.exception.__cause__ self.assertEqual(str(cause), "Cannot call write() after write_eof()") self.assertIsInstance(cause, RuntimeError) async def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.pong() cause = raised.exception.__cause__ self.assertEqual(str(cause), "Cannot call write() after write_eof()") self.assertIsInstance(cause, RuntimeError) # Test safety nets — catching all exceptions in case of bugs. # Inject a fault in a random call in data_received(). # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" # Receive a message to trigger the fault. await self.remote_connection.send("😀") with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") exc = raised.exception self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) # Test broadcast. async def test_broadcast_text(self): """broadcast broadcasts a text message.""" broadcast([self.connection], "😀") await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) async def test_broadcast_text_reports_no_errors(self): """broadcast broadcasts a text message without raising exceptions.""" broadcast([self.connection], "😀", raise_exceptions=True) await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) async def test_broadcast_binary(self): """broadcast broadcasts a binary message.""" broadcast([self.connection], b"\x01\x02\xfe\xff") await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) async def test_broadcast_binary_reports_no_errors(self): """broadcast broadcasts a binary message without raising exceptions.""" broadcast([self.connection], b"\x01\x02\xfe\xff", raise_exceptions=True) await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) async def test_broadcast_no_clients(self): """broadcast does nothing when called with an empty list of clients.""" broadcast([], "😀") await self.assertNoFrameSent() async def test_broadcast_two_clients(self): """broadcast broadcasts a message to several clients.""" broadcast([self.connection, self.connection], "😀") await self.assertFramesSent( [ Frame(Opcode.TEXT, "😀".encode()), Frame(Opcode.TEXT, "😀".encode()), ] ) async def test_broadcast_skips_closed_connection(self): """broadcast ignores closed connections.""" await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() async def test_broadcast_skips_closing_connection(self): """broadcast ignores closing connections.""" async with self.delay_frames_rcvd(MS): close_task = asyncio.create_task(self.connection.close()) await asyncio.sleep(0) await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() await close_task async def test_broadcast_skips_connection_with_send_blocked(self): """broadcast logs a warning when a connection is blocked in send.""" gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.connection], "😀") self.assertEqual( [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) gate.set_result(None) await send_task @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) async def test_broadcast_reports_connection_with_send_blocked(self): """broadcast raises exceptions for connections blocked in send.""" gate = asyncio.get_running_loop().create_future() async def fragments(): yield "⏳" await gate send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) with self.assertRaises(ExceptionGroup) as raised: broadcast([self.connection], "😀", raise_exceptions=True) self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") exc = raised.exception.exceptions[0] self.assertEqual(str(exc), "sending a fragmented message") self.assertIsInstance(exc, ConcurrencyError) gate.set_result(None) await send_task async def test_broadcast_skips_connection_failing_to_send(self): """broadcast logs a warning when a connection fails to send.""" # Inject a fault by shutting down the transport for writing. self.transport.write_eof() with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.connection], "😀") self.assertEqual( [record.getMessage() for record in logs.records], [ "skipped broadcast: failed to write message: " "RuntimeError: Cannot call write() after write_eof()" ], ) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) async def test_broadcast_reports_connection_failing_to_send(self): """broadcast raises exceptions for connections failing to send.""" # Inject a fault by shutting down the transport for writing. self.transport.write_eof() with self.assertRaises(ExceptionGroup) as raised: broadcast([self.connection], "😀", raise_exceptions=True) self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") exc = raised.exception.exceptions[0] self.assertEqual(str(exc), "failed to write message") self.assertIsInstance(exc, RuntimeError) cause = exc.__cause__ self.assertEqual(str(cause), "Cannot call write() after write_eof()") self.assertIsInstance(cause, RuntimeError) async def test_broadcast_type_error(self): """broadcast raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): broadcast([self.connection], ["⏳", "⌛️"]) class ServerConnectionTests(ClientConnectionTests): LOCAL = SERVER REMOTE = CLIENT websockets-15.0.1/tests/asyncio/test_messages.py000066400000000000000000000603221476212450300220050ustar00rootroot00000000000000import asyncio import unittest import unittest.mock from websockets.asyncio.compatibility import aiter, anext from websockets.asyncio.messages import * from websockets.asyncio.messages import SimpleQueue from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from .utils import alist class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.queue = SimpleQueue() async def test_len(self): """__len__ returns queue length.""" self.assertEqual(len(self.queue), 0) self.queue.put(42) self.assertEqual(len(self.queue), 1) await self.queue.get() self.assertEqual(len(self.queue), 0) async def test_put_then_get(self): """get returns an item that is already put.""" self.queue.put(42) item = await self.queue.get() self.assertEqual(item, 42) async def test_get_then_put(self): """get returns an item when it is put.""" getter_task = asyncio.create_task(self.queue.get()) await asyncio.sleep(0) # let the task start self.queue.put(42) item = await getter_task self.assertEqual(item, 42) async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) item = await self.queue.get() self.assertEqual(item, 42) async def test_abort(self): """abort throws an exception in get.""" getter_task = asyncio.create_task(self.queue.get()) await asyncio.sleep(0) # let the task start self.queue.abort() with self.assertRaises(EOFError): await getter_task class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.pause = unittest.mock.Mock() self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get async def test_get_text_message_already_received(self): """get returns a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = await self.assembler.get() self.assertEqual(message, "café") async def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"tea")) message = await self.assembler.get() self.assertEqual(message, b"tea") async def test_get_text_message_not_received_yet(self): """get returns a text message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = await getter_task self.assertEqual(message, "café") async def test_get_binary_message_not_received_yet(self): """get returns a binary message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_BINARY, b"tea")) message = await getter_task self.assertEqual(message, b"tea") async def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await self.assembler.get() self.assertEqual(message, "café") async def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = await self.assembler.get() self.assertEqual(message, b"tea") async def test_get_fragmented_text_message_not_received_yet(self): """get reassembles a fragmented text message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await getter_task self.assertEqual(message, "café") async def test_get_fragmented_binary_message_not_received_yet(self): """get reassembles a fragmented binary message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = await getter_task self.assertEqual(message, b"tea") async def test_get_fragmented_text_message_being_received(self): """get reassembles a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await getter_task self.assertEqual(message, "café") async def test_get_fragmented_binary_message_being_received(self): """get reassembles a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = await getter_task self.assertEqual(message, b"tea") async def test_get_encoded_text_message(self): """get returns a text message without UTF-8 decoding.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = await self.assembler.get(decode=False) self.assertEqual(message, b"caf\xc3\xa9") async def test_get_decoded_binary_message(self): """get returns a binary message with UTF-8 decoding.""" self.assembler.put(Frame(OP_BINARY, b"tea")) message = await self.assembler.get(decode=True) self.assertEqual(message, "tea") async def test_get_resumes_reading(self): """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) # queue is above the low-water mark await self.assembler.get() self.resume.assert_not_called() # queue is at the low-water mark await self.assembler.get() self.resume.assert_called_once_with() # queue is below the low-water mark await self.assembler.get() self.resume.assert_called_once_with() async def test_get_does_not_resume_reading(self): """get does not resume reading when the low-water mark is unset.""" self.assembler.low = None self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) await self.assembler.get() await self.assembler.get() await self.assembler.get() self.resume.assert_not_called() async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) # let the event loop start getter_task getter_task.cancel() with self.assertRaises(asyncio.CancelledError): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = await self.assembler.get() self.assertEqual(message, "café") async def test_cancel_get_after_first_frame(self): """get can be canceled safely after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) # let the event loop start getter_task getter_task.cancel() with self.assertRaises(asyncio.CancelledError): await getter_task self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await self.assembler.get() self.assertEqual(message, "café") # Test get_iter async def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) async def test_get_iter_binary_message_already_received(self): """get_iter yields a binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"tea")) fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) async def test_get_iter_text_message_not_received_yet(self): """get_iter yields a text message when it is received.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) fragments = await getter_task self.assertEqual(fragments, ["café"]) async def test_get_iter_binary_message_not_received_yet(self): """get_iter yields a binary message when it is received.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_BINARY, b"tea")) fragments = await getter_task self.assertEqual(fragments, [b"tea"]) async def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) async def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = aiter(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assertEqual(await anext(iterator), "ca") self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = aiter(self.assembler.get_iter()) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assertEqual(await anext(iterator), b"t") self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") async def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = aiter(self.assembler.get_iter()) self.assertEqual(await anext(iterator), "ca") self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = aiter(self.assembler.get_iter()) self.assertEqual(await anext(iterator), b"t") self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") async def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) fragments = await alist(self.assembler.get_iter(decode=False)) self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) async def test_get_iter_decoded_binary_message(self): """get_iter yields a binary message with UTF-8 decoding.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) fragments = await alist(self.assembler.get_iter(decode=True)) self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) # queue is above the low-water mark await anext(iterator) self.resume.assert_not_called() # queue is at the low-water mark await anext(iterator) self.resume.assert_called_once_with() # queue is below the low-water mark await anext(iterator) self.resume.assert_called_once_with() async def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" self.assembler.low = None self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) await anext(iterator) await anext(iterator) await anext(iterator) self.resume.assert_not_called() async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) # let the event loop start getter_task getter_task.cancel() with self.assertRaises(asyncio.CancelledError): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) # let the event loop start getter_task getter_task.cancel() with self.assertRaises(asyncio.CancelledError): await getter_task self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) # Test put async def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" # queue is below the high-water mark self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() async def test_put_does_not_pause_reading(self): """put does not pause reading when the high-water mark is unset.""" self.assembler.high = None self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_not_called() # Test termination async def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" asyncio.get_running_loop().call_soon(self.assembler.close) with self.assertRaises(EOFError): await self.assembler.get() async def test_get_iter_fails_when_interrupted_by_close(self): """get_iter raises EOFError when close is called.""" asyncio.get_running_loop().call_soon(self.assembler.close) with self.assertRaises(EOFError): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") async def test_get_fails_after_close(self): """get raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): await self.assembler.get() async def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") async def test_get_queued_message_after_close(self): """get returns a message after close is called.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.close() message = await self.assembler.get() self.assertEqual(message, "café") async def test_get_iter_queued_message_after_close(self): """get_iter yields a message after close is called.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.close() fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) async def test_get_queued_fragmented_message_after_close(self): """get reassembles a fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assembler.close() self.assembler.close() message = await self.assembler.get() self.assertEqual(message, b"tea") async def test_get_iter_queued_fragmented_message_after_close(self): """get_iter yields a fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assembler.close() fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_partially_queued_fragmented_message_after_close(self): """get raises EOF on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() with self.assertRaises(EOFError): await self.assembler.get() async def test_get_iter_partially_queued_fragmented_message_after_close(self): """get_iter yields a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() fragments = [] with self.assertRaises(EOFError): async for fragment in self.assembler.get_iter(): fragments.append(fragment) self.assertEqual(fragments, [b"t", b"e"]) async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) async def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() self.assembler.close() # Test (non-)concurrency async def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits async def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 2) async def test_set_low_water_mark(self): """low sets the low-water and high-water marks.""" assembler = Assembler(low=5) self.assertEqual(assembler.low, 5) self.assertEqual(assembler.high, 20) async def test_set_high_and_low_water_marks(self): """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) async def test_unset_high_and_low_water_marks(self): """High-water and low-water marks are unset.""" assembler = Assembler() self.assertEqual(assembler.high, None) self.assertEqual(assembler.low, None) async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): Assembler(high=-1) async def test_set_invalid_low_water_mark(self): """low must be higher than high.""" with self.assertRaises(ValueError): Assembler(low=10, high=5) websockets-15.0.1/tests/asyncio/test_router.py000066400000000000000000000200631476212450300215140ustar00rootroot00000000000000import http import socket import sys import unittest from unittest.mock import patch from websockets.asyncio.client import connect, unix_connect from websockets.asyncio.router import * from websockets.exceptions import InvalidStatus from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path from .server import EvalShellMixin, get_uri, handler from .utils import alist try: from werkzeug.routing import Map, Rule except ImportError: pass async def echo(websocket, count): message = await websocket.recv() for _ in range(count): await websocket.send(message) @unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): # This is a small realistic example of werkzeug's basic URL routing # features: path matching, parameter extraction, and default values. async def test_router_matches_paths_and_extracts_parameters(self): """Router matches paths and extracts parameters.""" url_map = Map( [ Rule("/echo", defaults={"count": 1}, endpoint=echo), Rule("/echo/", endpoint=echo), ] ) async with route(url_map, "localhost", 0) as server: async with connect(get_uri(server) + "/echo") as client: await client.send("hello") messages = await alist(client) self.assertEqual(messages, ["hello"]) async with connect(get_uri(server) + "/echo/3") as client: await client.send("hello") messages = await alist(client) self.assertEqual(messages, ["hello", "hello", "hello"]) @property # avoids an import-time dependency on werkzeug def url_map(self): return Map( [ Rule("/", endpoint=handler), Rule("/r", redirect_to="/"), ] ) async def test_route_with_query_string(self): """Router ignores query strings when matching paths.""" async with route(self.url_map, "localhost", 0) as server: async with connect(get_uri(server) + "/?a=b") as client: await self.assertEval(client, "ws.request.path", "/?a=b") async def test_redirect(self): """Router redirects connections according to redirect_to.""" async with route(self.url_map, "localhost", 0) as server: async with connect(get_uri(server) + "/r") as client: await self.assertEval(client, "ws.request.path", "/") async def test_secure_redirect(self): """Router redirects connections to a wss:// URI when TLS is enabled.""" async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.request.path", "/") @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) async def test_force_secure_redirect(self): """Router redirects ws:// connections to a wss:// URI when ssl=True.""" async with route(self.url_map, "localhost", 0, ssl=True) as server: redirect_uri = get_uri(server, secure=True) with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server) + "/r"): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], redirect_uri + "/", ) @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) async def test_force_redirect_server_name(self): """Router redirects connections to the host declared in server_name.""" async with route(self.url_map, "localhost", 0, server_name="other") as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server) + "/r"): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], "ws://other/", ) async def test_not_found(self): """Router rejects requests to unknown paths with an HTTP 404 error.""" async with route(self.url_map, "localhost", 0) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server) + "/n"): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 404", ) async def test_process_request_function_returning_none(self): """Router supports a process_request function returning None.""" def process_request(ws, request): ws.process_request_ran = True async with route( self.url_map, "localhost", 0, process_request=process_request ) as server: async with connect(get_uri(server) + "/") as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_process_request_coroutine_returning_none(self): """Router supports a process_request coroutine returning None.""" async def process_request(ws, request): ws.process_request_ran = True async with route( self.url_map, "localhost", 0, process_request=process_request ) as server: async with connect(get_uri(server) + "/") as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_process_request_function_returning_response(self): """Router supports a process_request function returning a response.""" def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with route( self.url_map, "localhost", 0, process_request=process_request ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server) + "/"): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) async def test_process_request_coroutine_returning_response(self): """Router supports a process_request coroutine returning a response.""" async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with route( self.url_map, "localhost", 0, process_request=process_request ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server) + "/"): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) async def test_custom_router_factory(self): """Router supports a custom router factory.""" class MyRouter(Router): async def handler(self, connection): connection.my_router_ran = True return await super().handler(connection) async with route( self.url_map, "localhost", 0, create_router=MyRouter ) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.my_router_ran", "True") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_router_supports_unix_sockets(self): """Router supports Unix sockets.""" url_map = Map([Rule("/echo/", endpoint=echo)]) with temp_unix_socket_path() as path: async with unix_route(url_map, path): async with unix_connect(path, "ws://localhost/echo/3") as client: await client.send("hello") messages = await alist(client) self.assertEqual(messages, ["hello", "hello", "hello"]) websockets-15.0.1/tests/asyncio/test_server.py000066400000000000000000001022221476212450300215000ustar00rootroot00000000000000import asyncio import dataclasses import hmac import http import logging import socket import unittest from websockets.asyncio.client import connect, unix_connect from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout from websockets.asyncio.server import * from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, InvalidStatus, NegotiationError, ) from websockets.http11 import Request, Response from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, AssertNoLogsMixin, temp_unix_socket_path, ) from .server import ( EvalShellMixin, args, get_host_port, get_uri, handler, ) class ServerTests(EvalShellMixin, AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_connection_handler_returns(self): """Connection handler returns.""" async with serve(*args) as server: async with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( str(raised.exception), "received 1000 (OK); then sent 1000 (OK)", ) async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" async with serve(*args) as server: async with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() self.assertEqual( str(raised.exception), "received 1011 (internal error); then sent 1011 (internal error)", ) async def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() async with serve(handler, sock=sock): async with connect(f"ws://{host}:{port}/") as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_select_subprotocol(self): """Server selects a subprotocol with the select_subprotocol callable.""" def select_subprotocol(ws, subprotocols): ws.select_subprotocol_ran = True assert "chat" in subprotocols return "chat" async with serve( *args, subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: async with connect(get_uri(server), subprotocols=["chat"]) as client: await self.assertEval(client, "ws.select_subprotocol_ran", "True") await self.assertEval(client, "ws.subprotocol", "chat") async def test_select_subprotocol_rejects_handshake(self): """Server rejects handshake if select_subprotocol raises NegotiationError.""" def select_subprotocol(ws, subprotocols): raise NegotiationError async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 400", ) async def test_select_subprotocol_raises_exception(self): """Server returns an error if select_subprotocol raises an exception.""" def select_subprotocol(ws, subprotocols): raise RuntimeError async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) async def test_compression_is_enabled(self): """Server enables compression by default.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", "['PerMessageDeflate']", ) async def test_disable_compression(self): """Server disables compression.""" async with serve(*args, compression=None) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") async def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_async_process_request_returns_none(self): """Server runs async process_request and continues the handshake.""" async def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async def handler(ws): self.fail("handler must not run") with self.assertNoLogs("websockets", logging.ERROR): async with serve( handler, *args[1:], process_request=process_request ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async def handler(ws): self.fail("handler must not run") with self.assertNoLogs("websockets", logging.ERROR): async with serve( handler, *args[1:], process_request=process_request ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): raise RuntimeError("BOOM") with self.assertLogs("websockets", logging.ERROR) as logs: async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): raise RuntimeError("BOOM") with self.assertLogs("websockets", logging.ERROR) as logs: async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) self.assertIsInstance(response, Response) ws.process_response_ran = True async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_async_process_response_returns_none(self): """Server runs async process_response but keeps the handshake response.""" async def process_response(ws, request, response): self.assertIsInstance(request, Request) self.assertIsInstance(response, Response) ws.process_response_ran = True async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_process_response_modifies_response(self): """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_modifies_response(self): """Server runs async process_response and modifies the handshake response.""" async def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_replaces_response(self): """Server runs process_response and replaces the handshake response.""" def process_response(ws, request, response): headers = response.headers.copy() headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_replaces_response(self): """Server runs async process_response and replaces the handshake response.""" async def process_response(ws, request, response): headers = response.headers.copy() headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): raise RuntimeError("BOOM") with self.assertLogs("websockets", logging.ERROR) as logs: async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): raise RuntimeError("BOOM") with self.assertLogs("websockets", logging.ERROR) as logs: async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) async def test_override_server(self): """Server can override Server header with server_header.""" async with serve(*args, server_header="Neo") as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.response.headers['Server']", "Neo") async def test_remove_server(self): """Server can remove Server header with server_header.""" async with serve(*args, server_header=None) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, "'Server' in ws.response.headers", "False" ) async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" async with serve(*args, ping_interval=MS) as server: async with connect(get_uri(server)) as client: await client.send("ws.latency") latency = eval(await client.recv()) self.assertEqual(latency, 0) await asyncio.sleep(2 * MS) await client.send("ws.latency") latency = eval(await client.recv()) self.assertGreater(latency, 0) async def test_disable_keepalive(self): """Server disables keepalive.""" async with serve(*args, ping_interval=None) as server: async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) await client.send("ws.latency") latency = eval(await client.recv()) self.assertEqual(latency, 0) async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") async with serve(*args, logger=logger) as server: self.assertEqual(server.logger.name, logger.name) async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" def create_connection(*args, **kwargs): server = ServerConnection(*args, **kwargs) server.create_connection_ran = True return server async with serve(*args, create_connection=create_connection) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.create_connection_ran", "True") async def test_connections(self): """Server provides a connections property.""" async with serve(*args) as server: self.assertEqual(server.connections, set()) async with connect(get_uri(server)) as client: self.assertEqual(len(server.connections), 1) ws_id = str(next(iter(server.connections)).id) await self.assertEval(client, "ws.id", ws_id) self.assertEqual(server.connections, set()) async def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] async with serve(*args, process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 400", ) async def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" async with serve(*args, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: writer.close() async def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" async with serve(*args) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() async def test_junk_handshake(self): """Server closes the connection when receiving non-HTTP request from client.""" with self.assertLogs("websockets", logging.ERROR) as logs: async with serve(*args) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.write(b"HELO relay.invalid\r\n") try: # Wait for the server to close the connection. self.assertEqual(await reader.read(4096), b"") finally: writer.close() self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["did not receive a valid HTTP request"], ) self.assertEqual( [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) async def test_close_server_rejects_connecting_connections(self): """Server rejects connecting connections with HTTP 503 when closing.""" async def process_request(ws, _request): while ws.server.is_serving(): await asyncio.sleep(0) # pragma: no cover async with serve(*args, process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 503", ) async def test_close_server_closes_open_connections(self): """Server closes open connections with close code 1001 when closing.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close() with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( str(raised.exception), "received 1001 (going away); then sent 1001 (going away)", ) async def test_close_server_keeps_connections_open(self): """Server waits for client to close open connections when closing.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close(close_connections=False) # Server cannot receive new connections. await asyncio.sleep(0) self.assertFalse(server.sockets) # The server waits for the client to close the connection. with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): await server.wait_closed() # Once the client closes the connection, the server terminates. await client.close() async with asyncio_timeout(MS): await server.wait_closed() async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" async with serve(*args) as server: async with connect(get_uri(server) + "/delay") as client: # Delay termination of connection handler. await client.send(str(3 * MS)) server.close() # The server waits for the connection handler to terminate. with self.assertRaises(TimeoutError): async with asyncio_timeout(2 * MS): await server.wait_closed() # Set a large timeout here, else the test becomes flaky. async with asyncio_timeout(5 * MS): await server.wait_closed() SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" async with serve(*args, ssl=SERVER_CONTEXT, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: writer.close() async def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path): async with unix_connect(path) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: await unix_serve(handler) self.assertEqual( str(raised.exception), "path was not specified, and no sock specified", ) async def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: await unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock can not be specified at the same time", ) async def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: await serve(*args, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", ) async def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: await serve(*args, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", ) class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" async with serve( *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") async def test_missing_authorization(self): """basic_auth rejects client without credentials.""" async with serve( *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) async def test_unsupported_authorization(self): """basic_auth rejects client with unsupported credentials.""" async with serve( *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect( get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) async def test_authorization_with_unknown_username(self): """basic_auth rejects client with unknown username.""" async with serve( *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect( get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) async def test_authorization_with_incorrect_password(self): """basic_auth rejects client with incorrect password.""" async with serve( *args, process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: async with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) async def test_list_of_credentials(self): """basic_auth accepts a list of hard coded credentials.""" async with serve( *args, process_request=basic_auth( credentials=[ ("hello", "iloveyou"), ("bye", "youloveme"), ] ), ) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: await self.assertEval(client, "ws.username", "bye") async def test_check_credentials_function(self): """basic_auth accepts a check_credentials function.""" def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") async with serve( *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") async def test_check_credentials_coroutine(self): """basic_auth accepts a check_credentials coroutine.""" async def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") async with serve( *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") async def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), "provide either credentials or check_credentials", ) async def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover ) self.assertEqual( str(raised.exception), "provide either credentials or check_credentials", ) async def test_bad_credentials(self): """basic_auth receives an unsupported credentials argument.""" with self.assertRaises(TypeError) as raised: basic_auth(credentials=42) self.assertEqual( str(raised.exception), "invalid credentials argument: 42", ) async def test_bad_list_of_credentials(self): """basic_auth receives an unsupported credentials argument.""" with self.assertRaises(TypeError) as raised: basic_auth(credentials=[42]) self.assertEqual( str(raised.exception), "invalid credentials argument: [42]", ) websockets-15.0.1/tests/asyncio/utils.py000066400000000000000000000002021476212450300202660ustar00rootroot00000000000000async def alist(async_iterable): items = [] async for item in async_iterable: items.append(item) return items websockets-15.0.1/tests/extensions/000077500000000000000000000000001476212450300173145ustar00rootroot00000000000000websockets-15.0.1/tests/extensions/__init__.py000066400000000000000000000000001476212450300214130ustar00rootroot00000000000000websockets-15.0.1/tests/extensions/test_base.py000066400000000000000000000017631476212450300216460ustar00rootroot00000000000000import unittest from websockets.extensions.base import * from websockets.frames import Frame, Opcode class ExtensionTests(unittest.TestCase): def test_encode(self): with self.assertRaises(NotImplementedError): Extension().encode(Frame(Opcode.TEXT, b"")) def test_decode(self): with self.assertRaises(NotImplementedError): Extension().decode(Frame(Opcode.TEXT, b"")) class ClientExtensionFactoryTests(unittest.TestCase): def test_get_request_params(self): with self.assertRaises(NotImplementedError): ClientExtensionFactory().get_request_params() def test_process_response_params(self): with self.assertRaises(NotImplementedError): ClientExtensionFactory().process_response_params([], []) class ServerExtensionFactoryTests(unittest.TestCase): def test_process_request_params(self): with self.assertRaises(NotImplementedError): ServerExtensionFactory().process_request_params([], []) websockets-15.0.1/tests/extensions/test_permessage_deflate.py000066400000000000000000001026261476212450300245530ustar00rootroot00000000000000import dataclasses import os import unittest from websockets.exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, NegotiationError, PayloadTooBig, ProtocolError, ) from websockets.extensions.permessage_deflate import * from websockets.frames import ( OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Close, CloseCode, Frame, ) from .utils import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory class PerMessageDeflateTestsMixin: def assertExtensionEqual(self, extension1, extension2): self.assertEqual( extension1.remote_no_context_takeover, extension2.remote_no_context_takeover, ) self.assertEqual( extension1.local_no_context_takeover, extension2.local_no_context_takeover, ) self.assertEqual( extension1.remote_max_window_bits, extension2.remote_max_window_bits, ) self.assertEqual( extension1.local_max_window_bits, extension2.local_max_window_bits, ) class PerMessageDeflateTests(unittest.TestCase, PerMessageDeflateTestsMixin): def setUp(self): # Set up an instance of the permessage-deflate extension with the most # common settings. Since the extension is symmetrical, this instance # may be used for testing both encoding and decoding. self.extension = PerMessageDeflate(False, False, 15, 15) def test_name(self): assert self.extension.name == "permessage-deflate" def test_repr(self): self.assertExtensionEqual(eval(repr(self.extension)), self.extension) # Control frames aren't encoded or decoded. def test_no_encode_decode_ping_frame(self): frame = Frame(OP_PING, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_pong_frame(self): frame = Frame(OP_PONG, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): frame = Frame(OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "").serialize()) self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) enc_frame = self.extension.encode(frame) self.assertEqual( enc_frame, dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00"), ) dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) def test_encode_decode_binary_frame(self): frame = Frame(OP_BINARY, b"tea") enc_frame = self.extension.encode(frame) self.assertEqual( enc_frame, dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00"), ) dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): frame1 = Frame(OP_TEXT, "café".encode(), fin=False) frame2 = Frame(OP_CONT, " & ".encode(), fin=False) frame3 = Frame(OP_CONT, "croissants".encode()) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) enc_frame3 = self.extension.encode(frame3) self.assertEqual( enc_frame1, dataclasses.replace( frame1, rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff" ), ) self.assertEqual( enc_frame2, dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff"), ) self.assertEqual( enc_frame3, dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) dec_frame3 = self.extension.decode(enc_frame3) self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) self.assertEqual(dec_frame3, frame3) def test_encode_decode_fragmented_binary_frame(self): frame1 = Frame(OP_TEXT, b"tea ", fin=False) frame2 = Frame(OP_CONT, b"time") enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) self.assertEqual( enc_frame1, dataclasses.replace( frame1, rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff" ), ) self.assertEqual( enc_frame2, dataclasses.replace(frame2, data=b"*\xc9\xccM\x05\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) def test_encode_decode_large_frame(self): # There is a separate code path that avoids copying data # when frames are larger than 2kB. Test it for coverage. frame = Frame(OP_BINARY, os.urandom(4096)) enc_frame = self.extension.encode(frame) dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_binary_frame(self): frame = Frame(OP_TEXT, b"tea") # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): frame1 = Frame(OP_TEXT, "café".encode(), fin=False) frame2 = Frame(OP_CONT, " & ".encode(), fin=False) frame3 = Frame(OP_CONT, "croissants".encode()) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) dec_frame3 = self.extension.decode(frame3) self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) self.assertEqual(dec_frame3, frame3) def test_no_decode_fragmented_binary_frame(self): frame1 = Frame(OP_TEXT, b"tea ", fin=False) frame2 = Frame(OP_CONT, b"time") dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") dec_frame1 = self.extension.decode(enc_frame1) self.assertEqual(dec_frame1, frame) with self.assertRaises(ProtocolError): self.extension.decode(enc_frame2) def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") self.assertEqual(enc_frame2.data, b"JNL;\xbc\x12\x00") dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) self.assertEqual(dec_frame1, frame) self.assertEqual(dec_frame2, frame) # Compression settings can be customized. def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) frame = Frame(OP_TEXT, "café".encode()) enc_frame = extension.encode(frame) self.assertEqual( enc_frame, dataclasses.replace( frame, rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00", # not compressed ), ) # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): frame = Frame(OP_TEXT, ("a" * 20).encode()) enc_frame = self.extension.encode(frame) self.assertEqual(enc_frame.data, b"JL\xc4\x04\x00\x00") with self.assertRaises(PayloadTooBig): self.extension.decode(enc_frame, max_size=10) class ClientPerMessageDeflateFactoryTests( unittest.TestCase, PerMessageDeflateTestsMixin ): def test_name(self): assert ClientPerMessageDeflateFactory.name == "permessage-deflate" def test_init(self): for config in [ (False, False, 8, None), # server_max_window_bits ≥ 8 (False, True, 15, None), # server_max_window_bits ≤ 15 (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, True), # client_max_window_bits (False, False, None, None, {"memLevel": 4}), ]: with self.subTest(config=config): # This does not raise an exception. ClientPerMessageDeflateFactory(*config) def test_init_error(self): for config in [ (False, False, 7, 8), # server_max_window_bits < 8 (False, True, 8, 7), # client_max_window_bits < 8 (True, False, 16, 15), # server_max_window_bits > 15 (True, True, 15, 16), # client_max_window_bits > 15 (False, False, True, None), # server_max_window_bits (False, False, None, None, {"wbits": 11}), ]: with self.subTest(config=config): with self.assertRaises(ValueError): ClientPerMessageDeflateFactory(*config) def test_get_request_params(self): for config, result in [ # Test without any parameter ( (False, False, None, None), [], ), # Test server_no_context_takeover ( (True, False, None, None), [("server_no_context_takeover", None)], ), # Test client_no_context_takeover ( (False, True, None, None), [("client_no_context_takeover", None)], ), # Test server_max_window_bits ( (False, False, 10, None), [("server_max_window_bits", "10")], ), # Test client_max_window_bits ( (False, False, None, 10), [("client_max_window_bits", "10")], ), ( (False, False, None, True), [("client_max_window_bits", None)], ), # Test all parameters together ( (True, True, 12, 12), [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "12"), ("client_max_window_bits", "12"), ], ), ]: with self.subTest(config=config): factory = ClientPerMessageDeflateFactory(*config) self.assertEqual(factory.get_request_params(), result) def test_process_response_params(self): for config, response_params, result in [ # Test without any parameter ( (False, False, None, None), [], (False, False, 15, 15), ), ( (False, False, None, None), [("unknown", None)], InvalidParameterName, ), # Test server_no_context_takeover ( (False, False, None, None), [("server_no_context_takeover", None)], (True, False, 15, 15), ), ( (True, False, None, None), [], NegotiationError, ), ( (True, False, None, None), [("server_no_context_takeover", None)], (True, False, 15, 15), ), ( (True, False, None, None), [("server_no_context_takeover", None)] * 2, DuplicateParameter, ), ( (True, False, None, None), [("server_no_context_takeover", "42")], InvalidParameterValue, ), # Test client_no_context_takeover ( (False, False, None, None), [("client_no_context_takeover", None)], (False, True, 15, 15), ), ( (False, True, None, None), [], (False, True, 15, 15), ), ( (False, True, None, None), [("client_no_context_takeover", None)], (False, True, 15, 15), ), ( (False, True, None, None), [("client_no_context_takeover", None)] * 2, DuplicateParameter, ), ( (False, True, None, None), [("client_no_context_takeover", "42")], InvalidParameterValue, ), # Test server_max_window_bits ( (False, False, None, None), [("server_max_window_bits", "7")], NegotiationError, ), ( (False, False, None, None), [("server_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, None, None), [("server_max_window_bits", "16")], NegotiationError, ), ( (False, False, 12, None), [], NegotiationError, ), ( (False, False, 12, None), [("server_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, 12, None), [("server_max_window_bits", "12")], (False, False, 12, 15), ), ( (False, False, 12, None), [("server_max_window_bits", "13")], NegotiationError, ), ( (False, False, 12, None), [("server_max_window_bits", "12")] * 2, DuplicateParameter, ), ( (False, False, 12, None), [("server_max_window_bits", "42")], InvalidParameterValue, ), # Test client_max_window_bits ( (False, False, None, None), [("client_max_window_bits", "10")], NegotiationError, ), ( (False, False, None, True), [], (False, False, 15, 15), ), ( (False, False, None, True), [("client_max_window_bits", "7")], NegotiationError, ), ( (False, False, None, True), [("client_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, True), [("client_max_window_bits", "16")], NegotiationError, ), ( (False, False, None, 12), [], (False, False, 15, 12), ), ( (False, False, None, 12), [("client_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, 12), [("client_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, None, 12), [("client_max_window_bits", "13")], NegotiationError, ), ( (False, False, None, 12), [("client_max_window_bits", "12")] * 2, DuplicateParameter, ), ( (False, False, None, 12), [("client_max_window_bits", "42")], InvalidParameterValue, ), # Test all parameters together ( (True, True, 12, 12), [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (False, False, None, True), [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (True, True, 12, 12), [ ("server_no_context_takeover", None), ("server_max_window_bits", "12"), ], (True, True, 12, 12), ), ]: with self.subTest(config=config, response_params=response_params): factory = ClientPerMessageDeflateFactory(*config) if isinstance(result, type) and issubclass(result, Exception): with self.assertRaises(result): factory.process_response_params(response_params, []) else: extension = factory.process_response_params(response_params, []) expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) def test_process_response_params_deduplication(self): factory = ClientPerMessageDeflateFactory(False, False, None, None) with self.assertRaises(NegotiationError): factory.process_response_params( [], [PerMessageDeflate(False, False, 15, 15)] ) def test_enable_client_permessage_deflate(self): for extensions, ( expected_len, expected_position, expected_compress_settings, ) in [ ( None, (1, 0, {"memLevel": 5}), ), ( [], (1, 0, {"memLevel": 5}), ), ( [ClientNoOpExtensionFactory()], (2, 1, {"memLevel": 5}), ), ( [ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7})], (1, 0, {"memLevel": 7}), ), ( [ ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ClientNoOpExtensionFactory(), ], (2, 0, {"memLevel": 7}), ), ( [ ClientNoOpExtensionFactory(), ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): extensions = enable_client_permessage_deflate(extensions) self.assertEqual(len(extensions), expected_len) extension = extensions[expected_position] self.assertIsInstance(extension, ClientPerMessageDeflateFactory) self.assertEqual( extension.compress_settings, expected_compress_settings, ) class ServerPerMessageDeflateFactoryTests( unittest.TestCase, PerMessageDeflateTestsMixin ): def test_name(self): assert ServerPerMessageDeflateFactory.name == "permessage-deflate" def test_init(self): for config in [ (False, False, 8, None), # server_max_window_bits ≥ 8 (False, True, 15, None), # server_max_window_bits ≤ 15 (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, None, {"memLevel": 4}), (False, False, None, 12, {}, True), # require_client_max_window_bits ]: with self.subTest(config=config): # This does not raise an exception. ServerPerMessageDeflateFactory(*config) def test_init_error(self): for config in [ (False, False, 7, 8), # server_max_window_bits < 8 (False, True, 8, 7), # client_max_window_bits < 8 (True, False, 16, 15), # server_max_window_bits > 15 (True, True, 15, 16), # client_max_window_bits > 15 (False, False, None, True), # client_max_window_bits (False, False, True, None), # server_max_window_bits (False, False, None, None, {"wbits": 11}), (False, False, None, None, {}, True), # require_client_max_window_bits ]: with self.subTest(config=config): with self.assertRaises(ValueError): ServerPerMessageDeflateFactory(*config) def test_process_request_params(self): # Parameters in result appear swapped vs. config because the order is # (remote, local) vs. (server, client). for config, request_params, response_params, result in [ # Test without any parameter ( (False, False, None, None), [], [], (False, False, 15, 15), ), ( (False, False, None, None), [("unknown", None)], None, InvalidParameterName, ), # Test server_no_context_takeover ( (False, False, None, None), [("server_no_context_takeover", None)], [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), [], [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), [("server_no_context_takeover", None)], [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), [("server_no_context_takeover", None)] * 2, None, DuplicateParameter, ), ( (True, False, None, None), [("server_no_context_takeover", "42")], None, InvalidParameterValue, ), # Test client_no_context_takeover ( (False, False, None, None), [("client_no_context_takeover", None)], [("client_no_context_takeover", None)], # doesn't matter (True, False, 15, 15), ), ( (False, True, None, None), [], [("client_no_context_takeover", None)], (True, False, 15, 15), ), ( (False, True, None, None), [("client_no_context_takeover", None)], [("client_no_context_takeover", None)], # doesn't matter (True, False, 15, 15), ), ( (False, True, None, None), [("client_no_context_takeover", None)] * 2, None, DuplicateParameter, ), ( (False, True, None, None), [("client_no_context_takeover", "42")], None, InvalidParameterValue, ), # Test server_max_window_bits ( (False, False, None, None), [("server_max_window_bits", "7")], None, NegotiationError, ), ( (False, False, None, None), [("server_max_window_bits", "10")], [("server_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, None), [("server_max_window_bits", "16")], None, NegotiationError, ), ( (False, False, 12, None), [], [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), [("server_max_window_bits", "10")], [("server_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, 12, None), [("server_max_window_bits", "12")], [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), [("server_max_window_bits", "13")], [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), [("server_max_window_bits", "12")] * 2, None, DuplicateParameter, ), ( (False, False, 12, None), [("server_max_window_bits", "42")], None, InvalidParameterValue, ), # Test client_max_window_bits ( (False, False, None, None), [("client_max_window_bits", None)], [], (False, False, 15, 15), ), ( (False, False, None, None), [("client_max_window_bits", "7")], None, InvalidParameterValue, ), ( (False, False, None, None), [("client_max_window_bits", "10")], [("client_max_window_bits", "10")], # doesn't matter (False, False, 10, 15), ), ( (False, False, None, None), [("client_max_window_bits", "16")], None, InvalidParameterValue, ), ( (False, False, None, 12), [], [], (False, False, 15, 15), ), ( (False, False, None, 12, {}, True), [], None, NegotiationError, ), ( (False, False, None, 12), [("client_max_window_bits", None)], [("client_max_window_bits", "12")], (False, False, 12, 15), ), ( (False, False, None, 12), [("client_max_window_bits", "10")], [("client_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, None, 12), [("client_max_window_bits", "12")], [("client_max_window_bits", "12")], # doesn't matter (False, False, 12, 15), ), ( (False, False, None, 12), [("client_max_window_bits", "13")], [("client_max_window_bits", "12")], # doesn't matter (False, False, 12, 15), ), ( (False, False, None, 12), [("client_max_window_bits", "12")] * 2, None, DuplicateParameter, ), ( (False, False, None, 12), [("client_max_window_bits", "42")], None, InvalidParameterValue, ), # Test all parameters together ( (True, True, 12, 12), [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (False, False, None, None), [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "10"), ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (True, True, 12, 12), [("client_max_window_bits", None)], [ ("server_no_context_takeover", None), ("client_no_context_takeover", None), ("server_max_window_bits", "12"), ("client_max_window_bits", "12"), ], (True, True, 12, 12), ), ]: with self.subTest( config=config, request_params=request_params, response_params=response_params, ): factory = ServerPerMessageDeflateFactory(*config) if isinstance(result, type) and issubclass(result, Exception): with self.assertRaises(result): factory.process_request_params(request_params, []) else: params, extension = factory.process_request_params( request_params, [] ) self.assertEqual(params, response_params) expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) def test_process_response_params_deduplication(self): factory = ServerPerMessageDeflateFactory(False, False, None, None) with self.assertRaises(NegotiationError): factory.process_request_params( [], [PerMessageDeflate(False, False, 15, 15)] ) def test_enable_server_permessage_deflate(self): for extensions, ( expected_len, expected_position, expected_compress_settings, ) in [ ( None, (1, 0, {"memLevel": 5}), ), ( [], (1, 0, {"memLevel": 5}), ), ( [ServerNoOpExtensionFactory()], (2, 1, {"memLevel": 5}), ), ( [ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7})], (1, 0, {"memLevel": 7}), ), ( [ ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ServerNoOpExtensionFactory(), ], (2, 0, {"memLevel": 7}), ), ( [ ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): extensions = enable_server_permessage_deflate(extensions) self.assertEqual(len(extensions), expected_len) extension = extensions[expected_position] self.assertIsInstance(extension, ServerPerMessageDeflateFactory) self.assertEqual( extension.compress_settings, expected_compress_settings, ) websockets-15.0.1/tests/extensions/utils.py000066400000000000000000000047651476212450300210420ustar00rootroot00000000000000import dataclasses from websockets.exceptions import NegotiationError class OpExtension: name = "x-op" def __init__(self, op=None): self.op = op def decode(self, frame, *, max_size=None): return frame # pragma: no cover def encode(self, frame): return frame # pragma: no cover def __eq__(self, other): return isinstance(other, OpExtension) and self.op == other.op class ClientOpExtensionFactory: name = "x-op" def __init__(self, op=None): self.op = op def get_request_params(self): return [("op", self.op)] def process_response_params(self, params, accepted_extensions): if params != [("op", self.op)]: raise NegotiationError() return OpExtension(self.op) class ServerOpExtensionFactory: name = "x-op" def __init__(self, op=None): self.op = op def process_request_params(self, params, accepted_extensions): if params != [("op", self.op)]: raise NegotiationError() return [("op", self.op)], OpExtension(self.op) class NoOpExtension: name = "x-no-op" def __repr__(self): return "NoOpExtension()" def decode(self, frame, *, max_size=None): return frame def encode(self, frame): return frame class ClientNoOpExtensionFactory: name = "x-no-op" def get_request_params(self): return [] def process_response_params(self, params, accepted_extensions): if params: raise NegotiationError() return NoOpExtension() class ServerNoOpExtensionFactory: name = "x-no-op" def __init__(self, params=None): self.params = params or [] def process_request_params(self, params, accepted_extensions): return self.params, NoOpExtension() class Rsv2Extension: name = "x-rsv2" def decode(self, frame, *, max_size=None): assert frame.rsv2 return dataclasses.replace(frame, rsv2=False) def encode(self, frame): assert not frame.rsv2 return dataclasses.replace(frame, rsv2=True) def __eq__(self, other): return isinstance(other, Rsv2Extension) class ClientRsv2ExtensionFactory: name = "x-rsv2" def get_request_params(self): return [] def process_response_params(self, params, accepted_extensions): return Rsv2Extension() class ServerRsv2ExtensionFactory: name = "x-rsv2" def process_request_params(self, params, accepted_extensions): return [], Rsv2Extension() websockets-15.0.1/tests/legacy/000077500000000000000000000000001476212450300163615ustar00rootroot00000000000000websockets-15.0.1/tests/legacy/__init__.py000066400000000000000000000004031476212450300204670ustar00rootroot00000000000000from __future__ import annotations import warnings with warnings.catch_warnings(): # Suppress DeprecationWarning raised by websockets.legacy. warnings.filterwarnings("ignore", category=DeprecationWarning) import websockets.legacy # noqa: F401 websockets-15.0.1/tests/legacy/test_auth.py000066400000000000000000000174161476212450300207440ustar00rootroot00000000000000import hmac import unittest import urllib.error from websockets.headers import build_authorization_basic from websockets.legacy.auth import * from websockets.legacy.auth import is_credentials from websockets.legacy.exceptions import InvalidStatusCode from .test_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase class AuthTests(unittest.TestCase): def test_is_credentials(self): self.assertTrue(is_credentials(("username", "password"))) def test_is_not_credentials(self): self.assertFalse(is_credentials(None)) self.assertFalse(is_credentials("username")) class CustomWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): async def process_request(self, path, request_headers): type(self).used = True return await super().process_request(path, request_headers) class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): return hmac.compare_digest(password, "letmein") class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): create_protocol = basic_auth_protocol_factory( realm="auth-tests", credentials=("hello", "iloveyou") ) @with_server(create_protocol=create_protocol) @with_client(user_info=("hello", "iloveyou")) def test_basic_auth(self): req_headers = self.client.request_headers resp_headers = self.client.response_headers self.assertEqual(req_headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") self.assertNotIn("WWW-Authenticate", resp_headers) self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) def test_basic_auth_server_no_credentials(self): with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=None) self.assertEqual( str(raised.exception), "provide either credentials or check_credentials" ) def test_basic_auth_server_bad_credentials(self): with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=42) self.assertEqual(str(raised.exception), "invalid credentials argument: 42") create_protocol_multiple_credentials = basic_auth_protocol_factory( realm="auth-tests", credentials=[("hello", "iloveyou"), ("goodbye", "stillloveu")], ) @with_server(create_protocol=create_protocol_multiple_credentials) @with_client(user_info=("hello", "iloveyou")) def test_basic_auth_server_multiple_credentials(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) def test_basic_auth_bad_multiple_credentials(self): with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory( realm="auth-tests", credentials=[("hello", "iloveyou"), 42] ) self.assertEqual( str(raised.exception), "invalid credentials argument: [('hello', 'iloveyou'), 42]", ) async def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") create_protocol_check_credentials = basic_auth_protocol_factory( realm="auth-tests", check_credentials=check_credentials, ) @with_server(create_protocol=create_protocol_check_credentials) @with_client(user_info=("hello", "iloveyou")) def test_basic_auth_check_credentials(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) create_protocol_custom_protocol = basic_auth_protocol_factory( realm="auth-tests", credentials=[("hello", "iloveyou")], create_protocol=CustomWebSocketServerProtocol, ) @with_server(create_protocol=create_protocol_custom_protocol) @with_client(user_info=("hello", "iloveyou")) def test_basic_auth_custom_protocol(self): self.assertTrue(CustomWebSocketServerProtocol.used) del CustomWebSocketServerProtocol.used self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) @with_server(create_protocol=CheckWebSocketServerProtocol) @with_client(user_info=("hello", "letmein")) def test_basic_auth_custom_protocol_subclass(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) # CustomWebSocketServerProtocol doesn't override check_credentials @with_server(create_protocol=CustomWebSocketServerProtocol) def test_basic_auth_defaults_to_deny_all(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "iloveyou")) self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_missing_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_missing_credentials_details(self): with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) self.assertEqual(raised.exception.code, 401) self.assertEqual( raised.exception.headers["WWW-Authenticate"], 'Basic realm="auth-tests", charset="UTF-8"', ) self.assertEqual(raised.exception.read().decode(), "Missing credentials\n") @with_server(create_protocol=create_protocol) def test_basic_auth_unsupported_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(extra_headers={"Authorization": "Digest ..."}) self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_unsupported_credentials_details(self): with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete( self.make_http_request(headers={"Authorization": "Digest ..."}) ) self.assertEqual(raised.exception.code, 401) self.assertEqual( raised.exception.headers["WWW-Authenticate"], 'Basic realm="auth-tests", charset="UTF-8"', ) self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") @with_server(create_protocol=create_protocol) def test_basic_auth_invalid_username(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("goodbye", "iloveyou")) self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_invalid_password(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_invalid_credentials_details(self): with self.assertRaises(urllib.error.HTTPError) as raised: authorization = build_authorization_basic("hello", "ihateyou") self.loop.run_until_complete( self.make_http_request(headers={"Authorization": authorization}) ) self.assertEqual(raised.exception.code, 401) self.assertEqual( raised.exception.headers["WWW-Authenticate"], 'Basic realm="auth-tests", charset="UTF-8"', ) self.assertEqual(raised.exception.read().decode(), "Invalid credentials\n") websockets-15.0.1/tests/legacy/test_client_server.py000066400000000000000000001741521476212450300226500ustar00rootroot00000000000000import asyncio import contextlib import functools import http import logging import platform import random import re import socket import ssl import sys import unittest import urllib.error import urllib.request import warnings from unittest.mock import patch from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, InvalidHeader, NegotiationError, ) from websockets.extensions.permessage_deflate import ( ClientPerMessageDeflateFactory, PerMessageDeflate, ServerPerMessageDeflateFactory, ) from websockets.frames import CloseCode from websockets.http11 import USER_AGENT from websockets.legacy.client import * from websockets.legacy.exceptions import InvalidStatusCode from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * from websockets.protocol import State from websockets.uri import parse_uri from ..extensions.utils import ( ClientNoOpExtensionFactory, NoOpExtension, ServerNoOpExtensionFactory, ) from ..utils import CERTIFICATE, MS, temp_unix_socket_path from .utils import AsyncioTestCase async def default_handler(ws): if ws.path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) elif ws.path == "/close_timeout": await ws.send(repr(ws.close_timeout)) elif ws.path == "/path": await ws.send(str(ws.path)) elif ws.path == "/headers": await ws.send(repr(ws.request_headers)) await ws.send(repr(ws.response_headers)) elif ws.path == "/extensions": await ws.send(repr(ws.extensions)) elif ws.path == "/subprotocol": await ws.send(repr(ws.subprotocol)) elif ws.path == "/slow_stop": await ws.wait_closed() await asyncio.sleep(2 * MS) else: await ws.send(await ws.recv()) async def redirect_request(path, headers, test, status): if path == "/absolute_redirect": location = get_server_uri(test.server, test.secure, "/") elif path == "/relative_redirect": location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") elif path == "/force_secure": location = get_server_uri(test.server, True, "/") elif path == "/force_insecure": location = get_server_uri(test.server, False, "/") elif path == "/missing_location": return status, {}, b"" else: return None return status, {"Location": location}, b"" @contextlib.contextmanager def temp_test_server(test, **kwargs): test.start_server(**kwargs) try: yield finally: test.stop_server() def temp_test_redirecting_server(test, status=http.HTTPStatus.FOUND, **kwargs): process_request = functools.partial(redirect_request, test=test, status=status) return temp_test_server(test, process_request=process_request, **kwargs) @contextlib.contextmanager def temp_test_client(test, *args, **kwargs): test.start_client(*args, **kwargs) try: yield finally: test.stop_client() def with_manager(manager, *args, **kwargs): """ Return a decorator that wraps a function with a context manager. """ def decorate(func): @functools.wraps(func) def _decorate(self, *_args, **_kwargs): with manager(self, *args, **kwargs): return func(self, *_args, **_kwargs) return _decorate return decorate def with_server(**kwargs): """ Return a decorator for TestCase methods that starts and stops a server. """ return with_manager(temp_test_server, **kwargs) def with_client(*args, **kwargs): """ Return a decorator for TestCase methods that starts and stops a client. """ return with_manager(temp_test_client, *args, **kwargs) def get_server_address(server): """ Return an address on which the given server listens. """ # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If # needed, either use the first socket, or test separately IPv4 and IPv6. server_socket = random.choice(server.sockets) if server_socket.family == socket.AF_INET6: # pragma: no cover return server_socket.getsockname()[:2] # (no IPv6 on CI) elif server_socket.family == socket.AF_INET: return server_socket.getsockname() else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") def get_server_uri(server, secure=False, resource_name="/", user_info=None): """ Return a WebSocket URI for connecting to the given server. """ proto = "wss" if secure else "ws" user_info = ":".join(user_info) + "@" if user_info else "" host, port = get_server_address(server) if ":" in host: # IPv6 address host = f"[{host}]" return f"{proto}://{user_info}{host}:{port}{resource_name}" class UnauthorizedServerProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) return http.HTTPStatus.UNAUTHORIZED, Headers([("X-Access", "denied")]), b"" class ForbiddenServerProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) return http.HTTPStatus.FORBIDDEN, {"X-Access": "denied"}, b"" class HealthCheckServerProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) if path == "/__health__/": return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" class ProcessRequestReturningIntProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): assert path == "/__health__/" return 200, [], b"OK\n" class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) class FooClientProtocol(WebSocketClientProtocol): pass class BarClientProtocol(WebSocketClientProtocol): pass class ClientServerTestsMixin: secure = False def setUp(self): super().setUp() self.server = None def start_server(self, deprecation_warnings=None, **kwargs): handler = kwargs.pop("handler", default_handler) # Disable compression by default in tests. kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) # This logic is encapsulated in a coroutine to prevent it from executing # before the event loop is running which causes asyncio.get_event_loop() # to raise a DeprecationWarning on Python ≥ 3.10. async def start_server(): return await serve(handler, "localhost", 0, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") self.server = self.loop.run_until_complete(start_server()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs ): # Disable compression by default in tests. kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) secure = kwargs.get("ssl") is not None try: server_uri = kwargs.pop("uri") except KeyError: server_uri = get_server_uri(self.server, secure, resource_name, user_info) # This logic is encapsulated in a coroutine to prevent it from executing # before the event loop is running which causes asyncio.get_event_loop() # to raise a DeprecationWarning on Python ≥ 3.10. async def start_client(): return await connect(server_uri, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") self.client = self.loop.run_until_complete(start_client()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): self.loop.run_until_complete( asyncio.wait_for(self.client.close_connection_task, timeout=1) ) def stop_server(self): self.server.close() self.loop.run_until_complete( asyncio.wait_for(self.server.wait_closed(), timeout=1) ) @contextlib.contextmanager def temp_server(self, **kwargs): with temp_test_server(self, **kwargs): yield @contextlib.contextmanager def temp_client(self, *args, **kwargs): with temp_test_client(self, *args, **kwargs): yield def make_http_request(self, path="/", headers=None): if headers is None: headers = {} # Set url to 'https?://:'. url = get_server_uri( self.server, resource_name=path, secure=self.secure ).replace("ws", "http") request = urllib.request.Request(url=url, headers=headers) if self.secure: open_health_check = functools.partial( urllib.request.urlopen, request, context=self.client_context ) else: open_health_check = functools.partial(urllib.request.urlopen, request) return self.loop.run_in_executor(None, open_health_check) class SecureClientServerTestsMixin(ClientServerTestsMixin): secure = True @property def server_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.load_cert_chain(CERTIFICATE) return ssl_context @property def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(CERTIFICATE) return ssl_context def start_server(self, **kwargs): kwargs.setdefault("ssl", self.server_context) super().start_server(**kwargs) def start_client(self, path="/", **kwargs): kwargs.setdefault("ssl", self.client_context) super().start_client(path, **kwargs) class CommonClientServerTests: """ Mixin that defines most tests but doesn't inherit unittest.TestCase. Tests are run by the ClientServerTests and SecureClientServerTests subclasses. """ @with_server() @with_client() def test_basic(self): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_redirect(self): redirect_statuses = [ http.HTTPStatus.MOVED_PERMANENTLY, http.HTTPStatus.FOUND, http.HTTPStatus.SEE_OTHER, http.HTTPStatus.TEMPORARY_REDIRECT, http.HTTPStatus.PERMANENT_REDIRECT, ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): with self.temp_client("/absolute_redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_redirect_relative_location(self): with temp_test_redirecting_server(self): with self.temp_client("/relative_redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/infinite"): self.fail("did not raise") def test_redirect_missing_location(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHeader): with self.temp_client("/missing_location"): self.fail("did not raise") def test_loop_backwards_compatibility(self): with self.temp_server( loop=self.loop, deprecation_warnings=["remove loop argument"], ): with self.temp_client( loop=self.loop, deprecation_warnings=["remove loop argument"], ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") @with_server() def test_explicit_host_port(self): uri = get_server_uri(self.server, self.secure) wsuri = parse_uri(uri) # Change host and port to invalid values. scheme = "wss" if wsuri.secure else "ws" port = 65535 - wsuri.port changed_uri = f"{scheme}://example.com:{port}/" with self.temp_client(uri=changed_uri, host=wsuri.host, port=wsuri.port): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): def __init__(self, *args, **kwargs): self.used_for_read = False self.used_for_write = False super().__init__(*args, **kwargs) def recv(self, *args, **kwargs): self.used_for_read = True return super().recv(*args, **kwargs) def recv_into(self, *args, **kwargs): self.used_for_read = True return super().recv_into(*args, **kwargs) def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) server_socket = [ sock for sock in self.server.sockets if sock.family == socket.AF_INET ][0] client_socket = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) client_socket.connect(server_socket.getsockname()) try: self.assertFalse(client_socket.used_for_read) self.assertFalse(client_socket.used_for_write) with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.assertTrue(client_socket.used_for_read) self.assertTrue(client_socket.used_for_write) finally: client_socket.close() @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_socket(self): with temp_unix_socket_path() as path: # Like self.start_server() but with unix_serve(). async def start_server(): return await unix_serve(default_handler, path) self.server = self.loop.run_until_complete(start_server()) try: # Like self.start_client() but with unix_connect() async def start_client(): return await unix_connect(path) self.client = self.loop.run_until_complete(start_client()) try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") finally: self.stop_client() finally: self.stop_server() def test_ws_handler_argument_backwards_compatibility(self): async def handler_with_path(ws, path): await ws.send(path) with self.temp_server( handler=handler_with_path, deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( self.loop.run_until_complete(self.client.recv()), "/path", ) def test_ws_handler_argument_backwards_compatibility_partial(self): async def handler_with_path(ws, path, extra): await ws.send(path) bound_handler_with_path = functools.partial(handler_with_path, extra=None) with self.temp_server( handler=bound_handler_with_path, deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( self.loop.run_until_complete(self.client.recv()), "/path", ) async def process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @with_server(process_request=process_request_OK) def test_process_request_argument(self): response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) def legacy_process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @with_server(process_request=legacy_process_request_OK) def test_process_request_argument_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertDeprecationWarnings( recorded_warnings, ["declare process_request as a coroutine"] ) class ProcessRequestOKServerProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @with_server(create_protocol=ProcessRequestOKServerProtocol) def test_process_request_override(self): response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) class LegacyProcessRequestOKServerProtocol(WebSocketServerProtocol): def process_request(self, path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @with_server(create_protocol=LegacyProcessRequestOKServerProtocol) def test_process_request_override_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertDeprecationWarnings( recorded_warnings, ["declare process_request as a coroutine"] ) def select_subprotocol_chat(client_subprotocols, server_subprotocols): return "chat" @with_server( subprotocols=["superchat", "chat"], select_subprotocol=select_subprotocol_chat ) @with_client("/subprotocol", subprotocols=["superchat", "chat"]) def test_select_subprotocol_argument(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") class SelectSubprotocolChatServerProtocol(WebSocketServerProtocol): def select_subprotocol(self, client_subprotocols, server_subprotocols): return "chat" @with_server( subprotocols=["superchat", "chat"], create_protocol=SelectSubprotocolChatServerProtocol, ) @with_client("/subprotocol", subprotocols=["superchat", "chat"]) def test_select_subprotocol_override(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") @with_server() @with_client("/deprecated_attributes") def test_protocol_deprecated_attributes(self): # The test could be connecting with IPv6 or IPv4. expected_client_attrs = [ server_socket.getsockname()[:2] + (self.secure,) for server_socket in self.server.sockets ] with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertDeprecationWarnings( recorded_warnings, [ "use remote_address[0] instead of host", "use remote_address[1] instead of port", "don't use secure", ], ) self.assertIn(client_attrs, expected_client_attrs) expected_server_attrs = ("localhost", 0, self.secure) with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") self.loop.run_until_complete(self.client.send("")) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertDeprecationWarnings( recorded_warnings, [ "use local_address[0] instead of host", "use local_address[1] instead of port", "don't use secure", ], ) self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() @with_client("/path") def test_protocol_path(self): client_path = self.client.path self.assertEqual(client_path, "/path") server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, "/path") @with_server() @with_client("/headers") def test_protocol_headers(self): client_req = self.client.request_headers client_resp = self.client.response_headers self.assertEqual(client_req["User-Agent"], USER_AGENT) self.assertEqual(client_resp["Server"], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) @with_server() @with_client("/headers", extra_headers={"X-Spam": "Eggs"}) def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client("/headers", extra_headers={"User-Agent": "websockets"}) def test_protocol_custom_user_agent_header_legacy(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'websockets')", req_headers) @with_server() @with_client("/headers", user_agent_header=None) def test_protocol_no_user_agent_header(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertNotIn("User-Agent", req_headers) @with_server() @with_client("/headers", user_agent_header="websockets") def test_protocol_custom_user_agent_header(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'websockets')", req_headers) @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) @with_client("/headers") def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: None) @with_client("/headers") def test_protocol_custom_response_headers_callable_none(self): self.loop.run_until_complete(self.client.recv()) # doesn't crash self.loop.run_until_complete(self.client.recv()) # nothing to check @with_server(extra_headers={"X-Spam": "Eggs"}) @with_client("/headers") def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers={"Server": "websockets"}) @with_client("/headers") def test_protocol_custom_server_header_legacy(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'websockets')", resp_headers) @with_server(server_header=None) @with_client("/headers") def test_protocol_no_server_header(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertNotIn("Server", resp_headers) @with_server(server_header="websockets") @with_client("/headers") def test_protocol_custom_server_header(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'websockets')", resp_headers) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making an HTTP request to an HTTP endpoint succeeds. response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertEqual(response.read(), b"status = green\n") @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): # Making an HTTP request to a WS endpoint fails. with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) self.assertEqual(raised.exception.code, 426) self.assertEqual(raised.exception.headers["Upgrade"], "websocket") @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): # Making a WS connection to an HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: self.start_client("/__health__/") self.assertEqual(raised.exception.status_code, 200) @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_ws_endpoint(self): # Making a WS connection to a WS endpoint succeeds. self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) self.stop_client() @with_server(create_protocol=HealthCheckServerProtocol, server_header=None) def test_http_request_no_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): self.assertNotIn("Server", response.headers) @with_server(create_protocol=HealthCheckServerProtocol, server_header="websockets") def test_http_request_custom_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): self.assertEqual(response.headers["Server"], "websockets") @with_server(create_protocol=ProcessRequestReturningIntProtocol) def test_process_request_returns_int_status(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertEqual(response.read(), b"OK\n") def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() self.assertEqual(raised.exception.status_code, status_code) @with_server(create_protocol=UnauthorizedServerProtocol) def test_server_create_protocol(self): self.assert_client_raises_code(401) def create_unauthorized_server_protocol(*args, **kwargs): return UnauthorizedServerProtocol(*args, **kwargs) @with_server(create_protocol=create_unauthorized_server_protocol) def test_server_create_protocol_function(self): self.assert_client_raises_code(401) @with_server( klass=UnauthorizedServerProtocol, deprecation_warnings=["rename klass to create_protocol"], ) def test_server_klass_backwards_compatibility(self): self.assert_client_raises_code(401) @with_server( create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol, deprecation_warnings=["rename klass to create_protocol"], ) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @with_server() @with_client("/path", create_protocol=FooClientProtocol) def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client( "/path", create_protocol=(lambda *args, **kwargs: FooClientProtocol(*args, **kwargs)), ) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client( "/path", klass=FooClientProtocol, deprecation_warnings=["rename klass to create_protocol"], ) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client( "/path", create_protocol=BarClientProtocol, klass=FooClientProtocol, deprecation_warnings=["rename klass to create_protocol"], ) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @with_server(close_timeout=7) @with_client("/close_timeout") def test_server_close_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server(timeout=6, deprecation_warnings=["rename timeout to close_timeout"]) @with_client("/close_timeout") def test_server_timeout_backwards_compatibility(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 6) @with_server( close_timeout=7, timeout=6, deprecation_warnings=["rename timeout to close_timeout"], ) @with_client("/close_timeout") def test_server_close_timeout_over_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server() @with_client("/close_timeout", close_timeout=7) def test_client_close_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() @with_client( "/close_timeout", timeout=6, deprecation_warnings=["rename timeout to close_timeout"], ) def test_client_timeout_backwards_compatibility(self): self.assertEqual(self.client.close_timeout, 6) @with_server() @with_client( "/close_timeout", close_timeout=7, timeout=6, deprecation_warnings=["rename timeout to close_timeout"], ) def test_client_close_timeout_over_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() @with_client("/extensions") def test_no_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([NoOpExtension()])) self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) @with_server() @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension_not_accepted(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client("/extensions") def test_extension_not_requested(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory([("foo", None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): self.start_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) @with_server( extensions=[ # No match because the client doesn't send client_max_window_bits. ServerPerMessageDeflateFactory( client_max_window_bits=10, require_client_max_window_bits=True, ), ServerPerMessageDeflateFactory(), ] ) @with_client( "/extensions", extensions=[ ClientPerMessageDeflateFactory(client_max_window_bits=None), ], ) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( server_extensions, repr([PerMessageDeflate(False, False, 15, 15)]) ) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 15, 15)]), ) @with_server(extensions=[ServerPerMessageDeflateFactory()]) @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server( extensions=[ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory()] ) @with_client( "/extensions", extensions=[ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory()], ) def test_extension_order(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( server_extensions, repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), ) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] with self.assertRaises(NegotiationError): self.start_client( "/extensions", extensions=[ClientPerMessageDeflateFactory()] ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] with self.assertRaises(InvalidHandshake): self.start_client("/extensions") @with_server(compression="deflate") @with_client("/extensions", compression="deflate") def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( server_extensions, repr([PerMessageDeflate(False, False, 12, 12)]) ) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 12, 12)]), ) def test_compression_unsupported_server(self): with self.assertRaises(ValueError): self.start_server(compression="xz") @with_server() def test_compression_unsupported_client(self): with self.assertRaises(ValueError): self.start_client(compression="xz") @with_server() @with_client("/subprotocol") def test_no_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=["superchat", "chat"]) @with_client("/subprotocol", subprotocols=["otherchat", "chat"]) def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") def test_invalid_subprotocol_server(self): with self.assertRaises(TypeError): self.start_server(subprotocols="sip") @with_server() def test_invalid_subprotocol_client(self): with self.assertRaises(TypeError): self.start_client(subprotocols="sip") @with_server(subprotocols=["superchat"]) @with_client("/subprotocol", subprotocols=["otherchat"]) def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server() @with_client("/subprotocol", subprotocols=["otherchat", "chat"]) def test_subprotocol_not_offered(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=["superchat", "chat"]) @with_client("/subprotocol") def test_subprotocol_not_requested(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=["superchat"]) @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" with self.assertRaises(NegotiationError): self.start_client("/subprotocol", subprotocols=["otherchat"]) self.run_loop_once() @with_server(subprotocols=["superchat"]) @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" with self.assertRaises(InvalidHandshake): self.start_client("/subprotocol") self.run_loop_once() @with_server(subprotocols=["superchat", "chat"]) @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat, chat" with self.assertRaises(InvalidHandshake): self.start_client("/subprotocol", subprotocols=["superchat", "chat"]) self.run_loop_once() @with_server() @patch("websockets.legacy.server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @patch("websockets.legacy.client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") with self.assertRaises(InvalidHandshake): self.start_client() self.run_loop_once() @with_server() @patch("websockets.legacy.client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return "42" _build_request.side_effect = wrong_build_request with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @patch("websockets.legacy.server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, "42") _build_response.side_effect = wrong_build_response with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @patch("websockets.legacy.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): async def wrong_read_response(stream): status_code, reason, headers = await read_response(stream) return 400, "Bad Request", headers _read_response.side_effect = wrong_read_response with self.assertRaises(InvalidStatusCode): self.start_client() self.run_loop_once() @with_server() @patch("websockets.legacy.server.WebSocketServerProtocol.process_request") def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server(create_protocol=SlowOpeningHandshakeProtocol) def test_client_connect_canceled_during_handshake(self): sock = socket.create_connection(get_server_address(self.server)) sock.send(b"") # socket is connected async def cancelled_client(): start_client = connect(get_server_uri(self.server), sock=sock) async with asyncio_timeout(5 * MS): await start_client with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(cancelled_client()) with self.assertRaises(OSError): sock.send(b"") # socket is closed @with_server() @patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, CloseCode.INTERNAL_ERROR) @with_server() @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) @with_server() @with_client() @patch.object(WebSocketClientProtocol, "handshake") def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. self.client.transport.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. @with_server(create_protocol=SlowOpeningHandshakeProtocol) def test_server_shuts_down_during_opening_handshake(self): self.loop.call_later(5 * MS, self.server.close) with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception self.assertEqual( str(exception), "server rejected WebSocket connection: HTTP 503" ) self.assertEqual(exception.status_code, 503) @with_server() def test_server_shuts_down_during_connection_handling(self): with self.temp_client(): server_ws = next(iter(self.server.websockets)) self.server.close() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) # Server closed the connection with 1001 Going Away. self.assertEqual(self.client.close_code, CloseCode.GOING_AWAY) self.assertEqual(server_ws.close_code, CloseCode.GOING_AWAY) @with_server() def test_server_shuts_down_gracefully_during_connection_handling(self): with self.temp_client(): server_ws = next(iter(self.server.websockets)) self.server.close(close_connections=False) self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) # Client closed the connection with 1000 OK. self.assertEqual(self.client.close_code, CloseCode.NORMAL_CLOSURE) self.assertEqual(server_ws.close_code, CloseCode.NORMAL_CLOSURE) @with_server() def test_server_shuts_down_and_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order # to test that wait_closed() really waits for handlers to complete. self.start_client("/slow_stop") server_ws = next(iter(self.server.websockets)) # Test that the handler task keeps running after close(). self.server.close() self.loop.run_until_complete(asyncio.sleep(MS)) self.assertFalse(server_ws.handler_task.done()) # Test that the handler task terminates before wait_closed() returns. self.loop.run_until_complete(self.server.wait_closed()) self.assertTrue(server_ws.handler_task.done()) @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception self.assertEqual( str(exception), "server rejected WebSocket connection: HTTP 403" ) self.assertEqual(exception.status_code, 403) @with_server() @patch("websockets.legacy.server.WebSocketServerProtocol.write_http_response") @patch("websockets.legacy.server.WebSocketServerProtocol.read_http_request") def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response ): _read_http_request.side_effect = ConnectionError # This exception is currently platform-dependent. It was observed to # be ConnectionResetError on Linux in the non-TLS case, and # InvalidMessage otherwise (including both Linux and macOS). This # doesn't matter though since this test is primarily for testing a # code path on the server side. with self.assertRaises(Exception): self.start_client() # No response must not be written if the network connection is broken. _write_http_response.assert_not_called() @with_server() @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): def test_redirect_secure(self): with temp_test_redirecting_server(self): # websockets doesn't support serving non-TLS and TLS connections # from the same server and this test suite makes it difficult to # run two servers. Therefore, we expect the redirect to create a # TLS client connection to a non-TLS server, which will fail. with self.assertRaises(ssl.SSLError): with self.temp_client("/force_secure"): self.fail("did not raise") class SecureClientServerTests( CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase ): # The implementation of this test makes it hard to run it over TLS. test_client_connect_canceled_during_handshake = None # TLS over Unix sockets doesn't make sense. test_unix_socket = None # This test fails under PyPy due to a difference with CPython. if platform.python_implementation() == "PyPy": # pragma: no cover test_http_request_ws_endpoint = None @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): self.start_client( uri=get_server_uri(self.server, secure=False), ssl=self.client_context ) def test_redirect_insecure(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/force_insecure"): self.fail("did not raise") class ClientServerOriginTests(ClientServerTestsMixin, AsyncioTestCase): @with_server(origins=["http://localhost"]) @with_client(origin="http://localhost") def test_checking_origin_succeeds(self): self.loop.run_until_complete(self.client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): with self.assertRaises(InvalidHandshake) as raised: self.start_client(origin="http://otherhost") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): with self.assertRaises(InvalidHandshake) as raised: self.start_client( origin="http://localhost", extra_headers=[("Origin", "http://otherhost")], ) self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 400", ) @with_server(origins=[None]) @with_client() def test_checking_lack_of_origin_succeeds(self): self.loop.run_until_complete(self.client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") @with_server(origins=[""]) # The deprecation warning is raised when a client connects to the server. @with_client(deprecation_warnings=["use None instead of '' in origins"]) def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.loop.run_until_complete(self.client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") @unittest.skipIf( sys.version_info[:2] >= (3, 11), "asyncio.coroutine has been removed in Python 3.11" ) class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): # pragma: no cover @with_server() def test_client(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(): warnings.simplefilter("ignore") @asyncio.coroutine def run_client(): # Yield from connect. client = yield from connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) yield from client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) def test_server(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(): warnings.simplefilter("ignore") @asyncio.coroutine def run_server(): # Yield from serve. server = yield from serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() yield from server.wait_closed() self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) class AsyncAwaitTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): async def run_client(): # Await connect. client = await connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) await client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) def test_server(self): async def run_server(): # Await serve. server = await serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() await server.wait_closed() self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) class ContextManagerTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): async def run_client(): # Use connect as an asynchronous context manager. async with connect(get_server_uri(self.server)) as client: self.assertEqual(client.state, State.OPEN) # Check that exiting the context manager closed the connection. self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. async with serve(default_handler, "localhost", 0) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_server(self): async def run_server(path): async with unix_serve(default_handler, path) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. self.assertFalse(server.sockets) with temp_unix_socket_path() as path: self.loop.run_until_complete(run_server(path)) class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. MESSAGES = ["3", "2", "1", "Fire!"] async def echo_handler(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) @with_server(handler=echo_handler) def test_iterate_on_messages(self): messages = [] async def run_client(): nonlocal messages async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) async def echo_handler_going_away(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(CloseCode.GOING_AWAY) @with_server(handler=echo_handler_going_away) def test_iterate_on_messages_going_away_exit_ok(self): messages = [] async def run_client(): nonlocal messages async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) async def echo_handler_internal_error(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(CloseCode.INTERNAL_ERROR) @with_server(handler=echo_handler_internal_error) def test_iterate_on_messages_internal_error_exit_not_ok(self): messages = [] async def run_client(): nonlocal messages async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): async def echo_handler(ws): async for msg in ws: await ws.send(msg) service_available = True async def maybe_service_unavailable(path, headers): if not ReconnectionTests.service_available: return http.HTTPStatus.SERVICE_UNAVAILABLE, [], b"" async def disable_server(self, duration): ReconnectionTests.service_available = False await asyncio.sleep(duration) ReconnectionTests.service_available = True @with_server(handler=echo_handler, process_request=maybe_service_unavailable) def test_reconnect(self): # Big, ugly integration test :-( async def run_client(): iteration = 0 connect_inst = connect(get_server_uri(self.server)) connect_inst.BACKOFF_MIN = 10 * MS connect_inst.BACKOFF_MAX = 99 * MS connect_inst.BACKOFF_INITIAL = 0 # coverage has a hard time dealing with this code - I give up. async for ws in connect_inst: # pragma: no cover await ws.send("spam") msg = await ws.recv() self.assertEqual(msg, "spam") iteration += 1 if iteration == 1: # Exit block normally. pass elif iteration == 2: # Disable server for a little bit asyncio.create_task(self.disable_server(50 * MS)) await asyncio.sleep(0) elif iteration == 3: # Exit block after catching connection error. server_ws = next(iter(self.server.websockets)) await server_ws.close() with self.assertRaises(ConnectionClosed): await ws.recv() else: # Exit block with an exception. raise Exception("BOOM") with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaises(Exception) as raised: self.loop.run_until_complete(run_client()) self.assertEqual(str(raised.exception), "BOOM") # Iteration 1 self.assertEqual( [record.getMessage() for record in logs.records][:2], [ "connection open", "connection closed", ], ) # Iteration 2 self.assertEqual( [record.getMessage() for record in logs.records][2:4], [ "connection open", "connection closed", ], ) # Iteration 3 exc = ( "websockets.legacy.exceptions.InvalidStatusCode: " "server rejected WebSocket connection: HTTP 503" ) self.assertEqual( [ re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) for record in logs.records ][4:-1], [ "connection rejected (503 Service Unavailable)", "connection closed", f"connect failed; reconnecting in X seconds: {exc}", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", f"connect failed again; retrying in X seconds: {exc}", ] * ((len(logs.records) - 8) // 3) + [ "connection open", "connection closed", ], ) # Iteration 4 self.assertEqual( [record.getMessage() for record in logs.records][-1:], [ "connection open", ], ) class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): def test_logger_client(self): with self.assertLogs("test.server", logging.DEBUG) as server_logs: self.start_server(logger=logging.getLogger("test.server")) with self.assertLogs("test.client", logging.DEBUG) as client_logs: self.start_client(logger=logging.getLogger("test.client")) self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) self.stop_client() self.stop_server() self.assertGreater(len(server_logs.records), 0) self.assertGreater(len(client_logs.records), 0) websockets-15.0.1/tests/legacy/test_exceptions.py000066400000000000000000000013651476212450300221600ustar00rootroot00000000000000import unittest from websockets.datastructures import Headers from websockets.legacy.exceptions import * class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), ( AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) websockets-15.0.1/tests/legacy/test_framing.py000066400000000000000000000215011476212450300214140ustar00rootroot00000000000000import asyncio import codecs import dataclasses import unittest import unittest.mock import warnings from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT, CloseCode from websockets.legacy.framing import * from .utils import AsyncioTestCase class FramingTests(AsyncioTestCase): def decode(self, message, mask=False, max_size=None, extensions=None): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(message) stream.feed_eof() with warnings.catch_warnings(): warnings.simplefilter("ignore") frame = self.loop.run_until_complete( Frame.read( stream.readexactly, mask=mask, max_size=max_size, extensions=extensions, ) ) # Make sure all the data was consumed. self.assertTrue(stream.at_eof()) return frame def encode(self, frame, mask=False, extensions=None): write = unittest.mock.Mock() with warnings.catch_warnings(): warnings.simplefilter("ignore") frame.write(write, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. self.assertEqual(write.call_count, 1) # The frame data is the single positional argument of that call. self.assertEqual(len(write.call_args[0]), 1) self.assertEqual(len(write.call_args[1]), 0) return write.call_args[0][0] def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) decoded.check() self.assertEqual(decoded, expected) encoded = self.encode(decoded, mask, extensions=extensions) if mask: # non-deterministic encoding decoded = self.decode(encoded, mask, extensions=extensions) self.assertEqual(decoded, expected) else: # deterministic encoding self.assertEqual(encoded, message) def test_text(self): self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) def test_text_masked(self): self.round_trip( b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", Frame(True, OP_TEXT, b"Spam"), mask=True, ) def test_binary(self): self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) def test_binary_masked(self): self.round_trip( b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", Frame(True, OP_BINARY, b"Eggs"), mask=True, ) def test_non_ascii_text(self): self.round_trip(b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode())) def test_non_ascii_text_masked(self): self.round_trip( b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", Frame(True, OP_TEXT, "café".encode()), mask=True, ) def test_close(self): self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) def test_ping(self): self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) def test_pong(self): self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) def test_long(self): self.round_trip( b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") ) def test_very_long(self): self.round_trip( b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", Frame(True, OP_BINARY, 65536 * b"a"), ) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): self.decode(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) def test_bad_reserved_bits(self): for encoded in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: with self.subTest(encoded=encoded): with self.assertRaises(ProtocolError): self.decode(encoded) def test_good_opcode(self): for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): self.decode(encoded) # does not raise an exception def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): with self.assertRaises(ProtocolError): self.decode(encoded) def test_mask_flag(self): # Mask flag correctly set. self.decode(b"\x80\x80\x00\x00\x00\x00", mask=True) # Mask flag incorrectly unset. with self.assertRaises(ProtocolError): self.decode(b"\x80\x80\x00\x00\x00\x00") # Mask flag correctly unset. self.decode(b"\x80\x00") # Mask flag incorrectly set. with self.assertRaises(ProtocolError): self.decode(b"\x80\x00", mask=True) def test_control_frame_max_length(self): # At maximum allowed length. self.decode(b"\x88\x7e\x00\x7d" + 125 * b"a") # Above maximum allowed length. with self.assertRaises(ProtocolError): self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") def test_fragmented_control_frame(self): # Fin bit correctly set. self.decode(b"\x88\x00") # Fin bit incorrectly unset. with self.assertRaises(ProtocolError): self.decode(b"\x08\x00") def test_extensions(self): class Rot13: @staticmethod def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod def decode(frame, *, max_size=None): return Rot13.encode(frame) self.round_trip( b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] ) class PrepareDataTests(unittest.TestCase): def test_prepare_data_str(self): self.assertEqual( prepare_data("café"), (OP_TEXT, b"caf\xc3\xa9"), ) def test_prepare_data_bytes(self): self.assertEqual( prepare_data(b"tea"), (OP_BINARY, b"tea"), ) def test_prepare_data_bytearray(self): self.assertEqual( prepare_data(bytearray(b"tea")), (OP_BINARY, bytearray(b"tea")), ) def test_prepare_data_memoryview(self): self.assertEqual( prepare_data(memoryview(b"tea")), (OP_BINARY, memoryview(b"tea")), ) def test_prepare_data_list(self): with self.assertRaises(TypeError): prepare_data([]) def test_prepare_data_none(self): with self.assertRaises(TypeError): prepare_data(None) class PrepareCtrlTests(unittest.TestCase): def test_prepare_ctrl_str(self): self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") def test_prepare_ctrl_bytes(self): self.assertEqual(prepare_ctrl(b"tea"), b"tea") def test_prepare_ctrl_bytearray(self): self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") def test_prepare_ctrl_memoryview(self): self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") def test_prepare_ctrl_list(self): with self.assertRaises(TypeError): prepare_ctrl([]) def test_prepare_ctrl_none(self): with self.assertRaises(TypeError): prepare_ctrl(None) class ParseAndSerializeCloseTests(unittest.TestCase): def assertCloseData(self, code, reason, data): """ Serializing code / reason yields data. Parsing data yields code / reason. """ serialized = serialize_close(code, reason) self.assertEqual(serialized, data) parsed = parse_close(data) self.assertEqual(parsed, (code, reason)) def test_parse_close_and_serialize_close(self): self.assertCloseData(CloseCode.NORMAL_CLOSURE, "", b"\x03\xe8") self.assertCloseData(CloseCode.NORMAL_CLOSURE, "OK", b"\x03\xe8OK") def test_parse_close_empty(self): self.assertEqual(parse_close(b""), (CloseCode.NO_STATUS_RCVD, "")) def test_parse_close_errors(self): with self.assertRaises(ProtocolError): parse_close(b"\x03") with self.assertRaises(ProtocolError): parse_close(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): parse_close(b"\x03\xe8\xff\xff") def test_serialize_close_errors(self): with self.assertRaises(ProtocolError): serialize_close(999, "") websockets-15.0.1/tests/legacy/test_handshake.py000066400000000000000000000153471476212450300217320ustar00rootroot00000000000000import contextlib import unittest from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidUpgrade, ) from websockets.legacy.handshake import * from websockets.utils import accept_key class HandshakeTests(unittest.TestCase): def test_round_trip(self): request_headers = Headers() request_key = build_request(request_headers) response_key = check_request(request_headers) self.assertEqual(request_key, response_key) response_headers = Headers() build_response(response_headers, response_key) check_response(response_headers, request_key) @contextlib.contextmanager def assertValidRequestHeaders(self): """ Provide request headers for modification. Assert that the transformation kept them valid. """ headers = Headers() build_request(headers) yield headers check_request(headers) @contextlib.contextmanager def assertInvalidRequestHeaders(self, exc_type): """ Provide request headers for modification. Assert that the transformation made them invalid. """ headers = Headers() build_request(headers) yield headers assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_request(headers) def test_request_invalid_connection(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers["Connection"] headers["Connection"] = "Downgrade" def test_request_missing_connection(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers["Connection"] def test_request_additional_connection(self): with self.assertValidRequestHeaders() as headers: headers["Connection"] = "close" def test_request_invalid_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers["Upgrade"] headers["Upgrade"] = "socketweb" def test_request_missing_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers["Upgrade"] def test_request_additional_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: headers["Upgrade"] = "socketweb" def test_request_invalid_key_not_base64(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Key"] headers["Sec-WebSocket-Key"] = "!@#$%^&*()" def test_request_invalid_key_not_well_padded(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Key"] headers["Sec-WebSocket-Key"] = "CSIRmL8dWYxeAdr/XpEHRw" def test_request_invalid_key_not_16_bytes_long(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Key"] headers["Sec-WebSocket-Key"] = "ZLpprpvK4PE=" def test_request_missing_key(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers["Sec-WebSocket-Key"] def test_request_additional_key(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Key header. headers["Sec-WebSocket-Key"] = headers["Sec-WebSocket-Key"] def test_request_invalid_version(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Version"] headers["Sec-WebSocket-Version"] = "42" def test_request_missing_version(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers["Sec-WebSocket-Version"] def test_request_additional_version(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Version header. headers["Sec-WebSocket-Version"] = headers["Sec-WebSocket-Version"] @contextlib.contextmanager def assertValidResponseHeaders(self, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. Assert that the transformation kept them valid. """ headers = Headers() build_response(headers, key) yield headers check_response(headers, key) @contextlib.contextmanager def assertInvalidResponseHeaders(self, exc_type, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. Assert that the transformation made them invalid. """ headers = Headers() build_response(headers, key) yield headers assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_response(headers, key) def test_response_invalid_connection(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers["Connection"] headers["Connection"] = "Downgrade" def test_response_missing_connection(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers["Connection"] def test_response_additional_connection(self): with self.assertValidResponseHeaders() as headers: headers["Connection"] = "close" def test_response_invalid_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers["Upgrade"] headers["Upgrade"] = "socketweb" def test_response_missing_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers["Upgrade"] def test_response_additional_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: headers["Upgrade"] = "socketweb" def test_response_invalid_accept(self): with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Accept"] other_key = "1Eq4UDEFQYg3YspNgqxv5g==" headers["Sec-WebSocket-Accept"] = accept_key(other_key) def test_response_missing_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: del headers["Sec-WebSocket-Accept"] def test_response_additional_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Accept header. headers["Sec-WebSocket-Accept"] = headers["Sec-WebSocket-Accept"] websockets-15.0.1/tests/legacy/test_http.py000066400000000000000000000147601476212450300207610ustar00rootroot00000000000000import asyncio from websockets.exceptions import SecurityError from websockets.legacy.http import * from websockets.legacy.http import read_headers from .utils import AsyncioTestCase class HTTPAsyncTests(AsyncioTestCase): def setUp(self): super().setUp() self.stream = asyncio.StreamReader(loop=self.loop) async def test_read_request(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( b"GET /chat HTTP/1.1\r\n" b"Host: server.example.com\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" b"Origin: http://example.com\r\n" b"Sec-WebSocket-Protocol: chat, superchat\r\n" b"Sec-WebSocket-Version: 13\r\n" b"\r\n" ) path, headers = await read_request(self.stream) self.assertEqual(path, "/chat") self.assertEqual(headers["Upgrade"], "websocket") async def test_read_request_empty(self): self.stream.feed_eof() with self.assertRaises(EOFError) as raised: await read_request(self.stream) self.assertEqual( str(raised.exception), "connection closed while reading HTTP request line", ) async def test_read_request_invalid_request_line(self): self.stream.feed_data(b"GET /\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_request(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP request line: GET /", ) async def test_read_request_unsupported_method(self): self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_request(self.stream) self.assertEqual( str(raised.exception), "unsupported HTTP method: OPTIONS", ) async def test_read_request_unsupported_version(self): self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_request(self.stream) self.assertEqual( str(raised.exception), "unsupported HTTP version: HTTP/1.0", ) async def test_read_request_invalid_header(self): self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") with self.assertRaises(ValueError) as raised: await read_request(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP header line: Oops", ) async def test_read_response(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( b"HTTP/1.1 101 Switching Protocols\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" b"Sec-WebSocket-Protocol: chat\r\n" b"\r\n" ) status_code, reason, headers = await read_response(self.stream) self.assertEqual(status_code, 101) self.assertEqual(reason, "Switching Protocols") self.assertEqual(headers["Upgrade"], "websocket") async def test_read_response_empty(self): self.stream.feed_eof() with self.assertRaises(EOFError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "connection closed while reading HTTP status line", ) async def test_read_request_invalid_status_line(self): self.stream.feed_data(b"Hello!\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP status line: Hello!", ) async def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "unsupported HTTP version: HTTP/1.0", ) async def test_read_response_invalid_status(self): self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP status code: OMG", ) async def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "unsupported HTTP status code: 007", ) async def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP reason phrase: \x7f", ) async def test_read_response_invalid_header(self): self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") with self.assertRaises(ValueError) as raised: await read_response(self.stream) self.assertEqual( str(raised.exception), "invalid HTTP header line: Oops", ) async def test_header_name(self): self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") with self.assertRaises(ValueError): await read_headers(self.stream) async def test_header_value(self): self.stream.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") with self.assertRaises(ValueError): await read_headers(self.stream) async def test_headers_limit(self): self.stream.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_limit(self): # Header line contains 5 + 8186 + 2 = 8193 bytes. self.stream.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_ending(self): self.stream.feed_data(b"foo: bar\n\n") with self.assertRaises(EOFError): await read_headers(self.stream) websockets-15.0.1/tests/legacy/test_protocol.py000066400000000000000000002007251476212450300216410ustar00rootroot00000000000000import asyncio import contextlib import logging import sys import unittest import unittest.mock import warnings from websockets.exceptions import ConnectionClosed, InvalidState from websockets.frames import ( OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Close, CloseCode, ) from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from websockets.protocol import State from ..utils import MS from .utils import AsyncioTestCase async def async_iterable(iterable): for item in iterable: yield item class TransportMock(unittest.mock.Mock): """ Transport mock to control the protocol's inputs and outputs in tests. It calls the protocol's connection_made and connection_lost methods like actual transports. It also calls the protocol's connection_open method to bypass the WebSocket handshake. To simulate incoming data, tests call the protocol's data_received and eof_received methods directly. They could also pause_writing and resume_writing to test flow control. """ # This should happen in __init__ but overriding Mock.__init__ is hard. def setup_mock(self, loop, protocol): self.loop = loop self.protocol = protocol self._eof = False self._closing = False # Simulate a successful TCP handshake. self.protocol.connection_made(self) # Simulate a successful WebSocket handshake. self.protocol.connection_open() def can_write_eof(self): return True def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. if not self._eof: self.loop.call_soon(self.close) self._eof = True def close(self): # Simulate how actual transports drop the connection. if not self._closing: self.loop.call_soon(self.protocol.connection_lost, None) self._closing = True def abort(self): # Change this to an `if` if tests call abort() multiple times. assert self.protocol.state is not State.CLOSED self.loop.call_soon(self.protocol.connection_lost, None) class CommonTests: """ Mixin that defines most tests but doesn't inherit unittest.TestCase. Tests are run by the ServerTests and ClientTests subclasses. """ def setUp(self): super().setUp() # This logic is encapsulated in a coroutine to prevent it from executing # before the event loop is running which causes asyncio.get_event_loop() # to raise a DeprecationWarning on Python ≥ 3.10. async def create_protocol(): # Disable pings to make it easier to test what frames are sent exactly. return WebSocketCommonProtocol(ping_interval=None) self.protocol = self.loop.run_until_complete(create_protocol()) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) def tearDown(self): self.transport.close() self.loop.run_until_complete(self.protocol.close()) super().tearDown() # Utilities for writing tests. def make_drain_slow(self, delay=MS): # Process connection_made in order to initialize self.protocol.transport. self.run_loop_once() original_drain = self.protocol._drain async def delayed_drain(): await asyncio.sleep(delay) await original_drain() self.protocol._drain = delayed_drain close_frame = Frame( True, OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "close").serialize(), ) local_close = Frame( True, OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "local").serialize(), ) remote_close = Frame( True, OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "remote").serialize(), ) def receive_frame(self, frame): """ Make the protocol receive a frame. """ write = self.protocol.data_received mask = not self.protocol.is_client frame.write(write, mask=mask) def receive_eof(self): """ Make the protocol receive the end of the data stream. Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an actual transport would close itself after calling it. This function emulates that behavior. """ self.protocol.eof_received() self.loop.call_soon(self.transport.close) def receive_eof_if_client(self): """ Like receive_eof, but only if this is the client side. Since the server is supposed to initiate the termination of the TCP connection, this method helps making tests work for both sides. """ if self.protocol.is_client: self.receive_eof() def close_connection(self, code=CloseCode.NORMAL_CLOSURE, reason="close"): """ Execute a closing handshake. This puts the connection in the CLOSED state. """ close_frame_data = Close(code, reason).serialize() # Prepare the response to the closing handshake from the remote side. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.receive_eof_if_client() # Trigger the closing handshake from the local side and complete it. self.loop.run_until_complete(self.protocol.close(code, reason)) # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSED def half_close_connection_local( self, code=CloseCode.NORMAL_CLOSURE, reason="close", ): """ Start a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. The current implementation returns a task that must be awaited or canceled, else asyncio complains about destroying a pending task. """ close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the local endpoint. close_task = self.loop.create_task(self.protocol.close(code, reason)) self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSING # Complete the closing sequence at 1ms intervals so the test can run # at each point even it goes back to the event loop several times. self.loop.call_later( MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data) ) self.loop.call_later(2 * MS, self.receive_eof_if_client) # This task must be awaited or canceled by the caller. return close_task def half_close_connection_remote( self, code=CloseCode.NORMAL_CLOSURE, reason="close", ): """ Receive a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. """ # On the server side, websockets completes the closing handshake and # closes the TCP connection immediately. Yield to the event loop after # sending the close frame to run the test while the connection is in # the CLOSING state. if not self.protocol.is_client: self.make_drain_slow() close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.run_loop_once() # read_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) assert self.protocol.state is State.CLOSING # Complete the closing sequence at 1ms intervals so the test can run # at each point even it goes back to the event loop several times. self.loop.call_later(2 * MS, self.receive_eof_if_client) def process_invalid_frames(self): """ Make the protocol fail quickly after simulating invalid data. To achieve this, this function triggers the protocol's eof_received, which interrupts pending reads waiting for more data. """ self.run_loop_once() self.receive_eof() self.loop.run_until_complete(self.protocol.close_connection_task) def sent_frames(self): """ Read all frames sent to the transport. """ stream = asyncio.StreamReader(loop=self.loop) for (data,), kw in self.transport.write.call_args_list: stream.feed_data(data) self.transport.write.call_args_list = [] stream.feed_eof() frames = [] while not stream.at_eof(): frames.append( self.loop.run_until_complete( Frame.read(stream.readexactly, mask=self.protocol.is_client) ) ) return frames def last_sent_frame(self): """ Read the last frame sent to the transport. This method assumes that at most one frame was sent. It raises an AssertionError otherwise. """ frames = self.sent_frames() if frames: assert len(frames) == 1 return frames[0] def assertFramesSent(self, *frames): self.assertEqual(self.sent_frames(), [Frame(*args) for args in frames]) def assertOneFrameSent(self, *args): self.assertEqual(self.last_sent_frame(), Frame(*args)) def assertNoFrameSent(self): self.assertIsNone(self.last_sent_frame()) def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # A close frame was received. self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) def assertConnectionFailed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. self.assertEqual(self.protocol.close_code, CloseCode.ABNORMAL_CLOSURE) self.assertEqual(self.protocol.close_reason, "") # A close frame was sent -- unless the connection was already lost. if code == CloseCode.ABNORMAL_CLOSURE: self.assertNoFrameSent() else: self.assertOneFrameSent(True, OP_CLOSE, Close(code, message).serialize()) @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): t0 = self.loop.time() yield t1 = self.loop.time() dt = t1 - t0 self.assertGreaterEqual(dt, min_time, f"Too fast: {dt} < {min_time}") self.assertLess(dt, max_time, f"Too slow: {dt} >= {max_time}") # Test constructor. def test_timeout_backwards_compatibility(self): async def create_protocol(): return WebSocketCommonProtocol(ping_interval=None, timeout=5) with warnings.catch_warnings(record=True) as recorded: warnings.simplefilter("always") protocol = self.loop.run_until_complete(create_protocol()) self.assertEqual(protocol.close_timeout, 5) self.assertDeprecationWarnings(recorded, ["rename timeout to close_timeout"]) def test_loop_backwards_compatibility(self): loop = asyncio.new_event_loop() self.addCleanup(loop.close) with warnings.catch_warnings(record=True) as recorded: warnings.simplefilter("always") protocol = WebSocketCommonProtocol(ping_interval=None, loop=loop) self.assertEqual(protocol.loop, loop) self.assertDeprecationWarnings(recorded, ["remove loop argument"]) # Test public attributes. def test_local_address(self): get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.local_address, ("host", 4312)) get_extra_info.assert_called_with("sockname") def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. _transport = self.protocol.transport del self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: self.protocol.transport = _transport def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.remote_address, ("host", 4312)) get_extra_info.assert_called_with("peername") def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. _transport = self.protocol.transport del self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: self.protocol.transport = _transport def test_open(self): self.assertTrue(self.protocol.open) self.close_connection() self.assertFalse(self.protocol.open) def test_closed(self): self.assertFalse(self.protocol.closed) self.close_connection() self.assertTrue(self.protocol.closed) def test_wait_closed(self): wait_closed = self.loop.create_task(self.protocol.wait_closed()) self.assertFalse(wait_closed.done()) self.close_connection() self.assertTrue(wait_closed.done()) def test_close_code(self): self.close_connection(CloseCode.GOING_AWAY, "Bye!") self.assertEqual(self.protocol.close_code, CloseCode.GOING_AWAY) def test_close_reason(self): self.close_connection(CloseCode.GOING_AWAY, "Bye!") self.assertEqual(self.protocol.close_reason, "Bye!") def test_close_code_not_set(self): self.assertIsNone(self.protocol.close_code) def test_close_reason_not_set(self): self.assertIsNone(self.protocol.close_reason) # Test the recv coroutine. def test_recv_text(self): self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") def test_recv_binary(self): self.receive_frame(Frame(True, OP_BINARY, b"tea")) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b"tea") def test_recv_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(close_task) # cleanup def test_recv_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) def test_recv_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, "café".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.INVALID_DATA, "") def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) def test_recv_binary_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b"tea" * 342) def test_recv_queue_empty(self): recv = self.loop.create_task(self.protocol.recv()) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete( asyncio.wait_for(asyncio.shield(recv), timeout=MS) ) self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(recv) self.assertEqual(data, "café") def test_recv_queue_full(self): self.protocol.max_queue = 2 # Test internals because it's hard to verify buffers from the outside. self.assertEqual(list(self.protocol.messages), []) self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café"]) self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) self.receive_frame(Frame(True, OP_BINARY, b"milk")) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) self.loop.run_until_complete(self.protocol.recv()) self.run_loop_once() self.assertEqual(list(self.protocol.messages), [b"tea", b"milk"]) self.loop.run_until_complete(self.protocol.recv()) self.run_loop_once() self.assertEqual(list(self.protocol.messages), [b"milk"]) self.loop.run_until_complete(self.protocol.recv()) self.run_loop_once() self.assertEqual(list(self.protocol.messages), []) def test_recv_queue_no_limit(self): self.protocol.max_queue = None for _ in range(100): self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() # Incoming message queue can contain at least 100 messages. self.assertEqual(list(self.protocol.messages), ["café"] * 100) for _ in range(100): self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(list(self.protocol.messages), []) def test_recv_other_error(self): async def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message self.process_invalid_frames() self.assertConnectionFailed(CloseCode.INTERNAL_ERROR, "") def test_recv_canceled(self): recv = self.loop.create_task(self.protocol.recv()) self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") def test_recv_canceled_race_condition(self): recv = self.loop.create_task( asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) ) self.loop.call_soon(self.receive_frame, Frame(True, OP_TEXT, "café".encode())) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(recv) # The previous frame doesn't disappear in a vacuum (it used to). self.receive_frame(Frame(True, OP_TEXT, "tea".encode())) data = self.loop.run_until_complete(self.protocol.recv()) # If we're getting "tea" there, it means "café" was swallowed (ha, ha). self.assertEqual(data, "café") def test_recv_when_transfer_data_cancelled(self): # Clog incoming queue. self.protocol.max_queue = 1 self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() # Flow control kicks in (check with an implementation detail). self.assertFalse(self.protocol._put_message_waiter.done()) # Schedule recv(). recv = self.loop.create_task(self.protocol.recv()) # Cancel transfer_data_task (again, implementation detail). self.protocol.fail_connection() self.run_loop_once() self.assertTrue(self.protocol.transfer_data_task.cancelled()) # recv() completes properly. self.assertEqual(self.loop.run_until_complete(recv), "café") def test_recv_prevents_concurrent_calls(self): recv = self.loop.create_task(self.protocol.recv()) with self.assertRaises(RuntimeError) as raised: self.loop.run_until_complete(self.protocol.recv()) self.assertEqual( str(raised.exception), "cannot call recv while another coroutine " "is already waiting for the next message", ) recv.cancel() # Test the send coroutine. def test_send_text(self): self.loop.run_until_complete(self.protocol.send("café")) self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b"tea")) self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.send(bytearray(b"tea"))) self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_dict(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) self.assertNoFrameSent() def test_send_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() def test_send_iterable_text(self): self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) self.assertFramesSent( (False, OP_TEXT, "ca".encode()), (False, OP_CONT, "fé".encode()), (True, OP_CONT, "".encode()), ) def test_send_iterable_binary(self): self.loop.run_until_complete(self.protocol.send([b"te", b"a"])) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_iterable_binary_from_bytearray(self): self.loop.run_until_complete( self.protocol.send([bytearray(b"te"), bytearray(b"a")]) ) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_iterable_binary_from_memoryview(self): self.loop.run_until_complete( self.protocol.send([memoryview(b"te"), memoryview(b"a")]) ) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_empty_iterable(self): self.loop.run_until_complete(self.protocol.send([])) self.assertNoFrameSent() def test_send_iterable_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send([42])) self.assertNoFrameSent() def test_send_iterable_mixed_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_iterable_prevents_concurrent_send(self): self.make_drain_slow(2 * MS) async def send_iterable(): await self.protocol.send(["ca", "fé"]) async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") async def run_concurrently(): await asyncio.gather( send_iterable(), send_concurrent(), ) self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( (False, OP_TEXT, "ca".encode()), (False, OP_CONT, "fé".encode()), (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) self.assertFramesSent( (False, OP_TEXT, "ca".encode()), (False, OP_CONT, "fé".encode()), (True, OP_CONT, "".encode()), ) def test_send_async_iterable_binary(self): self.loop.run_until_complete(self.protocol.send(async_iterable([b"te", b"a"]))) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_async_iterable_binary_from_bytearray(self): self.loop.run_until_complete( self.protocol.send(async_iterable([bytearray(b"te"), bytearray(b"a")])) ) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_async_iterable_binary_from_memoryview(self): self.loop.run_until_complete( self.protocol.send(async_iterable([memoryview(b"te"), memoryview(b"a")])) ) self.assertFramesSent( (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_empty_async_iterable(self): self.loop.run_until_complete(self.protocol.send(async_iterable([]))) self.assertNoFrameSent() def test_send_async_iterable_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(async_iterable([42]))) self.assertNoFrameSent() def test_send_async_iterable_mixed_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete( self.protocol.send(async_iterable(["café", b"tea"])) ) self.assertFramesSent( (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_async_iterable_prevents_concurrent_send(self): self.make_drain_slow(2 * MS) async def send_async_iterable(): await self.protocol.send(async_iterable(["ca", "fé"])) async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") async def run_concurrently(): await asyncio.gather( send_async_iterable(), send_concurrent(), ) self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( (False, OP_TEXT, "ca".encode()), (False, OP_CONT, "fé".encode()), (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_send_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() def test_send_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() # Test the ping coroutine. def test_ping_default(self): self.loop.run_until_complete(self.protocol.ping()) # With our testing tools, it's more convenient to extract the expected # ping data from the library's internals than from the frame sent. ping_data = next(iter(self.protocol.pings)) self.assertIsInstance(ping_data, bytes) self.assertEqual(len(ping_data), 4) self.assertOneFrameSent(True, OP_PING, ping_data) def test_ping_text(self): self.loop.run_until_complete(self.protocol.ping("café")) self.assertOneFrameSent(True, OP_PING, "café".encode()) def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b"tea")) self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.ping(bytearray(b"tea"))) self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) self.assertNoFrameSent() def test_ping_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_ping_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() def test_ping_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() # Test the pong coroutine. def test_pong_default(self): self.loop.run_until_complete(self.protocol.pong()) self.assertOneFrameSent(True, OP_PONG, b"") def test_pong_text(self): self.loop.run_until_complete(self.protocol.pong("café")) self.assertOneFrameSent(True, OP_PONG, "café".encode()) def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b"tea")) self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.pong(bytearray(b"tea"))) self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) self.assertNoFrameSent() def test_pong_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_pong_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() def test_pong_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() # Test the protocol's logic for acknowledging pings with pongs. def test_answer_ping(self): self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b"test") def test_answer_ping_does_not_crash_if_connection_closing(self): close_task = self.half_close_connection_local() self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) self.loop.run_until_complete(close_task) # cleanup def test_answer_ping_does_not_crash_if_connection_closed(self): self.make_drain_slow() # Drop the connection right after receiving a ping frame, # which prevents responding with a pong frame properly. self.receive_frame(Frame(True, OP_PING, b"test")) self.receive_eof() self.run_loop_once() with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b"test")) self.run_loop_once() self.assertNoFrameSent() def test_acknowledge_ping(self): pong_waiter = self.loop.run_until_complete(self.protocol.ping()) self.assertFalse(pong_waiter.done()) ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() self.assertTrue(pong_waiter.done()) def test_abort_ping(self): pong_waiter = self.loop.run_until_complete(self.protocol.ping()) # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.assertFalse(pong_waiter.done()) self.close_connection() self.assertTrue(pong_waiter.done()) self.assertIsInstance(pong_waiter.exception(), ConnectionClosed) def test_abort_ping_does_not_log_exception_if_not_retreived(self): self.loop.run_until_complete(self.protocol.ping()) # Get the internal Future, which isn't directly returned by ping(). ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.close_connection() # Check a private attribute, for lack of a better solution. self.assertFalse(pong_waiter._log_traceback) def test_acknowledge_previous_pings(self): pings = [ (self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame()) for i in range(3) ] # Unsolicited pong doesn't acknowledge pings self.receive_frame(Frame(True, OP_PONG, b"")) self.run_loop_once() self.run_loop_once() self.assertFalse(pings[0][0].done()) self.assertFalse(pings[1][0].done()) self.assertFalse(pings[2][0].done()) # Pong acknowledges all previous pings self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) self.run_loop_once() self.run_loop_once() self.assertTrue(pings[0][0].done()) self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) def test_acknowledge_aborted_ping(self): pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() # Clog incoming queue. This lets connection_lost() abort pending pings # with a ConnectionClosed exception before transfer_data_task # terminates and close_connection cancels keepalive_ping_task. self.protocol.max_queue = 1 self.receive_frame(Frame(True, OP_TEXT, b"1")) self.receive_frame(Frame(True, OP_TEXT, b"2")) # Add pong frame to the queue. pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) # Connection drops. self.receive_eof() self.loop.run_until_complete(self.protocol.wait_closed()) # Ping receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): pong_waiter.result() # transfer_data doesn't crash, which would be logged. with self.assertNoLogs("websockets", logging.ERROR): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) def test_canceled_ping(self): pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() pong_waiter.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() self.assertTrue(pong_waiter.cancelled()) def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertOneFrameSent(True, OP_PING, b"foobar") with self.assertRaises(RuntimeError): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() # Test the protocol's logic for measuring latency def test_record_latency_on_pong(self): self.assertEqual(self.protocol.latency, 0) self.loop.run_until_complete(self.protocol.ping(b"test")) self.receive_frame(Frame(True, OP_PONG, b"test")) self.run_loop_once() self.assertGreater(self.protocol.latency, 0) def test_return_latency_on_pong(self): pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) latency = self.loop.run_until_complete(pong_waiter) self.assertGreater(latency, 0) # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") def test_fragmented_binary(self): self.receive_frame(Frame(False, OP_BINARY, b"t")) self.receive_frame(Frame(False, OP_CONT, b"e")) self.receive_frame(Frame(True, OP_CONT, b"a")) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b"tea") def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) def test_fragmented_binary_no_max_size(self): self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b"tea" * 342) def test_control_frame_within_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_PING, b"")) self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") self.assertOneFrameSent(True, OP_PONG, b"") def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_CLOSE, b"")) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames # can be interjected in the middle of a fragmented message and that a # close frame must be echoed. Even though there's an unterminated # message, technically, the closing handshake was successful. self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") # Test miscellaneous code paths to ensure full coverage. def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. self.protocol.state = State.CONNECTING with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.ensure_open()) def test_ensure_open_during_unclean_close(self): # Process connection_made in order to start transfer_data_task. self.run_loop_once() # Ensure the test terminates quickly. self.loop.call_later(MS, self.receive_eof_if_client) # Simulate the case when close() times out sending a close frame. self.protocol.fail_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ensure_open()) def test_legacy_recv(self): # By default legacy_recv in disabled. self.assertEqual(self.protocol.legacy_recv, False) self.close_connection() # Enable legacy_recv. self.protocol.legacy_recv = True # Now recv() returns None instead of raising ConnectionClosed. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( self, ping_interval=3 * MS, ping_timeout=3 * MS, ): initial_protocol = self.protocol # copied from tearDown self.transport.close() self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings async def create_protocol(): return WebSocketCommonProtocol( ping_interval=ping_interval, ping_timeout=ping_timeout, ) self.protocol = self.loop.run_until_complete(create_protocol()) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client self.protocol.side = initial_protocol.side def test_keepalive_ping(self): self.restart_protocol_with_keepalive_ping() # Ping is sent at 3ms and acknowledged at 4ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) self.receive_frame(Frame(True, OP_PONG, ping_1)) # Next ping is sent at 7ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_2,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_2) # The keepalive ping task goes on. self.assertFalse(self.protocol.keepalive_ping_task.done()) def test_keepalive_ping_not_acknowledged_closes_connection(self): self.restart_protocol_with_keepalive_ping() # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertOneFrameSent( True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "keepalive ping timeout").serialize(), ) # The keepalive ping task is complete. self.assertEqual(self.protocol.keepalive_ping_task.result(), None) def test_keepalive_ping_stops_when_connection_closing(self): self.restart_protocol_with_keepalive_ping() close_task = self.half_close_connection_local() # No ping sent at 3ms because the closing handshake is in progress. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertNoFrameSent() # The keepalive ping task terminated. self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) self.loop.run_until_complete(close_task) # cleanup def test_keepalive_ping_stops_when_connection_closed(self): self.restart_protocol_with_keepalive_ping() self.close_connection() # The keepalive ping task terminated. self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) def test_keepalive_ping_does_not_crash_when_connection_lost(self): self.restart_protocol_with_keepalive_ping() # Clog incoming queue. This lets connection_lost() abort pending pings # with a ConnectionClosed exception before transfer_data_task # terminates and close_connection cancels keepalive_ping_task. self.protocol.max_queue = 1 self.receive_frame(Frame(True, OP_TEXT, b"1")) self.receive_frame(Frame(True, OP_TEXT, b"2")) # Ping is sent at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Connection drops. self.receive_eof() self.loop.run_until_complete(self.protocol.wait_closed()) # The ping waiter receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): pong_waiter.result() # The keepalive ping task terminated properly. self.assertIsNone(self.protocol.keepalive_ping_task.result()) # Unclog incoming queue to terminate the test quickly. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) def test_keepalive_ping_with_no_ping_interval(self): self.restart_protocol_with_keepalive_ping(ping_interval=None) # No ping is sent at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertNoFrameSent() def test_keepalive_ping_with_no_ping_timeout(self): self.restart_protocol_with_keepalive_ping(ping_timeout=None) # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) # Next ping is sent at 7ms anyway. self.loop.run_until_complete(asyncio.sleep(4 * MS)) ping_1_again, ping_2 = tuple(self.protocol.pings) self.assertEqual(ping_1, ping_1_again) self.assertOneFrameSent(True, OP_PING, ping_2) # The keepalive ping task goes on. self.assertFalse(self.protocol.keepalive_ping_task.done()) def test_keepalive_ping_unexpected_error(self): self.restart_protocol_with_keepalive_ping() async def ping(): raise Exception("BOOM") self.protocol.ping = ping # The keepalive ping task fails when sending a ping at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) # The keepalive ping task is complete. # It logs and swallows the exception. self.assertEqual(self.protocol.keepalive_ping_task.result(), None) # Test the protocol logic for closing the connection. def test_local_close(self): # Emulate how the remote endpoint answers the closing handshake. self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close(self): # Emulate how the remote endpoint initiates the closing handshake. self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) # Wait for some data in order to process the handshake. # After recv() raises ConnectionClosed, the connection is closed. with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close_and_connection_lost(self): self.make_drain_slow() # Drop the connection right after receiving a close frame, # which prevents echoing the close frame properly. self.receive_frame(self.close_frame) self.receive_eof() self.run_loop_once() with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) def test_simultaneous_close(self): # Receive the incoming close frame right after self.protocol.close() # starts executing. This reproduces the error described in: # https://github.com/python-websockets/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="local")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "remote") # The current implementation sends a close frame in response to the # close frame received from the remote end. It skips the close frame # that should be sent as a result of calling close(). self.assertOneFrameSent(*self.remote_close) def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b"hello")) self.run_loop_once() self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(next_message, "hello") def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b"\x00") self.receive_frame(invalid_close_frame) self.receive_eof_if_client() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_connection_lost(self): self.receive_eof() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_during_recv(self): recv = self.loop.create_task(self.protocol.recv()) self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason="close")) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_remote_close_during_recv because it would be identical # to test_remote_close. def test_remote_close_during_send(self): self.make_drain_slow() send = self.loop.create_task(self.protocol.send("hello")) self.receive_frame(self.close_frame) self.receive_eof() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. def test_broadcast_text(self): broadcast([self.protocol], "café") self.assertOneFrameSent(True, OP_TEXT, "café".encode()) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) def test_broadcast_text_reports_no_errors(self): broadcast([self.protocol], "café", raise_exceptions=True) self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_broadcast_binary(self): broadcast([self.protocol], b"tea") self.assertOneFrameSent(True, OP_BINARY, b"tea") @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) def test_broadcast_binary_reports_no_errors(self): broadcast([self.protocol], b"tea", raise_exceptions=True) self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_broadcast_type_error(self): with self.assertRaises(TypeError): broadcast([self.protocol], ["ca", "fé"]) def test_broadcast_no_clients(self): broadcast([], "café") self.assertNoFrameSent() def test_broadcast_two_clients(self): broadcast([self.protocol, self.protocol], "café") self.assertFramesSent( (True, OP_TEXT, "café".encode()), (True, OP_TEXT, "café".encode()), ) def test_broadcast_skips_closed_connection(self): self.close_connection() with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertRaises(ExceptionGroup) as raised: broadcast([self.protocol], "café", raise_exceptions=True) self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") self.assertEqual( str(raised.exception.exceptions[0]), "sending a fragmented message" ) def test_broadcast_skips_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message: RuntimeError: BOOM"], ) @unittest.skipIf( sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertRaises(ExceptionGroup) as raised: broadcast([self.protocol], "café", raise_exceptions=True) self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") self.assertEqual(str(raised.exception.exceptions[0]), "failed to write message") self.assertEqual(str(raised.exception.exceptions[0].__cause__), "BOOM") class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): super().setUp() self.protocol.is_client = False self.protocol.side = "server" def test_local_close_send_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof(), time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 20ms. # Check the timing within -1/+9ms for robustness. # Add another 10ms because this test is flaky and I don't understand. with self.assertCompletesWithin(19 * MS, 39 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) class ClientTests(CommonTests, AsyncioTestCase): def setUp(self): super().setUp() self.protocol.is_client = True self.protocol.side = "client" def test_local_close_send_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 20ms. # - 10ms waiting for sending a close frame # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS # If the server doesn't send a close frame, time out in 20ms: # - 10ms waiting for receiving a close frame # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS # If the server doesn't half-close its side of the TCP connection # after we send a close frame, time out in 20ms: # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 30ms. # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # - 10ms waiting for receiving a close after close # Check the timing within -1/+9ms for robustness. # Add another 10ms because this test is flaky and I don't understand. with self.assertCompletesWithin(29 * MS, 49 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) websockets-15.0.1/tests/legacy/utils.py000066400000000000000000000046641476212450300201050ustar00rootroot00000000000000import asyncio import functools import sys import unittest from ..utils import AssertNoLogsMixin class AsyncioTestCase(AssertNoLogsMixin, unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. IsolatedAsyncioTestCase was introduced in Python 3.8 for similar purposes but isn't a drop-in replacement. """ def __init_subclass__(cls, **kwargs): """ Convert test coroutines to test functions. This supports asynchronous tests transparently. """ super().__init_subclass__(**kwargs) for name in unittest.defaultTestLoader.getTestCaseNames(cls): test = getattr(cls, name) if asyncio.iscoroutinefunction(test): setattr(cls, name, cls.convert_async_to_sync(test)) @staticmethod def convert_async_to_sync(test): """ Convert a test coroutine to a test function. """ @functools.wraps(test) def test_func(self, *args, **kwargs): return self.loop.run_until_complete(test(self, *args, **kwargs)) return test_func def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() super().tearDown() def run_loop_once(self): # Process callbacks scheduled with call_soon by appending a callback # to stop the event loop then running it until it hits that callback. self.loop.call_soon(self.loop.stop) self.loop.run_forever() def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ Check recorded deprecation warnings match a list of expected messages. """ # Work around https://github.com/python/cpython/issues/90476. if sys.version_info[:2] < (3, 11): # pragma: no cover recorded_warnings = [ recorded for recorded in recorded_warnings if not ( type(recorded.message) is ResourceWarning and str(recorded.message).startswith("unclosed transport") ) ] for recorded in recorded_warnings: self.assertIs(type(recorded.message), DeprecationWarning) self.assertEqual( {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), ) websockets-15.0.1/tests/maxi_cov.py000077500000000000000000000120031476212450300172730ustar00rootroot00000000000000#!/usr/bin/env python """Measure coverage of each module by its test module.""" import glob import os.path import subprocess import sys UNMAPPED_SRC_FILES = [ "websockets/typing.py", "websockets/version.py", ] UNMAPPED_TEST_FILES = [ "tests/test_exports.py", ] def check_environment(): """Check that prerequisites for running this script are met.""" try: import websockets # noqa: F401 except ImportError: print("failed to import websockets; is src on PYTHONPATH?") return False try: import coverage # noqa: F401 except ImportError: print("failed to locate Coverage.py; is it installed?") return False return True def get_mapping(src_dir="src"): """Return a dict mapping each source file to its test file.""" # List source and test files. src_files = glob.glob( os.path.join(src_dir, "websockets/**/*.py"), recursive=True, ) test_files = glob.glob( "tests/**/*.py", recursive=True, ) src_files = [ os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) if "legacy" not in os.path.dirname(src_file) and os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" and os.path.basename(src_file) != "async_timeout.py" and os.path.basename(src_file) != "compatibility.py" ] test_files = [ test_file for test_file in sorted(test_files) if "legacy" not in os.path.dirname(test_file) and os.path.basename(test_file) != "__init__.py" and os.path.basename(test_file).startswith("test_") ] # Map source files to test files. mapping = {} unmapped_test_files = set() for test_file in test_files: dir_name, file_name = os.path.split(test_file) assert dir_name.startswith("tests") assert file_name.startswith("test_") src_file = os.path.join( "websockets" + dir_name[len("tests") :], file_name[len("test_") :], ) if src_file in src_files: mapping[src_file] = test_file else: unmapped_test_files.add(test_file) unmapped_src_files = set(src_files) - set(mapping) # Ensure that all files are mapped. assert unmapped_src_files == set(UNMAPPED_SRC_FILES) assert unmapped_test_files == set(UNMAPPED_TEST_FILES) return mapping def get_ignored_files(src_dir="src"): """Return the list of files to exclude from coverage measurement.""" # */websockets matches src/websockets and .tox/**/site-packages/websockets. return [ # There are no tests for the __main__ module. "*/websockets/__main__.py", # There is nothing to test on type declarations. "*/websockets/typing.py", # We don't test compatibility modules with previous versions of Python # or websockets (import locations). "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", "tests/legacy/*", ] + [ # Exclude test utilities that are shared between several test modules. # Also excludes this script. test_file for test_file in sorted(glob.glob("tests/**/*.py", recursive=True)) if "legacy" not in os.path.dirname(test_file) and os.path.basename(test_file) != "__init__.py" and not os.path.basename(test_file).startswith("test_") ] def run_coverage(mapping, src_dir="src"): # Initialize a new coverage measurement session. The --source option # includes all files in the report, even if they're never imported. print("\nInitializing session\n", flush=True) subprocess.run( [ sys.executable, "-m", "coverage", "run", "--source", ",".join([os.path.join(src_dir, "websockets"), "tests"]), "--omit", ",".join(get_ignored_files(src_dir)), "-m", "unittest", ] + list(UNMAPPED_TEST_FILES), check=True, ) # Append coverage of each source module by the corresponding test module. for src_file, test_file in mapping.items(): print(f"\nTesting {src_file} with {test_file}\n", flush=True) subprocess.run( [ sys.executable, "-m", "coverage", "run", "--append", "--include", ",".join([os.path.join(src_dir, src_file), test_file]), "-m", "unittest", test_file, ], check=True, ) if __name__ == "__main__": if not check_environment(): sys.exit(1) src_dir = sys.argv[1] if len(sys.argv) == 2 else "src" mapping = get_mapping(src_dir) run_coverage(mapping, src_dir) websockets-15.0.1/tests/protocol.py000066400000000000000000000013771476212450300173400ustar00rootroot00000000000000from websockets.protocol import Protocol class RecordingProtocol(Protocol): """ Protocol subclass that records incoming frames. By interfacing with this protocol, you can check easily what the component being testing sends during a test. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.frames_rcvd = [] def get_frames_rcvd(self): """ Get incoming frames received up to this point. Calling this method clears the list. Each frame is returned only once. """ frames_rcvd, self.frames_rcvd = self.frames_rcvd, [] return frames_rcvd def recv_frame(self, frame): self.frames_rcvd.append(frame) super().recv_frame(frame) websockets-15.0.1/tests/proxy.py000066400000000000000000000114471476212450300166570ustar00rootroot00000000000000import asyncio import pathlib import ssl import threading import warnings try: # Ignore deprecation warnings raised by mitmproxy dependencies at import time. warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") from mitmproxy import ctx from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig from mitmproxy.http import Response from mitmproxy.master import Master from mitmproxy.options import CONF_BASENAME, CONF_DIR, Options except ImportError: pass class RecordFlows: def __init__(self, on_running): self.running = on_running self.http_connects = [] self.tcp_flows = [] def http_connect(self, flow): self.http_connects.append(flow) def tcp_start(self, flow): self.tcp_flows.append(flow) def get_http_connects(self): http_connects, self.http_connects[:] = self.http_connects[:], [] return http_connects def get_tcp_flows(self): tcp_flows, self.tcp_flows[:] = self.tcp_flows[:], [] return tcp_flows def reset(self): self.http_connects = [] self.tcp_flows = [] class AlterRequest: def load(self, loader): loader.add_option( name="break_http_connect", typespec=bool, default=False, help="Respond to HTTP CONNECT requests with a 999 status code.", ) loader.add_option( name="close_http_connect", typespec=bool, default=False, help="Do not respond to HTTP CONNECT requests.", ) def http_connect(self, flow): if ctx.options.break_http_connect: # mitmproxy can send a response with a status code not between 100 # and 599, while websockets treats it as a protocol error. # This is used for testing HTTP parsing errors. flow.response = Response.make(999, "not a valid HTTP response") if ctx.options.close_http_connect: flow.kill() class ProxyMixin: """ Run mitmproxy in a background thread. While it's uncommon to run two event loops in two threads, tests for the asyncio implementation rely on this class too because it starts an event loop for mitm proxy once, then a new event loop for each test. """ proxy_mode = None @classmethod async def run_proxy(cls): cls.proxy_loop = loop = asyncio.get_event_loop() cls.proxy_stop = stop = loop.create_future() cls.proxy_options = options = Options( mode=[cls.proxy_mode], # Don't intercept connections, but record them. ignore_hosts=["^localhost:", "^127.0.0.1:", "^::1:"], # This option requires mitmproxy 11.0.0, which requires Python 3.11. show_ignored_hosts=True, ) cls.proxy_master = master = Master(options) master.addons.add( core.Core(), proxyauth.ProxyAuth(), proxyserver.Proxyserver(), next_layer.NextLayer(), tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), AlterRequest(), ) task = loop.create_task(cls.proxy_master.run()) await stop for server in master.addons.get("proxyserver").servers: await server.stop() master.shutdown() await task @classmethod def setUpClass(cls): super().setUpClass() # Ignore deprecation warnings raised by mitmproxy at run time. warnings.filterwarnings( "ignore", category=DeprecationWarning, module="mitmproxy" ) cls.proxy_ready = threading.Event() cls.proxy_thread = threading.Thread(target=asyncio.run, args=(cls.run_proxy(),)) cls.proxy_thread.start() cls.proxy_ready.wait() certificate = pathlib.Path(CONF_DIR) / f"{CONF_BASENAME}-ca-cert.pem" certificate = certificate.expanduser() cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) cls.proxy_context.load_verify_locations(bytes(certificate)) def get_http_connects(self): return self.proxy_master.addons.get("recordflows").get_http_connects() def get_tcp_flows(self): return self.proxy_master.addons.get("recordflows").get_tcp_flows() def assertNumFlows(self, num_tcp_flows): self.assertEqual(len(self.get_tcp_flows()), num_tcp_flows) def tearDown(self): record_tcp_flows = self.proxy_master.addons.get("recordflows") record_tcp_flows.reset() super().tearDown() @classmethod def tearDownClass(cls): cls.proxy_loop.call_soon_threadsafe(cls.proxy_stop.set_result, None) cls.proxy_thread.join() super().tearDownClass() websockets-15.0.1/tests/requirements.txt000066400000000000000000000000401476212450300203730ustar00rootroot00000000000000python-socks[asyncio] mitmproxy websockets-15.0.1/tests/sync/000077500000000000000000000000001476212450300160715ustar00rootroot00000000000000websockets-15.0.1/tests/sync/__init__.py000066400000000000000000000000001476212450300201700ustar00rootroot00000000000000websockets-15.0.1/tests/sync/connection.py000066400000000000000000000055531476212450300206120ustar00rootroot00000000000000import contextlib import time from websockets.sync.connection import Connection class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.socket = InterceptingSocket(self.socket) @contextlib.contextmanager def delay_frames_sent(self, delay): """ Add a delay before sending frames. Delays cumulate: they're added before every frame or before EOF. """ assert self.socket.delay_sendall is None self.socket.delay_sendall = delay try: yield finally: self.socket.delay_sendall = None @contextlib.contextmanager def delay_eof_sent(self, delay): """ Add a delay before sending EOF. Delays cumulate: they're added before every frame or before EOF. """ assert self.socket.delay_shutdown is None self.socket.delay_shutdown = delay try: yield finally: self.socket.delay_shutdown = None @contextlib.contextmanager def drop_frames_sent(self): """ Prevent frames from being sent. Since TCP is reliable, sending frames or EOF afterwards is unrealistic. """ assert not self.socket.drop_sendall self.socket.drop_sendall = True try: yield finally: self.socket.drop_sendall = False @contextlib.contextmanager def drop_eof_sent(self): """ Prevent EOF from being sent. Since TCP is reliable, sending frames or EOF afterwards is unrealistic. """ assert not self.socket.drop_shutdown self.socket.drop_shutdown = True try: yield finally: self.socket.drop_shutdown = False class InterceptingSocket: """ Socket wrapper that intercepts calls to ``sendall()`` and ``shutdown()``. This is coupled to the implementation, which relies on these two methods. """ def __init__(self, socket): self.socket = socket self.delay_sendall = None self.delay_shutdown = None self.drop_sendall = False self.drop_shutdown = False def __getattr__(self, name): return getattr(self.socket, name) def sendall(self, bytes, flags=0): if self.delay_sendall is not None: time.sleep(self.delay_sendall) if not self.drop_sendall: self.socket.sendall(bytes, flags) def shutdown(self, how): if self.delay_shutdown is not None: time.sleep(self.delay_shutdown) if not self.drop_shutdown: self.socket.shutdown(how) websockets-15.0.1/tests/sync/server.py000066400000000000000000000054751476212450300177640ustar00rootroot00000000000000import contextlib import ssl import threading import urllib.parse from websockets.sync.router import * from websockets.sync.server import * def get_uri(server, secure=None): if secure is None: secure = isinstance(server.socket, ssl.SSLSocket) # hack protocol = "wss" if secure else "ws" host, port = server.socket.getsockname() return f"{protocol}://{host}:{port}" def handler(ws): path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. for expr in ws: value = eval(expr) ws.send(str(value)) elif path == "/crash": raise RuntimeError elif path == "/no-op": pass else: raise AssertionError(f"unexpected path: {path}") class EvalShellMixin: def assertEval(self, client, expr, value): client.send(expr) self.assertEqual(client.recv(), value) @contextlib.contextmanager def run_server_or_router( serve_or_route, handler_or_url_map, host="localhost", port=0, **kwargs, ): with serve_or_route(handler_or_url_map, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() # HACK: since the sync server doesn't track connections (yet), we record # a reference to the thread handling the most recent connection, then we # can wait for that thread to terminate when exiting the context. handler_thread = None original_handler = server.handler def handler(sock, addr): nonlocal handler_thread handler_thread = threading.current_thread() original_handler(sock, addr) server.handler = handler try: yield server finally: server.shutdown() thread.join() # HACK: wait for the thread handling the most recent connection. if handler_thread is not None: handler_thread.join() def run_server(handler=handler, **kwargs): return run_server_or_router(serve, handler, **kwargs) def run_router(url_map, **kwargs): return run_server_or_router(route, url_map, **kwargs) @contextlib.contextmanager def run_unix_server_or_router( path, unix_serve_or_route, handler_or_url_map, **kwargs, ): with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: yield server finally: server.shutdown() thread.join() def run_unix_server(path, handler=handler, **kwargs): return run_unix_server_or_router(path, unix_serve, handler, **kwargs) def run_unix_router(path, url_map, **kwargs): return run_unix_server_or_router(path, unix_route, url_map, **kwargs) websockets-15.0.1/tests/sync/test_client.py000066400000000000000000000746461476212450300210010ustar00rootroot00000000000000import http import logging import os import socket import socketserver import ssl import sys import threading import time import unittest from unittest.mock import patch from websockets.exceptions import ( InvalidHandshake, InvalidMessage, InvalidProxy, InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, DeprecationTestCase, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server class ClientTests(unittest.TestCase): def test_connection(self): """Client connects to server and the handshake succeeds.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_compression_is_enabled(self): """Client enables compression by default.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], ) def test_disable_compression(self): """Client disables compression.""" with run_server() as server: with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) def test_additional_headers(self): """Client can set additional headers with additional_headers.""" with run_server() as server: with connect( get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" with run_server() as server: with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" with run_server() as server: with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) def test_legacy_user_agent(self): """Client can override User-Agent header with additional_headers.""" with run_server() as server: with connect( get_uri(server), additional_headers={"User-Agent": "Smith"} ) as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" with run_server() as server: with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) time.sleep(2 * MS) self.assertGreater(client.latency, 0) def test_disable_keepalive(self): """Client disables keepalive.""" with run_server() as server: with connect(get_uri(server), ping_interval=None) as client: time.sleep(2 * MS) self.assertEqual(client.latency, 0) def test_logger(self): """Client accepts a logger argument.""" logger = logging.getLogger("test") with run_server() as server: with connect(get_uri(server), logger=logger) as client: self.assertEqual(client.logger.name, logger.name) def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" def create_connection(*args, **kwargs): client = ClientConnection(*args, **kwargs) client.create_connection_ran = True return client with run_server() as server: with connect( get_uri(server), create_connection=create_connection ) as client: self.assertTrue(client.create_connection_ran) def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): with connect("http://localhost"): # invalid scheme self.fail("did not raise") def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") def test_handshake_fails(self): """Client connects to server but the handshake fails.""" def remove_accept_header(self, request, response): del response.headers["Sec-WebSocket-Accept"] # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. with run_server(process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "missing Sec-WebSocket-Accept header", ) def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: with connect(f"ws://{host}:{port}", open_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out while waiting for handshake response", ) def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() with run_server(process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "did not receive a valid HTTP response", ) self.assertIsInstance(raised.exception.__cause__, EOFError) self.assertEqual( str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) def test_http_response(self): """Client reads HTTP response.""" def http_response(connection, request): return connection.respond(http.HTTPStatus.OK, "👌") with run_server(process_request=http_response) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual(raised.exception.response.status_code, 200) self.assertEqual(raised.exception.response.body.decode(), "👌") def test_http_response_without_content_length(self): """Client reads HTTP response without a Content-Length header.""" def http_response(connection, request): response = connection.respond(http.HTTPStatus.OK, "👌") del response.headers["Content-Length"] return response with run_server(process_request=http_response) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual(raised.exception.response.status_code, 200) self.assertEqual(raised.exception.response.body.decode(), "👌") def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" class JunkHandler(socketserver.BaseRequestHandler): def handle(self): time.sleep(MS) # wait for the client to send the handshake request self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") self.request.recv(4096) # wait for the client to close the connection self.request.close() server = socketserver.TCPServer(("localhost", 0), JunkHandler) host, port = server.server_address with server: thread = threading.Thread(target=server.serve_forever, args=(MS,)) thread.start() try: with self.assertRaises(InvalidMessage) as raised: with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), "did not receive a valid HTTP response", ) self.assertIsInstance(raised.exception.__cause__, ValueError) self.assertEqual( str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) finally: server.shutdown() thread.join() class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): with unix_connect( path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): with unix_connect( path, ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate is self-signed. with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), ) def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. with connect( get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" ): self.fail("did not raise") self.assertIn( "certificate verify failed: Hostname mismatch", str(raised.exception), ) @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") class SocksProxyClientTests(ProxyMixin, unittest.TestCase): proxy_mode = "socks5@51080" @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError with self.assertRaises(OSError) as raised: with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) def test_socks_proxy_connection_timeout(self): """Client times out while connecting to the SOCKS5 proxy.""" from python_socks import ProxyTimeoutError as SocksProxyTimeoutError # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyTimeoutError) self.assertNumFlows(0) def test_explicit_socks_proxy(self): """Client connects to server through a SOCKS5 proxy set explicitly.""" with run_server() as server: with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"ws_proxy": "http://localhost:58080"}) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "proxy rejected connection: HTTP 407", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" with run_server() as server: with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.protocol.state.name, "OPEN") [http_connect] = self.get_http_connects() self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" with run_server() as server: with connect(get_uri(server), user_agent_header=None) as client: self.assertEqual(client.protocol.state.name, "OPEN") [http_connect] = self.get_http_connects() self.assertNotIn(b"User-Agent", http_connect.request.headers) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) with self.assertRaises(InvalidProxyMessage) as raised: with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( str(raised.exception), "did not receive a valid HTTP response from proxy", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) with self.assertRaises(InvalidProxyMessage) as raised: with connect("ws://example.com/"): self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( str(raised.exception), "did not receive a valid HTTP response from proxy", ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" with self.assertRaises(OSError): with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) def test_http_proxy_connection_timeout(self): """Client times out while connecting to the HTTP proxy.""" # Replace the proxy with a TCP server that does't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out while connecting to HTTP proxy", ) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" with run_server() as server: with connect( get_uri(server), proxy_ssl=self.proxy_context, ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" with run_server(ssl=SERVER_CONTEXT) as server: with connect( get_uri(server), ssl=CLIENT_CONTEXT, proxy_ssl=self.proxy_context, ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" with run_server() as server: # Pass an argument not prefixed with proxy_ for coverage. kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} with connect( get_uri(server), proxy_ssl=self.proxy_context, proxy_server_hostname="overridden", **kwargs, ) as client: self.assertEqual(client.socket.server_hostname, "overridden") self.assertNumFlows(1) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The proxy certificate isn't trusted. with connect("wss://example.com/"): self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) self.assertNumFlows(0) @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate is self-signed. with connect(get_uri(server), proxy_ssl=self.proxy_context): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), ) self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.TestCase): def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_set_host_header(self): """Client sets the Host header to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path): with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): with unix_connect( path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" with self.assertRaises(ValueError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), "ssl argument is incompatible with a ws:// URI", ) def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" with self.assertRaises(ValueError) as raised: connect( "ws://localhost/", proxy="http://localhost:8080", proxy_ssl=True, ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", ) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: unix_connect() self.assertEqual( str(raised.exception), "missing path argument", ) def test_unsupported_proxy(self): """Client rejects unsupported proxy.""" with self.assertRaises(InvalidProxy) as raised: with connect("ws://example.com/", proxy="other://localhost:58080"): self.fail("did not raise") self.assertEqual( str(raised.exception), "other://localhost:58080 isn't a valid proxy: scheme other isn't supported", ) def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock arguments are incompatible", ) def test_invalid_subprotocol(self): """Client rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: connect("ws://localhost/", subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", ) def test_unsupported_compression(self): """Client rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: connect("ws://localhost/", compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", ) class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with connect(get_uri(server), ssl_context=CLIENT_CONTEXT): pass websockets-15.0.1/tests/sync/test_connection.py000066400000000000000000001153511476212450300216470ustar00rootroot00000000000000import contextlib import logging import socket import sys import threading import time import unittest import uuid from unittest.mock import patch from websockets.exceptions import ( ConcurrencyError, ConnectionClosedError, ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State from websockets.sync.connection import * from ..protocol import RecordingProtocol from ..utils import MS from .connection import InterceptingConnection # Connection implements symmetrical behavior between clients and servers. # All tests run on the client side and the server side to validate this. class ClientConnectionTests(unittest.TestCase): LOCAL = CLIENT REMOTE = SERVER def setUp(self): socket_, remote_socket = socket.socketpair() protocol = Protocol(self.LOCAL) remote_protocol = RecordingProtocol(self.REMOTE) self.connection = Connection(socket_, protocol, close_timeout=2 * MS) self.remote_connection = InterceptingConnection(remote_socket, remote_protocol) def tearDown(self): self.remote_connection.close() self.connection.close() # Test helpers built upon RecordingProtocol and InterceptingConnection. def assertFrameSent(self, frame): """Check that a single frame was sent.""" time.sleep(MS) # let the remote side process messages self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) def assertNoFrameSent(self): """Check that no frame was sent.""" time.sleep(MS) # let the remote side process messages self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.contextmanager def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield time.sleep(MS) # let the remote side process messages @contextlib.contextmanager def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield time.sleep(MS) # let the remote side process messages @contextlib.contextmanager def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield time.sleep(MS) # let the remote side process messages @contextlib.contextmanager def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield time.sleep(MS) # let the remote side process messages # Test __enter__ and __exit__. def test_enter(self): """__enter__ returns the connection itself.""" with self.connection as connection: self.assertIs(connection, self.connection) def test_exit(self): """__exit__ closes the connection with code 1000.""" with self.connection: self.assertNoFrameSent() self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) def test_exit_with_exception(self): """__exit__ with an exception closes the connection with code 1011.""" with self.assertRaises(RuntimeError): with self.connection: raise RuntimeError self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) # Test __iter__. def test_iter_text(self): """__iter__ yields text messages.""" iterator = iter(self.connection) self.remote_connection.send("😀") self.assertEqual(next(iterator), "😀") self.remote_connection.send("😀") self.assertEqual(next(iterator), "😀") def test_iter_binary(self): """__iter__ yields binary messages.""" iterator = iter(self.connection) self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_mixed(self): """__iter__ yields a mix of text and binary messages.""" iterator = iter(self.connection) self.remote_connection.send("😀") self.assertEqual(next(iterator), "😀") self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_connection_closed_ok(self): """__iter__ terminates after a normal closure.""" iterator = iter(self.connection) self.remote_connection.close() with self.assertRaises(StopIteration): next(iterator) def test_iter_connection_closed_error(self): """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): next(iterator) # Test recv. def test_recv_text(self): """recv receives a text message.""" self.remote_connection.send("😀") self.assertEqual(self.connection.recv(), "😀") def test_recv_binary(self): """recv receives a binary message.""" self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") def test_recv_text_as_bytes(self): """recv receives a text message as bytes.""" self.remote_connection.send("😀") self.assertEqual(self.connection.recv(decode=False), "😀".encode()) def test_recv_binary_as_text(self): """recv receives a binary message as a str.""" self.remote_connection.send("😀".encode()) self.assertEqual(self.connection.recv(decode=True), "😀") def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) self.assertEqual(self.connection.recv(), "😀😀") def test_recv_fragmented_binary(self): """recv receives a fragmented binary message.""" self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") def test_recv_connection_closed_ok(self): """recv raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): self.connection.recv() def test_recv_connection_closed_error(self): """recv raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.recv() def test_recv_non_utf8_text(self): """recv receives a non-UTF-8 text message.""" self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) with self.assertRaises(ConnectionClosedError): self.connection.recv() self.assertFrameSent( Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") ) def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", ) self.remote_connection.send("") recv_thread.join() def test_recv_during_recv_streaming(self): """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", ) self.remote_connection.send("") recv_streaming_thread.join() # Test recv_streaming. def test_recv_streaming_text(self): """recv_streaming receives a text message.""" self.remote_connection.send("😀") self.assertEqual( list(self.connection.recv_streaming()), ["😀"], ) def test_recv_streaming_binary(self): """recv_streaming receives a binary message.""" self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual( list(self.connection.recv_streaming()), [b"\x01\x02\xfe\xff"], ) def test_recv_streaming_text_as_bytes(self): """recv_streaming receives a text message as bytes.""" self.remote_connection.send("😀") self.assertEqual( list(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) def test_recv_streaming_binary_as_str(self): """recv_streaming receives a binary message as a str.""" self.remote_connection.send("😀".encode()) self.assertEqual( list(self.connection.recv_streaming(decode=True)), ["😀"], ) def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.connection.recv_streaming()), ["😀", "😀", ""], ) def test_recv_streaming_fragmented_binary(self): """recv_streaming receives a fragmented binary message.""" self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.connection.recv_streaming()), [b"\x01\x02", b"\xfe\xff", b""], ) def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): for _ in self.connection.recv_streaming(): self.fail("did not raise") def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): for _ in self.connection.recv_streaming(): self.fail("did not raise") def test_recv_streaming_non_utf8_text(self): """recv_streaming receives a non-UTF-8 text message.""" self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) with self.assertRaises(ConnectionClosedError): list(self.connection.recv_streaming()) self.assertFrameSent( Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") ) def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " "is already running recv or recv_streaming", ) self.remote_connection.send("") recv_thread.join() def test_recv_streaming_during_recv_streaming(self): """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another thread " r"is already running recv or recv_streaming", ) self.remote_connection.send("") recv_streaming_thread.join() # Test send. def test_send_text(self): """send sends a text message.""" self.connection.send("😀") self.assertEqual(self.remote_connection.recv(), "😀") def test_send_binary(self): """send sends a binary message.""" self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") def test_send_binary_from_str(self): """send sends a binary message from a str.""" self.connection.send("😀", text=False) self.assertEqual(self.remote_connection.recv(), "😀".encode()) def test_send_text_from_bytes(self): """send sends a text message from bytes.""" self.connection.send("😀".encode(), text=True) self.assertEqual(self.remote_connection.recv(), "😀") def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) def test_send_fragmented_binary(self): """send sends a fragmented binary message.""" self.connection.send([b"\x01\x02", b"\xfe\xff"]) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.remote_connection.recv_streaming()), [b"\x01\x02", b"\xfe\xff", b""], ) def test_send_fragmented_binary_from_str(self): """send sends a fragmented binary message from a str.""" self.connection.send(["😀", "😀"], text=False) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.remote_connection.recv_streaming()), ["😀".encode(), "😀".encode(), b""], ) def test_send_fragmented_text_from_bytes(self): """send sends a fragmented text message from bytes.""" self.connection.send(["😀".encode(), "😀".encode()], text=True) # websockets sends an trailing empty fragment. That's an implementation detail. self.assertEqual( list(self.remote_connection.recv_streaming()), ["😀", "😀", ""], ) def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): self.connection.send("😀") def test_send_connection_closed_error(self): """send raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.send("😀") def test_send_during_send(self): """send raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.remote_connection.recv) recv_thread.start() send_gate = threading.Event() exit_gate = threading.Event() def fragments(): yield "😀" send_gate.set() exit_gate.wait() yield "😀" send_thread = threading.Thread( target=self.connection.send, args=(fragments(),), ) send_thread.start() send_gate.wait() # The check happens in four code paths, depending on the argument. for message in [ "😀", b"\x01\x02\xfe\xff", ["😀", "😀"], [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): with self.assertRaises(ConcurrencyError) as raised: self.connection.send(message) self.assertEqual( str(raised.exception), "cannot call send while another thread is already running send", ) exit_gate.set() send_thread.join() recv_thread.join() def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" self.connection.send([]) self.connection.close() self.assertEqual(list(self.remote_connection), []) def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" with self.assertRaises(TypeError): self.connection.send(["😀", b"\xfe\xff"]) def test_send_unsupported_iterable(self): """send raises TypeError when called with an iterable of unsupported type.""" with self.assertRaises(TypeError): self.connection.send([None]) def test_send_dict(self): """send raises TypeError when called with a dict.""" with self.assertRaises(TypeError): self.connection.send({"type": "object"}) def test_send_unsupported_type(self): """send raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): self.connection.send(None) # Test close. def test_close(self): """close sends a close frame.""" self.connection.close() self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) def test_close_explicit_code_reason(self): """close sends a close frame with a given code and reason.""" self.connection.close(CloseCode.GOING_AWAY, "bye!") self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" with self.delay_frames_rcvd(MS): self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) def test_close_waits_for_connection_closed(self): """close waits for EOF before returning.""" if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") with self.delay_eof_rcvd(MS): self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") self.assertIsInstance(exc.__cause__, TimeoutError) def test_close_timeout_waiting_for_connection_closed(self): """close times out if EOF isn't received.""" if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") with self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) def test_close_preserves_queued_messages(self): """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) self.connection.close() self.assertNoFrameSent() def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" self.connection.close_timeout = 6 * MS def closer(): with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) close_thread.start() # Let closer() initiate the closing handshake and send a close frame. time.sleep(MS) self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) # Connection isn't closed yet. with self.assertRaises(TimeoutError): self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): self.connection.recv(timeout=MS) close_thread.join() def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" def closer(): time.sleep(MS) self.connection.close() close_thread = threading.Thread(target=closer) close_thread.start() with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) close_thread.join() def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() exit_gate = threading.Event() def closer(): close_gate.wait() self.connection.close() exit_gate.set() def fragments(): yield "😀" close_gate.set() exit_gate.wait() yield "😀" close_thread = threading.Thread(target=closer) close_thread.start() with self.assertRaises(ConnectionClosedError) as raised: self.connection.send(fragments()) exc = raised.exception self.assertEqual( str(exc), "sent 1011 (internal error) close during fragmented message; " "no close frame received", ) self.assertIsNone(exc.__cause__) close_thread.join() # Test ping. @patch("random.getrandbits", return_value=1918987876) def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" self.connection.ping() getrandbits.assert_called_once_with(32) self.assertFrameSent(Frame(Opcode.PING, b"rand")) def test_ping_explicit_text(self): """ping sends a ping frame with a payload provided as text.""" self.connection.ping("ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) def test_ping_explicit_binary(self): """ping sends a ping frame with a payload provided as binary.""" self.connection.ping(b"ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("this") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("that") self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_on_close(self): """ping with ack_on_close is acknowledged when the connection is closed.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) pong_waiter = self.connection.ping("that") self.connection.close() self.assertTrue(pong_waiter_ack_on_close.wait(MS)) self.assertFalse(pong_waiter.wait(MS)) def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") self.assertEqual( str(raised.exception), "already waiting for a pong with the same data", ) self.remote_connection.pong("idem") self.assertTrue(pong_waiter.wait(MS)) self.connection.ping("idem") # doesn't raise an exception def test_ping_unsupported_type(self): """ping raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): self.connection.ping([]) # Test pong. def test_pong(self): """pong sends a pong frame.""" self.connection.pong() self.assertFrameSent(Frame(Opcode.PONG, b"")) def test_pong_explicit_text(self): """pong sends a pong frame with a payload provided as text.""" self.connection.pong("pong") self.assertFrameSent(Frame(Opcode.PONG, b"pong")) def test_pong_explicit_binary(self): """pong sends a pong frame with a payload provided as binary.""" self.connection.pong(b"pong") self.assertFrameSent(Frame(Opcode.PONG, b"pong")) def test_pong_unsupported_type(self): """pong raises TypeError when called with an unsupported type.""" with self.assertRaises(TypeError): self.connection.pong([]) # Test keepalive. @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 4 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) # 3 ms: keepalive() sends a ping frame. # 3.x ms: a pong frame is received. time.sleep(4 * MS) # 4 ms: check that the ping frame was sent. self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_thread) @patch("random.getrandbits", return_value=1918987876) def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. time.sleep(2 * MS) # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits", return_value=1918987876) def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection remains open. time.sleep(2 * MS) # 7 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() time.sleep(MS) self.connection.close() self.connection.keepalive_thread.join(MS) self.assertFalse(self.connection.keepalive_thread.is_alive()) def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" self.connection.ping_interval = 1 * MS self.connection.start_keepalive() with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() self.assertFalse(self.connection.keepalive_thread.is_alive()) def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = MS self.connection.ping_timeout = 4 * MS with self.drop_frames_rcvd(): self.connection.start_keepalive() # 1 ms: keepalive() sends a ping frame. # 1.x ms: a pong frame is dropped. time.sleep(MS) # Exiting the context manager sleeps for 1 ms. # 2 ms: close the connection before ping_timeout elapses. self.connection.close() self.connection.keepalive_thread.join(MS) self.assertFalse(self.connection.keepalive_thread.is_alive()) def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS with self.drop_frames_rcvd(): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. # 2.x ms: a pong frame is dropped. with self.assertLogs("websockets", logging.ERROR) as logs: with patch("threading.Event.wait", side_effect=Exception("BOOM")): time.sleep(3 * MS) # Exiting the context manager sleeps for 1 ms. self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["BOOM"], ) # Test parameters. def test_close_timeout(self): """close_timeout parameter configures close timeout.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) connection = Connection( socket_, Protocol(self.LOCAL), close_timeout=42 * MS, ) self.assertEqual(connection.close_timeout, 42 * MS) def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) connection = Connection( socket_, Protocol(self.LOCAL), max_queue=4, ) self.assertEqual(connection.recv_messages.high, 4) def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) connection = Connection( socket_, Protocol(self.LOCAL), max_queue=None, ) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.high, None) def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) connection = Connection( socket_, Protocol(self.LOCAL), max_queue=(4, 2), ) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) # Test attributes. def test_id(self): """Connection has an id attribute.""" self.assertIsInstance(self.connection.id, uuid.UUID) def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) @patch("socket.socket.getsockname", return_value=("sock", 1234)) def test_local_address(self, getsockname): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) getsockname.assert_called_with() @patch("socket.socket.getpeername", return_value=("peer", 1234)) def test_remote_address(self, getpeername): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() def test_state(self): """Connection has a state attribute.""" self.assertIs(self.connection.state, State.OPEN) def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) def test_response(self): """Connection has a response attribute.""" self.assertIsNone(self.connection.response) def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) def test_close_code(self): """Connection has a close_code attribute.""" self.assertIsNone(self.connection.close_code) def test_close_reason(self): """Connection has a close_reason attribute.""" self.assertIsNone(self.connection.close_reason) # Test reporting of network errors. @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") def test_reading_in_recv_events_fails(self): """Error when reading incoming frames is correctly reported.""" # Inject a fault by closing the socket. This works only on BSD. # I cannot find a way to achieve the same effect on Linux. self.connection.socket.close() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() self.assertIsInstance(raised.exception.__cause__, IOError) def test_writing_in_recv_events_fails(self): """Error when responding to incoming frames is correctly reported.""" # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) # Receive a ping. Responding with a pong will fail. self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.pong() self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) # Test safety nets — catching all exceptions in case of bugs. # Inject a fault in a random call in recv_events(). # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" # Receive a message to trigger the fault. self.remote_connection.send("😀") with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") exc = raised.exception self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) class ServerConnectionTests(ClientConnectionTests): LOCAL = SERVER REMOTE = CLIENT websockets-15.0.1/tests/sync/test_messages.py000066400000000000000000000552611476212450300213220ustar00rootroot00000000000000import time import unittest import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * from ..utils import MS from .utils import ThreadTestCase class AssemblerTests(ThreadTestCase): def setUp(self): self.pause = unittest.mock.Mock() self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"tea")) message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): """get returns a text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assertEqual(message, "café") def test_get_binary_message_not_received_yet(self): """get returns a binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) self.assertEqual(message, b"tea") def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_fragmented_text_message_not_received_yet(self): """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") def test_get_fragmented_binary_message_not_received_yet(self): """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") def test_get_fragmented_text_message_being_received(self): """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") def test_get_fragmented_binary_message_being_received(self): """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") def test_get_encoded_text_message(self): """get returns a text message without UTF-8 decoding.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get(decode=False) self.assertEqual(message, b"caf\xc3\xa9") def test_get_decoded_binary_message(self): """get returns a binary message with UTF-8 decoding.""" self.assembler.put(Frame(OP_BINARY, b"tea")) message = self.assembler.get(decode=True) self.assertEqual(message, "tea") def test_get_resumes_reading(self): """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) # queue is above the low-water mark self.assembler.get() self.resume.assert_not_called() # queue is at the low-water mark self.assembler.get() self.resume.assert_called_once_with() # queue is below the low-water mark self.assembler.get() self.resume.assert_called_once_with() def test_get_does_not_resume_reading(self): """get does not resume reading when the low-water mark is unset.""" self.assembler.low = None self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) self.assembler.get() self.assembler.get() self.assembler.get() self.resume.assert_not_called() def test_get_timeout_before_first_frame(self): """get times out before reading the first frame.""" with self.assertRaises(TimeoutError): self.assembler.get(timeout=MS) self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get() self.assertEqual(message, "café") def test_get_timeout_after_first_frame(self): """get times out after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.assertRaises(TimeoutError): self.assembler.get(timeout=MS) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = self.assembler.get() self.assertEqual(message, "café") def test_get_timeout_0_message_already_received(self): """get(timeout=0) returns a message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get(timeout=0) self.assertEqual(message, "café") def test_get_timeout_0_message_not_received_yet(self): """get(timeout=0) times out when no message is already received.""" with self.assertRaises(TimeoutError): self.assembler.get(timeout=0) def test_get_timeout_0_fragmented_message_already_received(self): """get(timeout=0) returns a fragmented message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = self.assembler.get(timeout=0) self.assertEqual(message, "café") def test_get_timeout_0_fragmented_message_partially_received(self): """get(timeout=0) times out when a fragmented message is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) with self.assertRaises(TimeoutError): self.assembler.get(timeout=0) # Test get_iter def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) def test_get_iter_binary_message_already_received(self): """get_iter yields a binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"tea")) fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): """get_iter yields a text message when it is received.""" fragments = [] def getter(): nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assertEqual(fragments, ["café"]) def test_get_iter_binary_message_not_received_yet(self): """get_iter yields a binary message when it is received.""" fragments = [] def getter(): nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) self.assertEqual(fragments, [b"tea"]) def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"t", b"e", b"a"]) def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = self.assembler.get_iter() self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assertEqual(next(iterator), "ca") self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assertEqual(next(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = self.assembler.get_iter() self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assertEqual(next(iterator), b"t") self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assertEqual(next(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = self.assembler.get_iter() self.assertEqual(next(iterator), "ca") self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assertEqual(next(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = self.assembler.get_iter() self.assertEqual(next(iterator), b"t") self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assertEqual(next(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(next(iterator), b"a") def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) fragments = list(self.assembler.get_iter(decode=False)) self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) def test_get_iter_decoded_binary_message(self): """get_iter yields a binary message with UTF-8 decoding.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) fragments = list(self.assembler.get_iter(decode=True)) self.assertEqual(fragments, ["t", "e", "a"]) def test_get_iter_resumes_reading(self): """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() # queue is above the low-water mark next(iterator) self.resume.assert_not_called() # queue is at the low-water mark next(iterator) self.resume.assert_called_once_with() # queue is below the low-water mark next(iterator) self.resume.assert_called_once_with() def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" self.assembler.low = None self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() next(iterator) next(iterator) next(iterator) self.resume.assert_not_called() # Test put def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" # queue is below the high-water mark self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() def test_put_does_not_pause_reading(self): """put does not pause reading when the high-water mark is unset.""" self.assembler.high = None self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_not_called() # Test termination def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" def closer(): time.sleep(2 * MS) self.assembler.close() with self.run_in_thread(closer): with self.assertRaises(EOFError): self.assembler.get() def test_get_iter_fails_when_interrupted_by_close(self): """get_iter raises EOFError when close is called.""" def closer(): time.sleep(2 * MS) self.assembler.close() with self.run_in_thread(closer): with self.assertRaises(EOFError): for _ in self.assembler.get_iter(): self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): self.assembler.get() def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): for _ in self.assembler.get_iter(): self.fail("no fragment expected") def test_get_queued_message_after_close(self): """get returns a message after close is called.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.close() message = self.assembler.get() self.assertEqual(message, "café") def test_get_iter_queued_message_after_close(self): """get_iter yields a message after close is called.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.close() fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) def test_get_queued_fragmented_message_after_close(self): """get reassembles a fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assembler.close() self.assembler.close() message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_iter_queued_fragmented_message_after_close(self): """get_iter yields a fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assembler.close() fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"t", b"e", b"a"]) def test_get_partially_queued_fragmented_message_after_close(self): """get raises EOF on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() with self.assertRaises(EOFError): self.assembler.get() def test_get_iter_partially_queued_fragmented_message_after_close(self): """get_iter yields a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() fragments = [] with self.assertRaises(EOFError): for fragment in self.assembler.get_iter(): fragments.append(fragment) self.assertEqual(fragments, [b"t", b"e"]) def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) def test_close_resumes_reading(self): """close unblocks reading when queue is above the high-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) # queue is at the high-water mark assert self.assembler.paused self.assembler.close() self.resume.assert_called_once_with() def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() self.assembler.close() # Test (non-)concurrency def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" with self.run_in_thread(self.assembler.get): with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" with self.run_in_thread(self.assembler.get): with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread # Test setting limits def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 2) def test_set_low_water_mark(self): """low sets the low-water and high-water marks.""" assembler = Assembler(low=5) self.assertEqual(assembler.low, 5) self.assertEqual(assembler.high, 20) def test_set_high_and_low_water_marks(self): """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) def test_unset_high_and_low_water_marks(self): """High-water and low-water marks are unset.""" assembler = Assembler() self.assertEqual(assembler.high, None) self.assertEqual(assembler.low, None) def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): Assembler(high=-1) def test_set_invalid_low_water_mark(self): """low must be higher than high.""" with self.assertRaises(ValueError): Assembler(low=10, high=5) websockets-15.0.1/tests/sync/test_router.py000066400000000000000000000154541476212450300210330ustar00rootroot00000000000000import http import socket import sys import unittest from unittest.mock import patch from websockets.exceptions import InvalidStatus from websockets.sync.client import connect, unix_connect from websockets.sync.router import * from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path from .server import EvalShellMixin, get_uri, handler, run_router, run_unix_router try: from werkzeug.routing import Map, Rule except ImportError: pass def echo(websocket, count): message = websocket.recv() for _ in range(count): websocket.send(message) @unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") class RouterTests(EvalShellMixin, unittest.TestCase): # This is a small realistic example of werkzeug's basic URL routing # features: path matching, parameter extraction, and default values. def test_router_matches_paths_and_extracts_parameters(self): """Router matches paths and extracts parameters.""" url_map = Map( [ Rule("/echo", defaults={"count": 1}, endpoint=echo), Rule("/echo/", endpoint=echo), ] ) with run_router(url_map) as server: with connect(get_uri(server) + "/echo") as client: client.send("hello") messages = list(client) self.assertEqual(messages, ["hello"]) with connect(get_uri(server) + "/echo/3") as client: client.send("hello") messages = list(client) self.assertEqual(messages, ["hello", "hello", "hello"]) @property # avoids an import-time dependency on werkzeug def url_map(self): return Map( [ Rule("/", endpoint=handler), Rule("/r", redirect_to="/"), ] ) def test_route_with_query_string(self): """Router ignores query strings when matching paths.""" with run_router(self.url_map) as server: with connect(get_uri(server) + "/?a=b") as client: self.assertEval(client, "ws.request.path", "/?a=b") def test_redirect(self): """Router redirects connections according to redirect_to.""" with run_router(self.url_map, server_name="localhost") as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/r"): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], "ws://localhost/", ) def test_secure_redirect(self): """Router redirects connections to a wss:// URI when TLS is enabled.""" with run_router( self.url_map, server_name="localhost", ssl=SERVER_CONTEXT ) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], "wss://localhost/", ) @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) def test_force_secure_redirect(self): """Router redirects ws:// connections to a wss:// URI when ssl=True.""" with run_router(self.url_map, ssl=True) as server: redirect_uri = get_uri(server, secure=True) with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/r"): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], redirect_uri + "/", ) @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) def test_force_redirect_server_name(self): """Router redirects connections to the host declared in server_name.""" with run_router(self.url_map, server_name="other") as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/r"): self.fail("did not raise") self.assertEqual( raised.exception.response.headers["Location"], "ws://other/", ) def test_not_found(self): """Router rejects requests to unknown paths with an HTTP 404 error.""" with run_router(self.url_map) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/n"): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 404", ) def test_process_request_returning_none(self): """Router supports a process_request returning None.""" def process_request(ws, request): ws.process_request_ran = True with run_router(self.url_map, process_request=process_request) as server: with connect(get_uri(server) + "/") as client: self.assertEval(client, "ws.process_request_ran", "True") def test_process_request_returning_response(self): """Router supports a process_request returning a response.""" def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_router(self.url_map, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server) + "/"): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) def test_custom_router_factory(self): """Router supports a custom router factory.""" class MyRouter(Router): def handler(self, connection): connection.my_router_ran = True return super().handler(connection) with run_router(self.url_map, create_router=MyRouter) as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.my_router_ran", "True") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): def test_router_supports_unix_sockets(self): """Router supports Unix sockets.""" url_map = Map([Rule("/echo/", endpoint=echo)]) with temp_unix_socket_path() as path: with run_unix_router(path, url_map): with unix_connect(path, "ws://localhost/echo/3") as client: client.send("hello") messages = list(client) self.assertEqual(messages, ["hello", "hello", "hello"]) websockets-15.0.1/tests/sync/test_server.py000066400000000000000000000557671476212450300210340ustar00rootroot00000000000000import dataclasses import hmac import http import logging import socket import time import unittest from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, InvalidStatus, NegotiationError, ) from websockets.http11 import Request, Response from websockets.sync.client import connect, unix_connect from websockets.sync.server import * from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, DeprecationTestCase, temp_unix_socket_path, ) from .server import ( EvalShellMixin, get_uri, handler, run_server, run_unix_server, ) class ServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives connection from client and the handshake succeeds.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_connection_handler_returns(self): """Connection handler returns.""" with run_server() as server: with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: client.recv() self.assertEqual( str(raised.exception), "received 1000 (OK); then sent 1000 (OK)", ) def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server() as server: with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: client.recv() self.assertEqual( str(raised.exception), "received 1011 (internal error); then sent 1011 (internal error)", ) def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with run_server(sock=sock): with connect(f"ws://{host}:{port}/") as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): """Server selects a subprotocol with the select_subprotocol callable.""" def select_subprotocol(ws, subprotocols): ws.select_subprotocol_ran = True assert "chat" in subprotocols return "chat" with run_server( subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: with connect(get_uri(server), subprotocols=["chat"]) as client: self.assertEval(client, "ws.select_subprotocol_ran", "True") self.assertEval(client, "ws.subprotocol", "chat") def test_select_subprotocol_rejects_handshake(self): """Server rejects handshake if select_subprotocol raises NegotiationError.""" def select_subprotocol(ws, subprotocols): raise NegotiationError with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 400", ) def test_select_subprotocol_raises_exception(self): """Server returns an error if select_subprotocol raises an exception.""" def select_subprotocol(ws, subprotocols): raise RuntimeError with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) def test_compression_is_enabled(self): """Server enables compression by default.""" with run_server() as server: with connect(get_uri(server)) as client: self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", "['PerMessageDeflate']", ) def test_disable_compression(self): """Server disables compression.""" with run_server(compression=None) as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True with run_server(process_request=process_request) as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_request_ran", "True") def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") def handler(ws): self.fail("handler must not run") with run_server(handler, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 403", ) def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): raise RuntimeError with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) self.assertIsInstance(response, Response) ws.process_response_ran = True with run_server(process_response=process_response) as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_modifies_response(self): """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_replaces_response(self): """Server runs process_response and replaces the handshake response.""" def process_response(ws, request, response): headers = response.headers.copy() headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) with run_server(process_response=process_response) as server: with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): raise RuntimeError with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 500", ) def test_override_server(self): """Server can override Server header with server_header.""" with run_server(server_header="Neo") as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.response.headers['Server']", "Neo") def test_remove_server(self): """Server can remove Server header with server_header.""" with run_server(server_header=None) as server: with connect(get_uri(server)) as client: self.assertEval(client, "'Server' in ws.response.headers", "False") def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" with run_server(ping_interval=MS) as server: with connect(get_uri(server)) as client: client.send("ws.latency") latency = eval(client.recv()) self.assertEqual(latency, 0) time.sleep(2 * MS) client.send("ws.latency") latency = eval(client.recv()) self.assertGreater(latency, 0) def test_disable_keepalive(self): """Server disables keepalive.""" with run_server(ping_interval=None) as server: with connect(get_uri(server)) as client: time.sleep(2 * MS) client.send("ws.latency") latency = eval(client.recv()) self.assertEqual(latency, 0) def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: self.assertEqual(server.logger.name, logger.name) def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" def create_connection(*args, **kwargs): server = ServerConnection(*args, **kwargs) server.create_connection_ran = True return server with run_server(create_connection=create_connection) as server: with connect(get_uri(server)) as client: self.assertEval(client, "ws.create_connection_ran", "True") def test_fileno(self): """Server provides a fileno attribute.""" with run_server() as server: self.assertIsInstance(server.fileno(), int) def test_shutdown(self): """Server provides a shutdown method.""" with run_server() as server: server.shutdown() # Check that the server socket is closed. with self.assertRaises(OSError): server.socket.accept() def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] with run_server(process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 400", ) def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" with run_server(open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. time.sleep(MS) def test_junk_handshake(self): """Server closes the connection when receiving non-HTTP request from client.""" with self.assertLogs("websockets.server", logging.ERROR) as logs: with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: sock.send(b"HELO relay.invalid\r\n") # Wait for the server to close the connection. self.assertEqual(sock.recv(4096), b"") self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], ["did not receive a valid HTTP request"], ) self.assertEqual( [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" with run_server(ssl=SERVER_CONTEXT) as server: with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. time.sleep(MS) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): with unix_connect(path) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: unix_serve(handler) self.assertEqual( str(raised.exception), "missing path argument", ) def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock arguments are incompatible", ) def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: serve(handler, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", ) def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: serve(handler, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", ) class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") def test_missing_authorization(self): """basic_auth rejects client without credentials.""" with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) def test_unsupported_authorization(self): """basic_auth rejects client with unsupported credentials.""" with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: with connect( get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) def test_authorization_with_unknown_username(self): """basic_auth rejects client with unknown username.""" with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: with connect( get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) def test_authorization_with_incorrect_password(self): """basic_auth rejects client with incorrect password.""" with run_server( process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") self.assertEqual( str(raised.exception), "server rejected WebSocket connection: HTTP 401", ) def test_list_of_credentials(self): """basic_auth accepts a list of hard coded credentials.""" with run_server( process_request=basic_auth( credentials=[ ("hello", "iloveyou"), ("bye", "youloveme"), ] ), ) as server: with connect( get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: self.assertEval(client, "ws.username", "bye") def test_check_credentials(self): """basic_auth accepts a check_credentials function.""" def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: with connect( get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), "provide either credentials or check_credentials", ) def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover ) self.assertEqual( str(raised.exception), "provide either credentials or check_credentials", ) def test_bad_credentials(self): """basic_auth receives an unsupported credentials argument.""" with self.assertRaises(TypeError) as raised: basic_auth(credentials=42) self.assertEqual( str(raised.exception), "invalid credentials argument: 42", ) def test_bad_list_of_credentials(self): """basic_auth receives an unsupported credentials argument.""" with self.assertRaises(TypeError) as raised: basic_auth(credentials=[42]) self.assertEqual( str(raised.exception), "invalid credentials argument: [42]", ) class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): """Server supports the deprecated ssl_context argument.""" with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with run_server(ssl_context=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT): pass def test_web_socket_server_class(self): with self.assertDeprecationWarning("WebSocketServer was renamed to Server"): from websockets.sync.server import WebSocketServer self.assertIs(WebSocketServer, Server) websockets-15.0.1/tests/sync/test_utils.py000066400000000000000000000021331476212450300206410ustar00rootroot00000000000000import unittest from websockets.sync.utils import * from ..utils import MS class DeadlineTests(unittest.TestCase): def test_timeout_pending(self): """timeout returns remaining time if deadline is in the future.""" deadline = Deadline(MS) timeout = deadline.timeout() self.assertGreater(timeout, 0) self.assertLess(timeout, MS) def test_timeout_elapsed_exception(self): """timeout raises TimeoutError if deadline is in the past.""" deadline = Deadline(-MS) with self.assertRaises(TimeoutError): deadline.timeout() def test_timeout_elapsed_no_exception(self): """timeout doesn't raise TimeoutError when raise_if_elapsed is disabled.""" deadline = Deadline(-MS) timeout = deadline.timeout(raise_if_elapsed=False) self.assertGreater(timeout, -2 * MS) self.assertLess(timeout, -MS) def test_no_timeout(self): """timeout returns None when no deadline is set.""" deadline = Deadline(None) timeout = deadline.timeout() self.assertIsNone(timeout, None) websockets-15.0.1/tests/sync/utils.py000066400000000000000000000012221476212450300176000ustar00rootroot00000000000000import contextlib import threading import time import unittest from ..utils import MS class ThreadTestCase(unittest.TestCase): @contextlib.contextmanager def run_in_thread(self, target): """ Run ``target`` function without arguments in a thread. In order to facilitate writing tests, this helper lets the thread run for 1ms on entry and joins the thread with a 1ms timeout on exit. """ thread = threading.Thread(target=target) thread.start() time.sleep(MS) try: yield finally: thread.join(MS) self.assertFalse(thread.is_alive()) websockets-15.0.1/tests/test_auth.py000066400000000000000000000010731476212450300174700ustar00rootroot00000000000000from .utils import DeprecationTestCase class BackwardsCompatibilityTests(DeprecationTestCase): def test_headers_class(self): with self.assertDeprecationWarning( "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", ): from websockets.auth import ( BasicAuthWebSocketServerProtocol, # noqa: F401 basic_auth_protocol_factory, # noqa: F401 ) websockets-15.0.1/tests/test_cli.py000066400000000000000000000076771476212450300173160ustar00rootroot00000000000000import io import os import re import unittest from unittest.mock import patch from websockets.cli import * from websockets.exceptions import ConnectionClosed from websockets.version import version # Run a test server in a thread. This is easier than running an asyncio server # because we would have to run main() in a thread, due to using asyncio.run(). from .sync.server import get_uri, run_server vt100_commands = re.compile(r"\x1b\[[A-Z]|\x1b[78]|\r") def remove_commands_and_prompts(output): return vt100_commands.sub("", output).replace("> ", "") def add_connection_messages(output, server_uri): return f"Connected to {server_uri}.\n{output}Connection closed: 1000 (OK).\n" class CLITests(unittest.TestCase): def run_main(self, argv, inputs="", close_input=False, expected_exit_code=None): # Replace sys.stdin with a file-like object backed by a file descriptor # for compatibility with loop.connect_read_pipe(). stdin_read_fd, stdin_write_fd = os.pipe() stdin = io.FileIO(stdin_read_fd) self.addCleanup(stdin.close) os.write(stdin_write_fd, inputs.encode()) if close_input: os.close(stdin_write_fd) else: self.addCleanup(os.close, stdin_write_fd) # Replace sys.stdout with a file-like object to record outputs. stdout = io.StringIO() with patch("sys.stdin", new=stdin), patch("sys.stdout", new=stdout): # Catch sys.exit() calls when expected. if expected_exit_code is not None: with self.assertRaises(SystemExit) as raised: main(argv) self.assertEqual(raised.exception.code, expected_exit_code) else: main(argv) return stdout.getvalue() def test_version(self): output = self.run_main(["--version"]) self.assertEqual(output, f"websockets {version}\n") def test_receive_text_message(self): def text_handler(websocket): websocket.send("café") with run_server(text_handler) as server: server_uri = get_uri(server) output = self.run_main([server_uri], "") self.assertEqual( remove_commands_and_prompts(output), add_connection_messages("\n< café\n", server_uri), ) def test_receive_binary_message(self): def binary_handler(websocket): websocket.send(b"tea") with run_server(binary_handler) as server: server_uri = get_uri(server) output = self.run_main([server_uri], "") self.assertEqual( remove_commands_and_prompts(output), add_connection_messages("\n< (binary) 746561\n", server_uri), ) def test_send_message(self): def echo_handler(websocket): websocket.send(websocket.recv()) with run_server(echo_handler) as server: server_uri = get_uri(server) output = self.run_main([server_uri], "hello\n") self.assertEqual( remove_commands_and_prompts(output), add_connection_messages("\n< hello\n", server_uri), ) def test_close_connection(self): def wait_handler(websocket): with self.assertRaises(ConnectionClosed): websocket.recv() with run_server(wait_handler) as server: server_uri = get_uri(server) output = self.run_main([server_uri], "", close_input=True) self.assertEqual( remove_commands_and_prompts(output), add_connection_messages("", server_uri), ) def test_connection_failure(self): output = self.run_main(["ws://localhost:54321"], expected_exit_code=1) self.assertTrue( output.startswith("Failed to connect to ws://localhost:54321: ") ) def test_no_args(self): output = self.run_main([], expected_exit_code=2) self.assertEqual(output, "usage: websockets [--version | ]\n") websockets-15.0.1/tests/test_client.py000066400000000000000000000614601476212450300200130ustar00rootroot00000000000000import contextlib import dataclasses import logging import types import unittest from unittest.mock import patch from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHandshake, InvalidHeader, InvalidMessage, InvalidStatus, ) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN from websockets.uri import parse_uri from websockets.utils import accept_key from .extensions.utils import ( ClientOpExtensionFactory, ClientRsv2ExtensionFactory, OpExtension, Rsv2Extension, ) from .test_utils import ACCEPT, KEY from .utils import DATE, DeprecationTestCase URI = parse_uri("wss://example.com/test") # for tests where the URI doesn't matter @patch("websockets.client.generate_key", return_value=KEY) class BasicTests(unittest.TestCase): """Test basic opening handshake scenarios.""" def test_send_request(self, _generate_key): """Client sends a handshake request.""" client = ClientProtocol(URI) request = client.connect() client.send_request(request) self.assertEqual( client.data_to_send(), [ f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"\r\n".encode() ], ) self.assertFalse(client.close_expected()) self.assertEqual(client.state, CONNECTING) def test_receive_successful_response(self, _generate_key): """Client receives a successful handshake response.""" client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" f"\r\n" ).encode(), ) self.assertEqual(client.data_to_send(), []) self.assertFalse(client.close_expected()) self.assertEqual(client.state, OPEN) def test_receive_failed_response(self, _generate_key): """Client receives a failed handshake response.""" client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" f"Sorry folks.\n" ).encode(), ) self.assertEqual(client.data_to_send(), [b""]) self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) class RequestTests(unittest.TestCase): """Test generating opening handshake requests.""" @patch("websockets.client.generate_key", return_value=KEY) def test_connect(self, _generate_key): """connect() creates an opening handshake request.""" client = ClientProtocol(URI) request = client.connect() self.assertIsInstance(request, Request) self.assertEqual(request.path, "/test") self.assertEqual( request.headers, Headers( { "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", } ), ) def test_path(self): """connect() uses the path from the URI.""" client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") def test_port(self): """connect() uses the port from the URI or the default port.""" for uri, host in [ ("ws://example.com/", "example.com"), ("ws://example.com:80/", "example.com"), ("ws://example.com:8080/", "example.com:8080"), ("wss://example.com/", "example.com"), ("wss://example.com:443/", "example.com"), ("wss://example.com:8443/", "example.com:8443"), ]: with self.subTest(uri=uri): client = ClientProtocol(parse_uri(uri)) request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): """connect() perfoms HTTP Basic Authentication with user info from the URI.""" client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): """connect(origin=...) generates an Origin header.""" client = ClientProtocol(URI, origin="https://example.com") request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): """connect(extensions=...) generates a Sec-WebSocket-Extensions header.""" client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory()]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): """connect(subprotocols=...) generates a Sec-WebSocket-Protocol header.""" client = ClientProtocol(URI, subprotocols=["chat"]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") @patch("websockets.client.generate_key", return_value=KEY) class ResponseTests(unittest.TestCase): """Test receiving opening handshake responses.""" def test_receive_successful_response(self, _generate_key): """Client receives a successful handshake response.""" client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" f"\r\n" ).encode(), ) [response] = client.events_received() self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( response.headers, Headers( { "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, "Date": DATE, } ), ) self.assertEqual(response.body, b"") self.assertIsNone(client.handshake_exc) def test_receive_failed_response(self, _generate_key): """Client receives a failed handshake response.""" client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" f"Sorry folks.\n" ).encode(), ) [response] = client.events_received() self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( response.headers, Headers( { "Date": DATE, "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", "Connection": "close", } ), ) self.assertEqual(response.body, b"Sorry folks.\n") self.assertIsInstance(client.handshake_exc, InvalidStatus) self.assertEqual( str(client.handshake_exc), "server rejected WebSocket connection: HTTP 404", ) def test_receive_no_response(self, _generate_key): """Client receives no handshake response.""" client = ClientProtocol(URI) client.receive_eof() self.assertEqual(client.events_received(), []) self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), "did not receive a valid HTTP response", ) self.assertIsInstance(client.handshake_exc.__cause__, EOFError) self.assertEqual( str(client.handshake_exc.__cause__), "connection closed while reading HTTP status line", ) def test_receive_truncated_response(self, _generate_key): """Client receives a truncated handshake response.""" client = ClientProtocol(URI) client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") client.receive_eof() self.assertEqual(client.events_received(), []) self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), "did not receive a valid HTTP response", ) self.assertIsInstance(client.handshake_exc.__cause__, EOFError) self.assertEqual( str(client.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) def test_receive_random_response(self, _generate_key): """Client receives a junk handshake response.""" client = ClientProtocol(URI) client.receive_data(b"220 smtp.invalid\r\n") client.receive_data(b"250 Hello relay.invalid\r\n") client.receive_data(b"250 Ok\r\n") client.receive_data(b"250 Ok\r\n") self.assertEqual(client.events_received(), []) self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), "did not receive a valid HTTP response", ) self.assertIsInstance(client.handshake_exc.__cause__, ValueError) self.assertEqual( str(client.handshake_exc.__cause__), "invalid HTTP status line: 220 smtp.invalid", ) @contextlib.contextmanager def alter_and_receive_response(client): """Generate a handshake response that can be altered for testing.""" # We could start by sending a handshake request, i.e.: # request = client.connect() # client.send_request(request) # However, in the current implementation, these calls have no effect on the # state of the client. Therefore, they're unnecessary and can be skipped. response = Response( status_code=101, reason_phrase="Switching Protocols", headers=Headers( { "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": accept_key(client.key), } ), ) yield response client.receive_data(response.serialize()) [parsed_response] = client.events_received() assert response == dataclasses.replace(parsed_response, _exception=None) class HandshakeTests(unittest.TestCase): """Test processing of handshake responses to configure the connection.""" def assertHandshakeSuccess(self, client): """Assert that the opening handshake succeeded.""" self.assertEqual(client.state, OPEN) self.assertIsNone(client.handshake_exc) def assertHandshakeError(self, client, exc_type, msg): """Assert that the opening handshake failed with the given exception.""" self.assertEqual(client.state, CONNECTING) self.assertIsInstance(client.handshake_exc, exc_type) # Exception chaining isn't used is client handshake implementation. assert client.handshake_exc.__cause__ is None self.assertEqual(str(client.handshake_exc), msg) def test_basic(self): """Handshake succeeds.""" client = ClientProtocol(URI) with alter_and_receive_response(client): pass self.assertHandshakeSuccess(client) def test_missing_connection(self): """Handshake fails when the Connection header is missing.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Connection"] self.assertHandshakeError( client, InvalidHeader, "missing Connection header", ) def test_invalid_connection(self): """Handshake fails when the Connection header is invalid.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Connection"] response.headers["Connection"] = "close" self.assertHandshakeError( client, InvalidHeader, "invalid Connection header: close", ) def test_missing_upgrade(self): """Handshake fails when the Upgrade header is missing.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Upgrade"] self.assertHandshakeError( client, InvalidHeader, "missing Upgrade header", ) def test_invalid_upgrade(self): """Handshake fails when the Upgrade header is invalid.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" self.assertHandshakeError( client, InvalidHeader, "invalid Upgrade header: h2c", ) def test_missing_accept(self): """Handshake fails when the Sec-WebSocket-Accept header is missing.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Sec-WebSocket-Accept"] self.assertHandshakeError( client, InvalidHeader, "missing Sec-WebSocket-Accept header", ) def test_multiple_accept(self): """Handshake fails when the Sec-WebSocket-Accept header is repeated.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Accept"] = ACCEPT self.assertHandshakeError( client, InvalidHeader, "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): """Handshake fails when the Sec-WebSocket-Accept header is invalid.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT self.assertHandshakeError( client, InvalidHeader, f"invalid Sec-WebSocket-Accept header: {ACCEPT}", ) def test_no_extensions(self): """Handshake succeeds without extensions.""" client = ClientProtocol(URI) with alter_and_receive_response(client): pass self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, []) def test_offer_extension(self): """Client offers an extension.""" client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-rsv2") def test_enable_extension(self): """Client offers an extension and the server enables it.""" client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension()]) def test_extension_not_enabled(self): """Client offers an extension, but the server doesn't enable it.""" client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) with alter_and_receive_response(client): pass self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, []) def test_no_extensions_offered(self): """Server enables an extension when the client didn't offer any.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" self.assertHandshakeError( client, InvalidHandshake, "no extensions supported", ) def test_extension_not_offered(self): """Server enables an extension that the client didn't offer.""" client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-op; op" self.assertHandshakeError( client, InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', None)]", ) def test_supported_extension_parameters(self): """Server enables an extension with parameters supported by the client.""" client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): """Server enables an extension with parameters unsupported by the client.""" client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" self.assertHandshakeError( client, InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', 'that')]", ) def test_multiple_supported_extension_parameters(self): """Client offers the same extension with several parameters.""" client = ClientProtocol( URI, extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), ], ) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): """Client offers several extensions and the server enables them.""" client = ClientProtocol( URI, extensions=[ ClientOpExtensionFactory(), ClientRsv2ExtensionFactory(), ], ) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-op; op" response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): """Client respects the order of extensions chosen by the server.""" client = ClientProtocol( URI, extensions=[ ClientOpExtensionFactory(), ClientRsv2ExtensionFactory(), ], ) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response.headers["Sec-WebSocket-Extensions"] = "x-op; op" self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): """Handshake succeeds without subprotocols.""" client = ClientProtocol(URI) with alter_and_receive_response(client): pass self.assertHandshakeSuccess(client) self.assertIsNone(client.subprotocol) def test_no_subprotocol_requested(self): """Client doesn't offer a subprotocol, but the server enables one.""" client = ClientProtocol(URI) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Protocol"] = "chat" self.assertHandshakeError( client, InvalidHandshake, "no subprotocols supported", ) def test_offer_subprotocol(self): """Client offers a subprotocol.""" client = ClientProtocol(URI, subprotocols=["chat"]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") def test_enable_subprotocol(self): """Client offers a subprotocol and the server enables it.""" client = ClientProtocol(URI, subprotocols=["chat"]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Protocol"] = "chat" self.assertHandshakeSuccess(client) self.assertEqual(client.subprotocol, "chat") def test_no_subprotocol_accepted(self): """Client offers a subprotocol, but the server doesn't enable it.""" client = ClientProtocol(URI, subprotocols=["chat"]) with alter_and_receive_response(client): pass self.assertHandshakeSuccess(client) self.assertIsNone(client.subprotocol) def test_multiple_subprotocols(self): """Client offers several subprotocols and the server enables one.""" client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Protocol"] = "chat" self.assertHandshakeSuccess(client) self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): """Client offers subprotocols but the server enables another one.""" client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Protocol"] = "otherchat" self.assertHandshakeError( client, InvalidHandshake, "unsupported subprotocol: otherchat", ) def test_multiple_subprotocols_accepted(self): """Server attempts to enable multiple subprotocols.""" client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) with alter_and_receive_response(client) as response: response.headers["Sec-WebSocket-Protocol"] = "superchat" response.headers["Sec-WebSocket-Protocol"] = "chat" self.assertHandshakeError( client, InvalidHandshake, "invalid Sec-WebSocket-Protocol header: multiple values: superchat, chat", ) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): """ClientProtocol bypasses the opening handshake.""" client = ClientProtocol(URI, state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): """ClientProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: ClientProtocol(URI, logger=logger) self.assertEqual(len(logs.records), 1) class BackwardsCompatibilityTests(DeprecationTestCase): def test_client_connection_class(self): """ClientConnection is a deprecated alias for ClientProtocol.""" with self.assertDeprecationWarning( "ClientConnection was renamed to ClientProtocol" ): from websockets.client import ClientConnection client = ClientConnection("ws://localhost/") self.assertIsInstance(client, ClientProtocol) class BackoffTests(unittest.TestCase): def test_backoff(self): """backoff() yields a random delay, then exponentially increasing delays.""" backoff_gen = backoff() self.assertIsInstance(backoff_gen, types.GeneratorType) initial_delay = next(backoff_gen) self.assertGreaterEqual(initial_delay, 0) self.assertLess(initial_delay, 5) following_delays = [int(next(backoff_gen)) for _ in range(9)] self.assertEqual(following_delays, [3, 5, 8, 13, 21, 34, 55, 89, 90]) websockets-15.0.1/tests/test_connection.py000066400000000000000000000010071476212450300206630ustar00rootroot00000000000000from websockets.protocol import Protocol from .utils import DeprecationTestCase class BackwardsCompatibilityTests(DeprecationTestCase): def test_connection_class(self): """Connection is a deprecated alias for Protocol.""" with self.assertDeprecationWarning( "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol" ): from websockets.connection import Connection self.assertIs(Connection, Protocol) websockets-15.0.1/tests/test_datastructures.py000066400000000000000000000160621476212450300216100ustar00rootroot00000000000000import unittest from websockets.datastructures import * class MultipleValuesErrorTests(unittest.TestCase): def test_multiple_values_error_str(self): self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") self.assertEqual(str(MultipleValuesError()), "") class HeadersTests(unittest.TestCase): def setUp(self): self.headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) def test_init(self): self.assertEqual( Headers(), Headers(), ) def test_init_from_kwargs(self): self.assertEqual( Headers(connection="Upgrade", server="websockets"), self.headers, ) def test_init_from_headers(self): self.assertEqual( Headers(self.headers), self.headers, ) def test_init_from_headers_and_kwargs(self): self.assertEqual( Headers(Headers(connection="Upgrade"), server="websockets"), self.headers, ) def test_init_from_mapping(self): self.assertEqual( Headers({"Connection": "Upgrade", "Server": "websockets"}), self.headers, ) def test_init_from_mapping_and_kwargs(self): self.assertEqual( Headers({"Connection": "Upgrade"}, server="websockets"), self.headers, ) def test_init_from_iterable(self): self.assertEqual( Headers([("Connection", "Upgrade"), ("Server", "websockets")]), self.headers, ) def test_init_from_iterable_and_kwargs(self): self.assertEqual( Headers([("Connection", "Upgrade")], server="websockets"), self.headers, ) def test_init_multiple_positional_arguments(self): with self.assertRaises(TypeError): Headers(Headers(connection="Upgrade"), Headers(server="websockets")) def test_str(self): self.assertEqual( str(self.headers), "Connection: Upgrade\r\nServer: websockets\r\n\r\n" ) def test_repr(self): self.assertEqual( repr(self.headers), "Headers([('Connection', 'Upgrade'), ('Server', 'websockets')])", ) def test_copy(self): self.assertEqual(repr(self.headers.copy()), repr(self.headers)) def test_serialize(self): self.assertEqual( self.headers.serialize(), b"Connection: Upgrade\r\nServer: websockets\r\n\r\n", ) def test_contains(self): self.assertIn("Server", self.headers) def test_contains_case_insensitive(self): self.assertIn("server", self.headers) def test_contains_not_found(self): self.assertNotIn("Date", self.headers) def test_contains_non_string_key(self): self.assertNotIn(42, self.headers) def test_iter(self): self.assertEqual(set(iter(self.headers)), {"connection", "server"}) def test_len(self): self.assertEqual(len(self.headers), 2) def test_getitem(self): self.assertEqual(self.headers["Server"], "websockets") def test_getitem_case_insensitive(self): self.assertEqual(self.headers["server"], "websockets") def test_getitem_key_error(self): with self.assertRaises(KeyError): self.headers["Upgrade"] def test_setitem(self): self.headers["Upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") def test_setitem_case_insensitive(self): self.headers["upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") def test_delitem(self): del self.headers["Connection"] with self.assertRaises(KeyError): self.headers["Connection"] def test_delitem_case_insensitive(self): del self.headers["connection"] with self.assertRaises(KeyError): self.headers["Connection"] def test_eq(self): other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) self.assertEqual(self.headers, other_headers) def test_eq_case_insensitive(self): other_headers = Headers(connection="Upgrade", server="websockets") self.assertEqual(self.headers, other_headers) def test_eq_not_equal(self): other_headers = Headers([("Connection", "close"), ("Server", "websockets")]) self.assertNotEqual(self.headers, other_headers) def test_eq_other_type(self): self.assertNotEqual( self.headers, "Connection: Upgrade\r\nServer: websockets\r\n\r\n" ) def test_clear(self): self.headers.clear() self.assertFalse(self.headers) self.assertEqual(self.headers, Headers()) def test_get_all(self): self.assertEqual(self.headers.get_all("Connection"), ["Upgrade"]) def test_get_all_case_insensitive(self): self.assertEqual(self.headers.get_all("connection"), ["Upgrade"]) def test_get_all_no_values(self): self.assertEqual(self.headers.get_all("Upgrade"), []) def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), [("Connection", "Upgrade"), ("Server", "websockets")], ) class MultiValueHeadersTests(unittest.TestCase): def setUp(self): self.headers = Headers([("Server", "Python"), ("Server", "websockets")]) def test_init_from_headers(self): self.assertEqual( Headers(self.headers), self.headers, ) def test_init_from_headers_and_kwargs(self): self.assertEqual( Headers(Headers(server="Python"), server="websockets"), self.headers, ) def test_str(self): self.assertEqual( str(self.headers), "Server: Python\r\nServer: websockets\r\n\r\n" ) def test_repr(self): self.assertEqual( repr(self.headers), "Headers([('Server', 'Python'), ('Server', 'websockets')])", ) def test_copy(self): self.assertEqual(repr(self.headers.copy()), repr(self.headers)) def test_serialize(self): self.assertEqual( self.headers.serialize(), b"Server: Python\r\nServer: websockets\r\n\r\n", ) def test_iter(self): self.assertEqual(set(iter(self.headers)), {"server"}) def test_len(self): self.assertEqual(len(self.headers), 1) def test_getitem_multiple_values_error(self): with self.assertRaises(MultipleValuesError): self.headers["Server"] def test_setitem(self): self.headers["Server"] = "redux" self.assertEqual( self.headers.get_all("Server"), ["Python", "websockets", "redux"] ) def test_delitem(self): del self.headers["Server"] with self.assertRaises(KeyError): self.headers["Server"] def test_get_all(self): self.assertEqual(self.headers.get_all("Server"), ["Python", "websockets"]) def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), [("Server", "Python"), ("Server", "websockets")], ) websockets-15.0.1/tests/test_exceptions.py000066400000000000000000000205041476212450300207100ustar00rootroot00000000000000import unittest from websockets.datastructures import Headers from websockets.exceptions import * from websockets.frames import Close, CloseCode from websockets.http11 import Response from .utils import DeprecationTestCase class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ ( WebSocketException("something went wrong"), "something went wrong", ), ( ConnectionClosed( Close(CloseCode.NORMAL_CLOSURE, ""), Close(CloseCode.NORMAL_CLOSURE, ""), True, ), "received 1000 (OK); then sent 1000 (OK)", ), ( ConnectionClosed( Close(CloseCode.GOING_AWAY, "Bye!"), Close(CloseCode.GOING_AWAY, "Bye!"), False, ), "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( ConnectionClosed( Close(CloseCode.NORMAL_CLOSURE, "race"), Close(CloseCode.NORMAL_CLOSURE, "cond"), True, ), "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( ConnectionClosed( Close(CloseCode.NORMAL_CLOSURE, "cond"), Close(CloseCode.NORMAL_CLOSURE, "race"), False, ), "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( ConnectionClosed( None, Close(CloseCode.MESSAGE_TOO_BIG, ""), None, ), "sent 1009 (message too big); no close frame received", ), ( ConnectionClosed( Close(CloseCode.PROTOCOL_ERROR, ""), None, None, ), "received 1002 (protocol error); no close frame sent", ), ( ConnectionClosedOK( Close(CloseCode.NORMAL_CLOSURE, ""), Close(CloseCode.NORMAL_CLOSURE, ""), True, ), "received 1000 (OK); then sent 1000 (OK)", ), ( ConnectionClosedError( None, None, None, ), "no close frame received or sent", ), ( InvalidURI("|", "not at all!"), "| isn't a valid URI: not at all!", ), ( InvalidProxy("|", "not at all!"), "| isn't a valid proxy: not at all!", ), ( InvalidHandshake("invalid request"), "invalid request", ), ( SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), ( ProxyError("failed to connect to SOCKS proxy"), "failed to connect to SOCKS proxy", ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), ( InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), ( InvalidProxyMessage("malformed HTTP message"), "malformed HTTP message", ), ( InvalidProxyStatus(Response(401, "Unauthorized", Headers())), "proxy rejected connection: HTTP 401", ), ( InvalidHeader("Name"), "missing Name header", ), ( InvalidHeader("Name", None), "missing Name header", ), ( InvalidHeader("Name", ""), "empty Name header", ), ( InvalidHeader("Name", "Value"), "invalid Name header: Value", ), ( InvalidHeaderFormat("Sec-WebSocket-Protocol", "exp. token", "a=|", 3), "invalid Sec-WebSocket-Protocol header: exp. token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), "invalid Sec-WebSocket-Version header: 42", ), ( InvalidOrigin("http://bad.origin"), "invalid Origin header: http://bad.origin", ), ( InvalidUpgrade("Upgrade"), "missing Upgrade header", ), ( InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", ), ( DuplicateParameter("a"), "duplicate parameter: a", ), ( InvalidParameterName("|"), "invalid parameter name: |", ), ( InvalidParameterValue("a", None), "missing value for parameter a", ), ( InvalidParameterValue("a", ""), "empty value for parameter a", ), ( InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), ( ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), ( PayloadTooBig(None, 4), "frame exceeds limit of 4 bytes", ), ( PayloadTooBig(8, 4), "frame with 8 bytes exceeds limit of 4 bytes", ), ( PayloadTooBig(8, 4, 12), "frame with 8 bytes after reading 12 bytes exceeds limit of 16 bytes", ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), ( ConcurrencyError("get() or get_iter() is already running"), "get() or get_iter() is already running", ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) class DeprecationTests(DeprecationTestCase): def test_connection_closed_attributes_deprecation(self): exception = ConnectionClosed(Close(CloseCode.NORMAL_CLOSURE, "OK"), None, None) with self.assertDeprecationWarning( "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code" ): self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) with self.assertDeprecationWarning( "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "OK") def test_connection_closed_attributes_deprecation_defaults(self): exception = ConnectionClosed(None, None, None) with self.assertDeprecationWarning( "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code" ): self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) with self.assertDeprecationWarning( "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "") def test_payload_too_big_with_message(self): with self.assertDeprecationWarning( "PayloadTooBig(message) is deprecated; " "change to PayloadTooBig(size, max_size)", ): exc = PayloadTooBig("payload length exceeds limit: 2 > 1 bytes") self.assertEqual(str(exc), "payload length exceeds limit: 2 > 1 bytes") websockets-15.0.1/tests/test_exports.py000066400000000000000000000023161476212450300202340ustar00rootroot00000000000000import unittest import websockets import websockets.asyncio.client import websockets.asyncio.router import websockets.asyncio.server import websockets.client import websockets.datastructures import websockets.exceptions import websockets.server import websockets.typing import websockets.uri combined_exports = [ name for name in ( [] + websockets.asyncio.client.__all__ + websockets.asyncio.router.__all__ + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ + websockets.exceptions.__all__ + websockets.frames.__all__ + websockets.http11.__all__ + websockets.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ ) if not name.isupper() # filter out constants ] class ExportsTests(unittest.TestCase): def test_top_level_module_reexports_submodule_exports(self): self.assertEqual( set(combined_exports), set(websockets.__all__), ) def test_submodule_exports_are_globally_unique(self): self.assertEqual( len(set(combined_exports)), len(combined_exports), ) websockets-15.0.1/tests/test_frames.py000066400000000000000000000325251476212450300200120ustar00rootroot00000000000000import codecs import dataclasses import unittest from unittest.mock import patch from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import * from websockets.frames import CloseCode from websockets.streams import StreamReader from .utils import GeneratorTestCase class FramesTestCase(GeneratorTestCase): def parse(self, data, mask, max_size=None, extensions=None): """ Parse a frame from a bytestring. """ reader = StreamReader() reader.feed_data(data) reader.feed_eof() parser = Frame.parse( reader.read_exact, mask=mask, max_size=max_size, extensions=extensions ) return self.assertGeneratorReturns(parser) def assertFrameData(self, frame, data, mask, extensions=None): """ Serializing frame yields data. Parsing data yields frame. """ # Compare frames first, because test failures are easier to read, # especially when mask = True. parsed = self.parse(data, mask=mask, extensions=extensions) self.assertEqual(parsed, frame) # Make masking deterministic by reusing the same "random" mask. # This has an effect only when mask is True. mask_bytes = data[2:6] if mask else b"" with patch("secrets.token_bytes", return_value=mask_bytes): serialized = frame.serialize(mask=mask, extensions=extensions) self.assertEqual(serialized, data) class FrameTests(FramesTestCase): def test_text_unmasked(self): self.assertFrameData( Frame(OP_TEXT, b"Spam"), b"\x81\x04Spam", mask=False, ) def test_text_masked(self): self.assertFrameData( Frame(OP_TEXT, b"Spam"), b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", mask=True, ) def test_binary_unmasked(self): self.assertFrameData( Frame(OP_BINARY, b"Eggs"), b"\x82\x04Eggs", mask=False, ) def test_binary_masked(self): self.assertFrameData( Frame(OP_BINARY, b"Eggs"), b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", mask=True, ) def test_non_ascii_text_unmasked(self): self.assertFrameData( Frame(OP_TEXT, "café".encode()), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( Frame(OP_TEXT, "café".encode()), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) def test_close(self): self.assertFrameData( Frame(OP_CLOSE, b""), b"\x88\x00", mask=False, ) def test_ping(self): self.assertFrameData( Frame(OP_PING, b"ping"), b"\x89\x04ping", mask=False, ) def test_pong(self): self.assertFrameData( Frame(OP_PONG, b"pong"), b"\x8a\x04pong", mask=False, ) def test_long(self): self.assertFrameData( Frame(OP_BINARY, 126 * b"a"), b"\x82\x7e\x00\x7e" + 126 * b"a", mask=False, ) def test_very_long(self): self.assertFrameData( Frame(OP_BINARY, 65536 * b"a"), b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", mask=False, ) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): self.parse(b"\x82\x7e\x04\x01" + 1025 * b"a", mask=False, max_size=1024) def test_bad_reserved_bits(self): for data in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: with self.subTest(data=data): with self.assertRaises(ProtocolError): self.parse(data, mask=False) def test_good_opcode(self): for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): data = bytes([0x80 | opcode, 0]) with self.subTest(data=data): self.parse(data, mask=False) # does not raise an exception def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): data = bytes([0x80 | opcode, 0]) with self.subTest(data=data): with self.assertRaises(ProtocolError): self.parse(data, mask=False) def test_mask_flag(self): # Mask flag correctly set. self.parse(b"\x80\x80\x00\x00\x00\x00", mask=True) # Mask flag incorrectly unset. with self.assertRaises(ProtocolError): self.parse(b"\x80\x80\x00\x00\x00\x00", mask=False) # Mask flag correctly unset. self.parse(b"\x80\x00", mask=False) # Mask flag incorrectly set. with self.assertRaises(ProtocolError): self.parse(b"\x80\x00", mask=True) def test_control_frame_max_length(self): # At maximum allowed length. self.parse(b"\x88\x7e\x00\x7d" + 125 * b"a", mask=False) # Above maximum allowed length. with self.assertRaises(ProtocolError): self.parse(b"\x88\x7e\x00\x7e" + 126 * b"a", mask=False) def test_fragmented_control_frame(self): # Fin bit correctly set. self.parse(b"\x88\x00", mask=False) # Fin bit incorrectly unset. with self.assertRaises(ProtocolError): self.parse(b"\x08\x00", mask=False) def test_extensions(self): class Rot13: @staticmethod def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod def decode(frame, *, max_size=None): return Rot13.encode(frame) self.assertFrameData( Frame(OP_TEXT, b"hello"), b"\x81\x05uryyb", mask=False, extensions=[Rot13()], ) class StrTests(unittest.TestCase): def test_cont_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me", fin=False)), "CONT ' crème' [text, 7 bytes, continued]", ) def test_cont_binary(self): self.assertEqual( str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff", fin=False)), "CONT fc fd fe ff [binary, 4 bytes, continued]", ) def test_cont_binary_from_memoryview(self): self.assertEqual( str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"), fin=False)), "CONT fc fd fe ff [binary, 4 bytes, continued]", ) def test_cont_final_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me")), "CONT ' crème' [text, 7 bytes]", ) def test_cont_final_binary(self): self.assertEqual( str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff")), "CONT fc fd fe ff [binary, 4 bytes]", ) def test_cont_final_binary_from_memoryview(self): self.assertEqual( str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"))), "CONT fc fd fe ff [binary, 4 bytes]", ) def test_cont_text_truncated(self): self.assertEqual( str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), "CONT 'café café café café café café café café café ca..." "fé café café café café ' [text, 96 bytes, continued]", ) def test_cont_binary_truncated(self): self.assertEqual( str(Frame(OP_CONT, bytes(range(256)), fin=False)), "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) def test_cont_binary_truncated_from_memoryview(self): self.assertEqual( str(Frame(OP_CONT, memoryview(bytes(range(256))), fin=False)), "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) def test_text(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9")), "TEXT 'café' [5 bytes]", ) def test_text_non_final(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9", fin=False)), "TEXT 'café' [5 bytes, continued]", ) def test_text_truncated(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9 " * 16)), "TEXT 'café café café café café café café café café ca..." "fé café café café café ' [96 bytes]", ) def test_text_with_newline(self): self.assertEqual( str(Frame(OP_TEXT, b"Hello\nworld!")), "TEXT 'Hello\\nworld!' [12 bytes]", ) def test_binary(self): self.assertEqual( str(Frame(OP_BINARY, b"\x00\x01\x02\x03")), "BINARY 00 01 02 03 [4 bytes]", ) def test_binary_from_memoryview(self): self.assertEqual( str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"))), "BINARY 00 01 02 03 [4 bytes]", ) def test_binary_non_final(self): self.assertEqual( str(Frame(OP_BINARY, b"\x00\x01\x02\x03", fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) def test_binary_non_final_from_memoryview(self): self.assertEqual( str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"), fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) def test_binary_truncated(self): self.assertEqual( str(Frame(OP_BINARY, bytes(range(256)))), "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [256 bytes]", ) def test_binary_truncated_from_memoryview(self): self.assertEqual( str(Frame(OP_BINARY, memoryview(bytes(range(256))))), "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [256 bytes]", ) def test_close(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe8")), "CLOSE 1000 (OK) [2 bytes]", ) def test_close_reason(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe9Bye!")), "CLOSE 1001 (going away) Bye! [6 bytes]", ) def test_ping(self): self.assertEqual( str(Frame(OP_PING, b"")), "PING '' [0 bytes]", ) def test_ping_text(self): self.assertEqual( str(Frame(OP_PING, b"ping")), "PING 'ping' [text, 4 bytes]", ) def test_ping_text_with_newline(self): self.assertEqual( str(Frame(OP_PING, b"ping\n")), "PING 'ping\\n' [text, 5 bytes]", ) def test_ping_binary(self): self.assertEqual( str(Frame(OP_PING, b"\xff\x00\xff\x00")), "PING ff 00 ff 00 [binary, 4 bytes]", ) def test_pong(self): self.assertEqual( str(Frame(OP_PONG, b"")), "PONG '' [0 bytes]", ) def test_pong_text(self): self.assertEqual( str(Frame(OP_PONG, b"pong")), "PONG 'pong' [text, 4 bytes]", ) def test_pong_text_with_newline(self): self.assertEqual( str(Frame(OP_PONG, b"pong\n")), "PONG 'pong\\n' [text, 5 bytes]", ) def test_pong_binary(self): self.assertEqual( str(Frame(OP_PONG, b"\xff\x00\xff\x00")), "PONG ff 00 ff 00 [binary, 4 bytes]", ) class CloseTests(unittest.TestCase): def assertCloseData(self, close, data): """ Serializing close yields data. Parsing data yields close. """ serialized = close.serialize() self.assertEqual(serialized, data) parsed = Close.parse(data) self.assertEqual(parsed, close) def test_str(self): self.assertEqual( str(Close(CloseCode.NORMAL_CLOSURE, "")), "1000 (OK)", ) self.assertEqual( str(Close(CloseCode.GOING_AWAY, "Bye!")), "1001 (going away) Bye!", ) self.assertEqual( str(Close(3000, "")), "3000 (registered)", ) self.assertEqual( str(Close(4000, "")), "4000 (private use)", ) self.assertEqual( str(Close(5000, "")), "5000 (unknown)", ) def test_parse_and_serialize(self): self.assertCloseData( Close(CloseCode.NORMAL_CLOSURE, "OK"), b"\x03\xe8OK", ) self.assertCloseData( Close(CloseCode.GOING_AWAY, ""), b"\x03\xe9", ) def test_parse_empty(self): self.assertEqual( Close.parse(b""), Close(CloseCode.NO_STATUS_RCVD, ""), ) def test_parse_errors(self): with self.assertRaises(ProtocolError): Close.parse(b"\x03") with self.assertRaises(ProtocolError): Close.parse(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): Close.parse(b"\x03\xe8\xff\xff") def test_serialize_errors(self): with self.assertRaises(ProtocolError): Close(999, "").serialize() websockets-15.0.1/tests/test_headers.py000066400000000000000000000223541476212450300201470ustar00rootroot00000000000000import unittest from websockets.exceptions import InvalidHeaderFormat, InvalidHeaderValue from websockets.headers import * class HeadersTests(unittest.TestCase): def test_build_host(self): for (host, port, secure), (result, result_with_port) in [ (("localhost", 80, False), ("localhost", "localhost:80")), (("localhost", 8000, False), ("localhost:8000", "localhost:8000")), (("localhost", 443, True), ("localhost", "localhost:443")), (("localhost", 8443, True), ("localhost:8443", "localhost:8443")), (("example.com", 80, False), ("example.com", "example.com:80")), (("example.com", 8000, False), ("example.com:8000", "example.com:8000")), (("example.com", 443, True), ("example.com", "example.com:443")), (("example.com", 8443, True), ("example.com:8443", "example.com:8443")), (("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")), (("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")), (("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")), (("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")), (("::1", 80, False), ("[::1]", "[::1]:80")), (("::1", 8000, False), ("[::1]:8000", "[::1]:8000")), (("::1", 443, True), ("[::1]", "[::1]:443")), (("::1", 8443, True), ("[::1]:8443", "[::1]:8443")), ]: with self.subTest(host=host, port=port, secure=secure): self.assertEqual( build_host(host, port, secure), result, ) self.assertEqual( build_host(host, port, secure, always_include_port=True), result_with_port, ) def test_parse_connection(self): for header, parsed in [ # Realistic use cases ("Upgrade", ["Upgrade"]), # Safari, Chrome ("keep-alive, Upgrade", ["keep-alive", "Upgrade"]), # Firefox # Pathological example (",,\t, , ,Upgrade ,,", ["Upgrade"]), ]: with self.subTest(header=header): self.assertEqual(parse_connection(header), parsed) def test_parse_connection_invalid_header_format(self): for header in ["???", "keep-alive; Upgrade"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_connection(header) def test_parse_upgrade(self): for header, parsed in [ # Realistic use case ("websocket", ["websocket"]), # Synthetic example ("http/3.0, websocket", ["http/3.0", "websocket"]), # Pathological example (",, WebSocket, \t,,", ["WebSocket"]), ]: with self.subTest(header=header): self.assertEqual(parse_upgrade(header), parsed) def test_parse_upgrade_invalid_header_format(self): for header in ["???", "websocket 2", "http/3.0; websocket"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_upgrade(header) def test_parse_extension(self): for header, parsed in [ # Synthetic examples ("foo", [("foo", [])]), ("foo, bar", [("foo", []), ("bar", [])]), ( 'foo; name; token=token; quoted-string="quoted-string", ' "bar; quux; quuux", [ ( "foo", [ ("name", None), ("token", "token"), ("quoted-string", "quoted-string"), ], ), ("bar", [("quux", None), ("quuux", None)]), ], ), # Pathological example ( ",\t, , ,foo ;bar = 42,, baz,,", [("foo", [("bar", "42")]), ("baz", [])], ), # Realistic use cases for permessage-deflate ("permessage-deflate", [("permessage-deflate", [])]), ( "permessage-deflate; client_max_window_bits", [("permessage-deflate", [("client_max_window_bits", None)])], ), ( "permessage-deflate; server_max_window_bits=10", [("permessage-deflate", [("server_max_window_bits", "10")])], ), ]: with self.subTest(header=header): self.assertEqual(parse_extension(header), parsed) # Also ensure that build_extension round-trips cleanly. unparsed = build_extension(parsed) self.assertEqual(parse_extension(unparsed), parsed) def test_parse_extension_invalid_header_format(self): for header in [ # Truncated examples "", ",\t,", "foo;", "foo; bar;", "foo; bar=", 'foo; bar="baz', # Wrong delimiter "foo, bar, baz=quux; quuux", # Value in quoted string parameter that isn't a token 'foo; bar=" "', ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_extension(header) def test_parse_subprotocol(self): for header, parsed in [ # Synthetic examples ("foo", ["foo"]), ("foo, bar", ["foo", "bar"]), # Pathological example (",\t, , ,foo ,, bar,baz,,", ["foo", "bar", "baz"]), ]: with self.subTest(header=header): self.assertEqual(parse_subprotocol(header), parsed) # Also ensure that build_subprotocol round-trips cleanly. unparsed = build_subprotocol(parsed) self.assertEqual(parse_subprotocol(unparsed), parsed) def test_parse_subprotocol_invalid_header(self): for header in [ # Truncated examples "", ",\t,", # Wrong delimiter "foo; bar", ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_subprotocol(header) def test_validate_subprotocols(self): for subprotocols in [[], ["sip"], ["v1.usp"], ["sip", "v1.usp"]]: with self.subTest(subprotocols=subprotocols): validate_subprotocols(subprotocols) def test_validate_subprotocols_invalid(self): for subprotocols, exception in [ ({"sip": None}, TypeError), ("sip", TypeError), ([""], ValueError), ]: with self.subTest(subprotocols=subprotocols): with self.assertRaises(exception): validate_subprotocols(subprotocols) def test_build_www_authenticate_basic(self): # Test vector from RFC 7617 self.assertEqual( build_www_authenticate_basic("foo"), 'Basic realm="foo", charset="UTF-8"' ) def test_build_www_authenticate_basic_invalid_realm(self): # Realm contains a control character forbidden in quoted-string encoding with self.assertRaises(ValueError): build_www_authenticate_basic("\u0007") def test_build_authorization_basic(self): # Test vector from RFC 7617 self.assertEqual( build_authorization_basic("Aladdin", "open sesame"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ) def test_build_authorization_basic_utf8(self): # Test vector from RFC 7617 self.assertEqual( build_authorization_basic("test", "123£"), "Basic dGVzdDoxMjPCow==" ) def test_parse_authorization_basic(self): for header, parsed in [ ("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ("Aladdin", "open sesame")), # Password contains non-ASCII character ("Basic dGVzdDoxMjPCow==", ("test", "123£")), # Password contains a colon ("Basic YWxhZGRpbjpvcGVuOnNlc2FtZQ==", ("aladdin", "open:sesame")), # Scheme name must be case insensitive ("basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ("Aladdin", "open sesame")), ]: with self.subTest(header=header): self.assertEqual(parse_authorization_basic(header), parsed) def test_parse_authorization_basic_invalid_header_format(self): for header in [ "// Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", "Basic\tQWxhZGRpbjpvcGVuIHNlc2FtZQ==", "Basic ****************************", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ== //", ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_authorization_basic(header) def test_parse_authorization_basic_invalid_header_value(self): for header in [ "Digest ...", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ", "Basic QWxhZGNlc2FtZQ==", ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderValue): parse_authorization_basic(header) websockets-15.0.1/tests/test_http.py000066400000000000000000000011101476212450300174760ustar00rootroot00000000000000from websockets.datastructures import Headers from .utils import DeprecationTestCase class BackwardsCompatibilityTests(DeprecationTestCase): def test_headers_class(self): with self.assertDeprecationWarning( "Headers and MultipleValuesError were moved " "from websockets.http to websockets.datastructures" "and read_request and read_response were moved " "from websockets.http to websockets.legacy.http", ): from websockets.http import Headers as OldHeaders self.assertIs(OldHeaders, Headers) websockets-15.0.1/tests/test_http11.py000066400000000000000000000344161476212450300176570ustar00rootroot00000000000000from websockets.datastructures import Headers from websockets.exceptions import SecurityError from websockets.http11 import * from websockets.http11 import parse_headers from websockets.streams import StreamReader from .utils import GeneratorTestCase class RequestTests(GeneratorTestCase): def setUp(self): super().setUp() self.reader = StreamReader() def parse(self): return Request.parse(self.reader.read_line) def test_parse(self): # Example from the protocol overview in RFC 6455 self.reader.feed_data( b"GET /chat HTTP/1.1\r\n" b"Host: server.example.com\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" b"Origin: http://example.com\r\n" b"Sec-WebSocket-Protocol: chat, superchat\r\n" b"Sec-WebSocket-Version: 13\r\n" b"\r\n" ) request = self.assertGeneratorReturns(self.parse()) self.assertEqual(request.path, "/chat") self.assertEqual(request.headers["Upgrade"], "websocket") def test_parse_empty(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "connection closed while reading HTTP request line", ) def test_parse_invalid_request_line(self): self.reader.feed_data(b"GET /\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid HTTP request line: GET /", ) def test_parse_unsupported_protocol(self): self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "unsupported protocol; expected HTTP/1.1: GET /chat HTTP/1.0", ) def test_parse_unsupported_method(self): self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "unsupported HTTP method; expected GET; got OPTIONS", ) def test_parse_invalid_header(self): self.reader.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid HTTP header line: Oops", ) def test_parse_body(self): self.reader.feed_data(b"GET / HTTP/1.1\r\nContent-Length: 3\r\n\r\nYo\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "unsupported request body", ) def test_parse_body_with_transfer_encoding(self): self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "transfer codings aren't supported", ) def test_serialize(self): # Example from the protocol overview in RFC 6455 request = Request( "/chat", Headers( [ ("Host", "server.example.com"), ("Upgrade", "websocket"), ("Connection", "Upgrade"), ("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="), ("Origin", "http://example.com"), ("Sec-WebSocket-Protocol", "chat, superchat"), ("Sec-WebSocket-Version", "13"), ] ), ) self.assertEqual( request.serialize(), b"GET /chat HTTP/1.1\r\n" b"Host: server.example.com\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" b"Origin: http://example.com\r\n" b"Sec-WebSocket-Protocol: chat, superchat\r\n" b"Sec-WebSocket-Version: 13\r\n" b"\r\n", ) class ResponseTests(GeneratorTestCase): def setUp(self): super().setUp() self.reader = StreamReader() def parse(self, **kwargs): return Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, **kwargs, ) def test_parse(self): # Example from the protocol overview in RFC 6455 self.reader.feed_data( b"HTTP/1.1 101 Switching Protocols\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" b"Sec-WebSocket-Protocol: chat\r\n" b"\r\n" ) response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual(response.headers["Upgrade"], "websocket") self.assertEqual(response.body, b"") def test_parse_empty(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "connection closed while reading HTTP status line", ) def test_parse_invalid_status_line(self): self.reader.feed_data(b"Hello!\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid HTTP status line: Hello!", ) def test_parse_unsupported_protocol(self): self.reader.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "unsupported protocol; expected HTTP/1.1: HTTP/1.0 400 Bad Request", ) def test_parse_non_integer_status(self): self.reader.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid status code; expected integer; got OMG", ) def test_parse_non_three_digit_status(self): self.reader.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid status code; expected 100–599; got 007" ) def test_parse_invalid_reason(self): self.reader.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid HTTP reason phrase: \x7f", ) def test_parse_invalid_header(self): self.reader.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "invalid HTTP header line: Oops", ) def test_parse_body(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\nHello world!\n") gen = self.parse() self.assertGeneratorRunning(gen) self.reader.feed_eof() response = self.assertGeneratorReturns(gen) self.assertEqual(response.body, b"Hello world!\n") def test_parse_body_too_large(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "body too large: over 1048576 bytes", ) def test_parse_body_with_content_length(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" ) response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"Hello world!\n") def test_parse_body_with_content_length_and_body_too_large(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n") with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "body too large: 1048577 bytes", ) def test_parse_body_with_content_length_and_body_way_too_large(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nContent-Length: 1234567890123456789\r\n\r\n" ) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "body too large: 1234567890123456789 bytes", ) def test_parse_body_with_chunked_transfer_encoding(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" b"6\r\nHello \r\n7\r\nworld!\n\r\n0\r\n\r\n" ) response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"Hello world!\n") def test_parse_body_with_chunked_transfer_encoding_and_chunk_without_crlf(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" b"6\r\nHello 7\r\nworld!\n0\r\n" ) with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "chunk without CRLF", ) def test_parse_body_with_chunked_transfer_encoding_and_chunk_too_large(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" b"100000\r\n" + b"a" * 1048576 + b"\r\n1\r\na\r\n0\r\n\r\n" ) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "chunk too large: 1 bytes after 1048576 bytes", ) def test_parse_body_with_chunked_transfer_encoding_and_chunk_way_too_large(self): self.reader.feed_data( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" b"1234567890ABCDEF\r\n\r\n" ) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "chunk too large: 0x1234567890ABCDEF bytes", ) def test_parse_body_with_unsupported_transfer_encoding(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), "transfer coding compress isn't supported", ) def test_parse_body_no_content(self): self.reader.feed_data(b"HTTP/1.1 204 No Content\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"") def test_parse_body_not_modified(self): self.reader.feed_data(b"HTTP/1.1 304 Not Modified\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"") def test_parse_without_body(self): self.reader.feed_data(b"HTTP/1.1 200 Connection Established\r\n\r\n") response = self.assertGeneratorReturns(self.parse(include_body=False)) self.assertEqual(response.body, b"") def test_serialize(self): # Example from the protocol overview in RFC 6455 response = Response( 101, "Switching Protocols", Headers( [ ("Upgrade", "websocket"), ("Connection", "Upgrade"), ("Sec-WebSocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="), ("Sec-WebSocket-Protocol", "chat"), ] ), ) self.assertEqual( response.serialize(), b"HTTP/1.1 101 Switching Protocols\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" b"Sec-WebSocket-Protocol: chat\r\n" b"\r\n", ) def test_serialize_with_body(self): response = Response( 200, "OK", Headers([("Content-Length", "13"), ("Content-Type", "text/plain")]), b"Hello world!\n", ) self.assertEqual( response.serialize(), b"HTTP/1.1 200 OK\r\n" b"Content-Length: 13\r\n" b"Content-Type: text/plain\r\n" b"\r\n" b"Hello world!\n", ) class HeadersTests(GeneratorTestCase): def setUp(self): super().setUp() self.reader = StreamReader() def parse_headers(self): return parse_headers(self.reader.read_line) def test_parse_invalid_name(self): self.reader.feed_data(b"foo bar: baz qux\r\n\r\n") with self.assertRaises(ValueError): next(self.parse_headers()) def test_parse_invalid_value(self): self.reader.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") with self.assertRaises(ValueError): next(self.parse_headers()) def test_parse_too_long_value(self): self.reader.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) def test_parse_too_long_line(self): # Header line contains 5 + 8186 + 2 = 8193 bytes. self.reader.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) def test_parse_invalid_line_ending(self): self.reader.feed_data(b"foo: bar\n\n") with self.assertRaises(EOFError): next(self.parse_headers()) websockets-15.0.1/tests/test_imports.py000066400000000000000000000032651476212450300202310ustar00rootroot00000000000000import types import unittest import warnings from websockets.imports import * foo = object() bar = object() class ImportsTests(unittest.TestCase): def setUp(self): self.mod = types.ModuleType("tests.test_imports.test_alias") self.mod.__package__ = self.mod.__name__ def test_get_alias(self): lazy_import( vars(self.mod), aliases={"foo": "...test_imports"}, ) self.assertEqual(self.mod.foo, foo) def test_get_deprecated_alias(self): lazy_import( vars(self.mod), deprecated_aliases={"bar": "...test_imports"}, ) with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") self.assertEqual(self.mod.bar, bar) self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0].message self.assertEqual( str(warning), "tests.test_imports.test_alias.bar is deprecated" ) self.assertEqual(type(warning), DeprecationWarning) def test_dir(self): lazy_import( vars(self.mod), aliases={"foo": "...test_imports"}, deprecated_aliases={"bar": "...test_imports"}, ) self.assertEqual( [item for item in dir(self.mod) if not item[:2] == item[-2:] == "__"], ["bar", "foo"], ) def test_attribute_error(self): lazy_import(vars(self.mod)) with self.assertRaises(AttributeError) as raised: self.mod.foo self.assertEqual( str(raised.exception), "module 'tests.test_imports.test_alias' has no attribute 'foo'", ) websockets-15.0.1/tests/test_localhost.cnf000066400000000000000000000004331476212450300206340ustar00rootroot00000000000000[ req ] default_md = sha256 encrypt_key = no prompt = no distinguished_name = dn x509_extensions = ext [ dn ] C = "FR" L = "Paris" O = "Aymeric Augustin" CN = "localhost" [ ext ] subjectAltName = @san [ san ] DNS.1 = localhost DNS.2 = overridden IP.3 = 127.0.0.1 IP.4 = ::1 websockets-15.0.1/tests/test_localhost.pem000066400000000000000000000055541476212450300206600ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T /8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh 89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO 9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI a/u/dlZs+/K16RcavQwx8rag -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH 9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v 2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh 4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa -----END CERTIFICATE----- websockets-15.0.1/tests/test_protocol.py000066400000000000000000002206321476212450300203740ustar00rootroot00000000000000import logging import unittest from unittest.mock import patch from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, InvalidState, PayloadTooBig, ProtocolError, ) from websockets.frames import ( OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Close, CloseCode, Frame, ) from websockets.protocol import * from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase class ProtocolTestCase(FramesTestCase): def assertFrameSent(self, connection, frame, eof=False): """ Outgoing data for ``connection`` contains the given frame. ``frame`` may be ``None`` if no frame is expected. When ``eof`` is ``True``, the end of the stream is also expected. """ frames_sent = [ ( None if write is SEND_EOF else self.parse( write, mask=connection.side is CLIENT, extensions=connection.extensions, ) ) for write in connection.data_to_send() ] frames_expected = [] if frame is None else [frame] if eof: frames_expected += [None] self.assertEqual(frames_sent, frames_expected) def assertFrameReceived(self, connection, frame): """ Incoming data for ``connection`` contains the given frame. ``frame`` may be ``None`` if no frame is expected. """ frames_received = connection.events_received() frames_expected = [] if frame is None else [frame] self.assertEqual(frames_received, frames_expected) def assertConnectionClosing(self, connection, code=None, reason=""): """ Incoming data caused the "Start the WebSocket Closing Handshake" process. """ close_frame = Frame( OP_CLOSE, b"" if code is None else Close(code, reason).serialize(), ) # A close frame was received. self.assertFrameReceived(connection, close_frame) # A close frame and possibly the end of stream were sent. self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) def assertConnectionFailing(self, connection, code=None, reason=""): """ Incoming data caused the "Fail the WebSocket Connection" process. """ close_frame = Frame( OP_CLOSE, b"" if code is None else Close(code, reason).serialize(), ) # No frame was received. self.assertFrameReceived(connection, None) # A close frame and possibly the end of stream were sent. self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) class MaskingTests(ProtocolTestCase): """ Test frame masking. 5.1. Overview """ unmasked_text_frame_date = b"\x81\x04Spam" masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" def test_client_sends_masked_frame(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) def test_server_sends_unmasked_frame(self): server = Protocol(SERVER) server.send_text(b"Spam", True) self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) def test_client_receives_unmasked_frame(self): client = Protocol(CLIENT) client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( client, Frame(OP_TEXT, b"Spam"), ) def test_server_receives_masked_frame(self): server = Protocol(SERVER) server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( server, Frame(OP_TEXT, b"Spam"), ) def test_client_receives_masked_frame(self): client = Protocol(CLIENT) client.receive_data(self.masked_text_frame_data) self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incorrect masking") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "incorrect masking" ) def test_server_receives_unmasked_frame(self): server = Protocol(SERVER) server.receive_data(self.unmasked_text_frame_date) self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incorrect masking") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "incorrect masking" ) class ContinuationTests(ProtocolTestCase): """ Test continuation frames without text or binary frames. """ def test_client_sends_unexpected_continuation(self): client = Protocol(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_server_sends_unexpected_continuation(self): server = Protocol(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_unexpected_continuation(self): client = Protocol(CLIENT) client.receive_data(b"\x00\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "unexpected continuation frame") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" ) def test_server_receives_unexpected_continuation(self): server = Protocol(SERVER) server.receive_data(b"\x00\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "unexpected continuation frame") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" ) def test_client_sends_continuation_after_sending_close(self): client = Protocol(CLIENT) # Since it isn't possible to send a close frame in a fragmented # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_server_sends_continuation_after_sending_close(self): # Since it isn't possible to send a close frame in a fragmented # message (see test_server_send_close_in_fragmented_message), in fact, # this is the same test as test_server_sends_unexpected_continuation. server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_continuation_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x00\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_continuation_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x00\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) class TextTests(ProtocolTestCase): """ Test text frames and continuation frames. """ def test_client_sends_text(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] ) def test_server_sends_text(self): server = Protocol(SERVER) server.send_text("😀".encode()) self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) def test_client_receives_text(self): client = Protocol(CLIENT) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text(self): server = Protocol(SERVER) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, Frame(OP_TEXT, "😀".encode()), ) def test_client_receives_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual( str(client.parser_exc), "frame with 4 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( client, CloseCode.MESSAGE_TOO_BIG, "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual( str(server.parser_exc), "frame with 4 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( server, CloseCode.MESSAGE_TOO_BIG, "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_receives_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text_without_size_limit(self): server = Protocol(SERVER, max_size=None) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, Frame(OP_TEXT, "😀".encode()), ) def test_client_sends_fragmented_text(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀😀".encode()[2:6], fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] ) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀".encode()[2:], fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) def test_server_sends_fragmented_text(self): server = Protocol(SERVER) server.send_text("😀".encode()[:2], fin=False) self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) server.send_continuation("😀😀".encode()[2:6], fin=False) self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) server.send_continuation("😀".encode()[2:], fin=True) self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) def test_client_receives_fragmented_text(self): client = Protocol(CLIENT) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text(self): server = Protocol(SERVER) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_receives_fragmented_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual( str(client.parser_exc), "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( client, CloseCode.MESSAGE_TOO_BIG, "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual( str(server.parser_exc), "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( server, CloseCode.MESSAGE_TOO_BIG, "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_receives_fragmented_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text_without_size_limit(self): server = Protocol(SERVER, max_size=None) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_sends_unexpected_text(self): client = Protocol(CLIENT) client.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_text(self): server = Protocol(SERVER) server.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_text(self): client = Protocol(CLIENT) client.receive_data(b"\x01\x00") self.assertFrameReceived( client, Frame(OP_TEXT, b"", fin=False), ) client.receive_data(b"\x01\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" ) def test_server_receives_unexpected_text(self): server = Protocol(SERVER) server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( server, Frame(OP_TEXT, b"", fin=False), ) server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" ) def test_client_sends_text_after_sending_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: client.send_text(b"") self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState) as raised: server.send_text(b"") self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x81\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_text_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x81\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) class BinaryTests(ProtocolTestCase): """ Test binary frames and continuation frames. """ def test_client_sends_binary(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] ) def test_server_sends_binary(self): server = Protocol(SERVER) server.send_binary(b"\x01\x02\xfe\xff") self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) def test_client_receives_binary(self): client = Protocol(CLIENT) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( client, Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_server_receives_binary(self): server = Protocol(SERVER) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( server, Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_client_receives_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual( str(client.parser_exc), "frame with 4 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( client, CloseCode.MESSAGE_TOO_BIG, "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual( str(server.parser_exc), "frame with 4 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( server, CloseCode.MESSAGE_TOO_BIG, "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_sends_fragmented_binary(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff\x01\x02", fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] ) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff", fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) def test_server_sends_fragmented_binary(self): server = Protocol(SERVER) server.send_binary(b"\x01\x02", fin=False) self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) server.send_continuation(b"\xee\xff\x01\x02", fin=False) self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) server.send_continuation(b"\xee\xff", fin=True) self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) def test_client_receives_fragmented_binary(self): client = Protocol(CLIENT) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, Frame(OP_BINARY, b"\x01\x02", fin=False), ) client.receive_data(b"\x00\x04\xfe\xff\x01\x02") self.assertFrameReceived( client, Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), ) client.receive_data(b"\x80\x02\xfe\xff") self.assertFrameReceived( client, Frame(OP_CONT, b"\xfe\xff"), ) def test_server_receives_fragmented_binary(self): server = Protocol(SERVER) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, Frame(OP_BINARY, b"\x01\x02", fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") self.assertFrameReceived( server, Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertFrameReceived( server, Frame(OP_CONT, b"\xfe\xff"), ) def test_client_receives_fragmented_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, Frame(OP_BINARY, b"\x01\x02", fin=False), ) client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual( str(client.parser_exc), "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( client, CloseCode.MESSAGE_TOO_BIG, "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, Frame(OP_BINARY, b"\x01\x02", fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual( str(server.parser_exc), "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) self.assertConnectionFailing( server, CloseCode.MESSAGE_TOO_BIG, "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_sends_unexpected_binary(self): client = Protocol(CLIENT) client.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_binary(self): server = Protocol(SERVER) server.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_binary(self): client = Protocol(CLIENT) client.receive_data(b"\x02\x00") self.assertFrameReceived( client, Frame(OP_BINARY, b"", fin=False), ) client.receive_data(b"\x02\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" ) def test_server_receives_unexpected_binary(self): server = Protocol(SERVER) server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( server, Frame(OP_BINARY, b"", fin=False), ) server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" ) def test_client_sends_binary_after_sending_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: client.send_binary(b"") self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState) as raised: server.send_binary(b"") self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x82\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_binary_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x82\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) class CloseTests(ProtocolTestCase): """ Test close frames. See RFC 6455: 5.5.1. Close 7.1.6. The WebSocket Connection Close Reason 7.1.7. Fail the WebSocket Connection """ def test_close_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") client.receive_eof() self.assertEqual(client.close_code, CloseCode.NORMAL_CLOSURE) def test_close_reason(self): server = Protocol(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") server.receive_eof() self.assertEqual(server.close_reason, "OK") def test_close_code_not_provided(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") server.receive_eof() self.assertEqual(server.close_code, CloseCode.NO_STATUS_RCVD) def test_close_reason_not_provided(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertEqual(client.close_reason, "") def test_close_code_not_available(self): client = Protocol(CLIENT) client.receive_eof() self.assertEqual(client.close_code, CloseCode.ABNORMAL_CLOSURE) def test_close_reason_not_available(self): server = Protocol(SERVER) server.receive_eof() self.assertEqual(server.close_reason, "") def test_close_code_not_available_yet(self): server = Protocol(SERVER) self.assertIsNone(server.close_code) def test_close_reason_not_available_yet(self): client = Protocol(CLIENT) self.assertIsNone(client.close_reason) def test_client_sends_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) def test_server_sends_close(self): server = Protocol(SERVER) server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) def test_client_receives_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) def test_server_receives_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) self.assertIs(server.state, CLOSING) def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side. client = Protocol(CLIENT) client.send_close() self.assertFrameReceived(client, None) self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_data(b"\x88\x00") self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) self.assertFrameSent(client, None) client.receive_eof() self.assertFrameReceived(client, None) self.assertFrameSent(client, None, eof=True) def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side. server = Protocol(SERVER) server.send_close() self.assertFrameReceived(server, None) self.assertFrameSent(server, Frame(OP_CLOSE, b"")) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) self.assertFrameSent(server, None, eof=True) server.receive_eof() self.assertFrameReceived(server, None) self.assertFrameSent(server, None) def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side. client = Protocol(CLIENT) client.receive_data(b"\x88\x00") self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_eof() self.assertFrameReceived(client, None) self.assertFrameSent(client, None, eof=True) def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side. server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) server.receive_eof() self.assertFrameReceived(server, None) self.assertFrameSent(server, None) def test_client_sends_close_with_code(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_code_and_reason(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code_and_reason(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE, "OK") self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code_and_reason(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "OK") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code_and_reason(self): server = Protocol(SERVER) server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "going away") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_reason_only(self): client = Protocol(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_close(reason="going away") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_server_sends_close_with_reason_only(self): server = Protocol(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_close(reason="OK") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_client_receives_close_with_truncated_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x01\x03") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "close frame too short") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "close frame too short" ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_truncated_code(self): server = Protocol(SERVER) server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "close frame too short") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "close frame too short" ) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_non_utf8_reason(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8\xff\xff") self.assertIsInstance(client.parser_exc, UnicodeDecodeError) self.assertEqual( str(client.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing( client, CloseCode.INVALID_DATA, "invalid start byte at position 0" ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_non_utf8_reason(self): server = Protocol(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") self.assertIsInstance(server.parser_exc, UnicodeDecodeError) self.assertEqual( str(server.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing( server, CloseCode.INVALID_DATA, "invalid start byte at position 0" ) self.assertIs(server.state, CLOSING) def test_client_sends_close_twice(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: client.send_close(CloseCode.GOING_AWAY) self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_close_twice(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState) as raised: server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(str(raised.exception), "connection is closing") def test_client_sends_close_after_connection_is_closed(self): client = Protocol(CLIENT) client.receive_eof() with self.assertRaises(InvalidState) as raised: client.send_close(CloseCode.GOING_AWAY) self.assertEqual(str(raised.exception), "connection is closed") def test_server_sends_close_after_connection_is_closed(self): server = Protocol(SERVER) server.receive_eof() with self.assertRaises(InvalidState) as raised: server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(str(raised.exception), "connection is closed") class PingTests(ProtocolTestCase): """ Test ping. See 5.5.2. Ping in RFC 6455. """ def test_client_sends_ping(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping(self): server = Protocol(SERVER) server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping(self): client = Protocol(CLIENT) client.receive_data(b"\x89\x00") self.assertFrameReceived( client, Frame(OP_PING, b""), ) self.assertFrameSent( client, Frame(OP_PONG, b""), ) def test_server_receives_ping(self): server = Protocol(SERVER) server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, Frame(OP_PING, b""), ) self.assertFrameSent( server, Frame(OP_PONG, b""), ) def test_client_sends_ping_with_data(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] ) def test_server_sends_ping_with_data(self): server = Protocol(SERVER) server.send_ping(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) def test_client_receives_ping_with_data(self): client = Protocol(CLIENT) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( client, Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_ping_with_data(self): server = Protocol(SERVER) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( server, Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_ping_frame(self): client = Protocol(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_ping_frame(self): server = Protocol(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_ping_frame(self): client = Protocol(CLIENT) client.receive_data(b"\x09\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" ) def test_server_receives_fragmented_ping_frame(self): server = Protocol(SERVER) server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" ) def test_client_sends_ping_after_sending_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_ping_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) def test_client_sends_ping_after_connection_is_closed(self): client = Protocol(CLIENT) client.receive_eof() with self.assertRaises(InvalidState) as raised: client.send_ping(b"") self.assertEqual(str(raised.exception), "connection is closed") def test_server_sends_ping_after_connection_is_closed(self): server = Protocol(SERVER) server.receive_eof() with self.assertRaises(InvalidState) as raised: server.send_ping(b"") self.assertEqual(str(raised.exception), "connection is closed") class PongTests(ProtocolTestCase): """ Test pong frames. See 5.5.3. Pong in RFC 6455. """ def test_client_sends_pong(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong(self): server = Protocol(SERVER) server.send_pong(b"") self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong(self): client = Protocol(CLIENT) client.receive_data(b"\x8a\x00") self.assertFrameReceived( client, Frame(OP_PONG, b""), ) def test_server_receives_pong(self): server = Protocol(SERVER) server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, Frame(OP_PONG, b""), ) def test_client_sends_pong_with_data(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] ) def test_server_sends_pong_with_data(self): server = Protocol(SERVER) server.send_pong(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) def test_client_receives_pong_with_data(self): client = Protocol(CLIENT) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_with_data(self): server = Protocol(SERVER) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_pong_frame(self): client = Protocol(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_pong_frame(self): server = Protocol(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_pong_frame(self): client = Protocol(CLIENT) client.receive_data(b"\x0a\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" ) def test_server_receives_fragmented_pong_frame(self): server = Protocol(SERVER) server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" ) def test_client_sends_pong_after_sending_close(self): client = Protocol(CLIENT) with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) server.send_pong(b"") self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_pong_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) def test_client_sends_pong_after_connection_is_closed(self): client = Protocol(CLIENT) client.receive_eof() with self.assertRaises(InvalidState) as raised: client.send_pong(b"") self.assertEqual(str(raised.exception), "connection is closed") def test_server_sends_pong_after_connection_is_closed(self): server = Protocol(SERVER) server.receive_eof() with self.assertRaises(InvalidState) as raised: server.send_pong(b"") self.assertEqual(str(raised.exception), "connection is closed") class FailTests(ProtocolTestCase): """ Test failing the connection. See 7.1.7. Fail the WebSocket Connection in RFC 6455. """ def test_client_stops_processing_frames_after_fail(self): client = Protocol(CLIENT) client.fail(CloseCode.PROTOCOL_ERROR) self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR) client.receive_data(b"\x88\x02\x03\xea") self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): server = Protocol(SERVER) server.fail(CloseCode.PROTOCOL_ERROR) self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") self.assertFrameReceived(server, None) class FragmentationTests(ProtocolTestCase): """ Test message fragmentation. See 5.4. Fragmentation in RFC 6455. """ def test_client_send_ping_pong_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) client.send_ping(b"Ping") self.assertFrameSent(client, Frame(OP_PING, b"Ping")) client.send_continuation(b"Ham", fin=False) self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) client.send_pong(b"Pong") self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) client.send_continuation(b"Eggs", fin=True) self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) def test_server_send_ping_pong_in_fragmented_message(self): server = Protocol(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_ping(b"Ping") self.assertFrameSent(server, Frame(OP_PING, b"Ping")) server.send_continuation(b"Ham", fin=False) self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) server.send_pong(b"Pong") self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) server.send_continuation(b"Eggs", fin=True) self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) def test_client_receive_ping_pong_in_fragmented_message(self): client = Protocol(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, Frame(OP_TEXT, b"Spam", fin=False), ) client.receive_data(b"\x89\x04Ping") self.assertFrameReceived( client, Frame(OP_PING, b"Ping"), ) self.assertFrameSent( client, Frame(OP_PONG, b"Ping"), ) client.receive_data(b"\x00\x03Ham") self.assertFrameReceived( client, Frame(OP_CONT, b"Ham", fin=False), ) client.receive_data(b"\x8a\x04Pong") self.assertFrameReceived( client, Frame(OP_PONG, b"Pong"), ) client.receive_data(b"\x80\x04Eggs") self.assertFrameReceived( client, Frame(OP_CONT, b"Eggs"), ) def test_server_receive_ping_pong_in_fragmented_message(self): server = Protocol(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, Frame(OP_TEXT, b"Spam", fin=False), ) server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") self.assertFrameReceived( server, Frame(OP_PING, b"Ping"), ) self.assertFrameSent( server, Frame(OP_PONG, b"Ping"), ) server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") self.assertFrameReceived( server, Frame(OP_CONT, b"Ham", fin=False), ) server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") self.assertFrameReceived( server, Frame(OP_PONG, b"Pong"), ) server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") self.assertFrameReceived( server, Frame(OP_CONT, b"Eggs"), ) def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) with self.assertRaises(InvalidState) as raised: client.send_continuation(b"Eggs", fin=True) self.assertEqual(str(raised.exception), "connection is closing") def test_server_send_close_in_fragmented_message(self): server = Protocol(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) with self.assertRaises(InvalidState) as raised: server.send_continuation(b"Eggs", fin=True) self.assertEqual(str(raised.exception), "connection is closing") def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, Frame(OP_TEXT, b"Spam", fin=False), ) client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") self.assertConnectionFailing( client, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" ) def test_server_receive_close_in_fragmented_message(self): server = Protocol(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, Frame(OP_TEXT, b"Spam", fin=False), ) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") self.assertConnectionFailing( server, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" ) class EOFTests(ProtocolTestCase): """ Test half-closes on connection termination. """ def test_client_receives_eof(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() self.assertIs(client.state, CLOSED) def test_server_receives_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() self.assertIs(server.state, CLOSED) def test_client_receives_eof_between_frames(self): client = Protocol(CLIENT) client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual(str(client.parser_exc), "unexpected end of stream") self.assertIs(client.state, CLOSED) def test_server_receives_eof_between_frames(self): server = Protocol(SERVER) server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual(str(server.parser_exc), "unexpected end of stream") self.assertIs(server.state, CLOSED) def test_client_receives_eof_inside_frame(self): client = Protocol(CLIENT) client.receive_data(b"\x81") client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual( str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) self.assertIs(client.state, CLOSED) def test_server_receives_eof_inside_frame(self): server = Protocol(SERVER) server.receive_data(b"\x81") server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual( str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) self.assertIs(server.state, CLOSED) def test_client_receives_data_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_data_and_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_after_eof(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() with self.assertRaises(EOFError) as raised: client.receive_data(b"\x88\x00") self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_data_after_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() with self.assertRaises(EOFError) as raised: server.receive_data(b"\x88\x80\x00\x00\x00\x00") self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_eof_after_eof(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() client.receive_eof() # this is idempotent def test_server_receives_eof_after_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() server.receive_eof() # this is idempotent class TCPCloseTests(ProtocolTestCase): """ Test expectation of TCP close on connection termination. """ def test_client_default(self): client = Protocol(CLIENT) self.assertFalse(client.close_expected()) def test_server_default(self): server = Protocol(SERVER) self.assertFalse(server.close_expected()) def test_client_sends_close(self): client = Protocol(CLIENT) client.send_close() self.assertTrue(client.close_expected()) def test_server_sends_close(self): server = Protocol(SERVER) server.send_close() self.assertTrue(server.close_expected()) def test_client_receives_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") self.assertTrue(client.close_expected()) def test_client_receives_close_then_eof(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertFalse(client.close_expected()) def test_server_receives_close_then_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") server.receive_eof() self.assertFalse(server.close_expected()) def test_server_receives_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertTrue(server.close_expected()) def test_client_fails_connection(self): client = Protocol(CLIENT) client.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(client.close_expected()) def test_server_fails_connection(self): server = Protocol(SERVER) server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) def test_client_is_connecting(self): client = Protocol(CLIENT, state=CONNECTING) self.assertFalse(client.close_expected()) def test_server_is_connecting(self): server = Protocol(SERVER, state=CONNECTING) self.assertFalse(server.close_expected()) def test_client_failed_connecting(self): client = Protocol(CLIENT, state=CONNECTING) client.send_eof() self.assertTrue(client.close_expected()) def test_server_failed_connecting(self): server = Protocol(SERVER, state=CONNECTING) server.send_eof() self.assertTrue(server.close_expected()) class ConnectionClosedTests(ProtocolTestCase): """ Test connection closed exception. """ def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side complete. client = Protocol(CLIENT) client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side complete. server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side complete. client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side complete. server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_client_sends_close_then_receives_eof(self): # Client-initiated close handshake on the client side times out. client = Protocol(CLIENT) client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_server_sends_close_then_receives_eof(self): # Server-initiated close handshake on the server side times out. server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_client_receives_eof(self): # Server-initiated close handshake on the client side times out. client = Protocol(CLIENT) client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) self.assertIsNone(exc.rcvd_then_sent) def test_server_receives_eof(self): # Client-initiated close handshake on the server side times out. server = Protocol(SERVER) server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) self.assertIsNone(exc.rcvd_then_sent) class ErrorTests(ProtocolTestCase): """ Test other error cases. """ def test_client_hits_internal_error_reading_frame(self): client = Protocol(CLIENT) # This isn't supposed to happen, so we're simulating it. with patch("struct.unpack", side_effect=RuntimeError("BOOM")): client.receive_data(b"\x81\x00") self.assertIsInstance(client.parser_exc, RuntimeError) self.assertEqual(str(client.parser_exc), "BOOM") self.assertConnectionFailing(client, CloseCode.INTERNAL_ERROR, "") def test_server_hits_internal_error_reading_frame(self): server = Protocol(SERVER) # This isn't supposed to happen, so we're simulating it. with patch("struct.unpack", side_effect=RuntimeError("BOOM")): server.receive_data(b"\x81\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, RuntimeError) self.assertEqual(str(server.parser_exc), "BOOM") self.assertConnectionFailing(server, CloseCode.INTERNAL_ERROR, "") class ExtensionsTests(ProtocolTestCase): """ Test how extensions affect frames. """ def test_client_extension_encodes_frame(self): client = Protocol(CLIENT) client.extensions = [Rsv2Extension()] with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) def test_server_extension_encodes_frame(self): server = Protocol(SERVER) server.extensions = [Rsv2Extension()] server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) def test_client_extension_decodes_frame(self): client = Protocol(CLIENT) client.extensions = [Rsv2Extension()] client.receive_data(b"\xaa\x00") self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) def test_server_extension_decodes_frame(self): server = Protocol(SERVER) server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) class MiscTests(unittest.TestCase): def test_client_default_logger(self): client = Protocol(CLIENT) logger = logging.getLogger("websockets.client") self.assertIs(client.logger, logger) def test_server_default_logger(self): server = Protocol(SERVER) logger = logging.getLogger("websockets.server") self.assertIs(server.logger, logger) def test_client_custom_logger(self): logger = logging.getLogger("test") client = Protocol(CLIENT, logger=logger) self.assertIs(client.logger, logger) def test_server_custom_logger(self): logger = logging.getLogger("test") server = Protocol(SERVER, logger=logger) self.assertIs(server.logger, logger) websockets-15.0.1/tests/test_server.py000066400000000000000000001061561476212450300200450ustar00rootroot00000000000000import http import logging import re import sys import unittest from unittest.mock import patch from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHeader, InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, ) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN from websockets.server import * from .extensions.utils import ( OpExtension, Rsv2Extension, ServerOpExtensionFactory, ServerRsv2ExtensionFactory, ) from .test_utils import ACCEPT, KEY from .utils import DATE, DeprecationTestCase def make_request(): """Generate a handshake request that can be altered for testing.""" return Request( path="/test", headers=Headers( { "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", } ), ) @patch("email.utils.formatdate", return_value=DATE) class BasicTests(unittest.TestCase): """Test basic opening handshake scenarios.""" def test_receive_request(self, _formatdate): """Server receives a handshake request.""" server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"\r\n" ).encode(), ) self.assertEqual(server.data_to_send(), []) self.assertFalse(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_accept_and_send_successful_response(self, _formatdate): """Server accepts a handshake request and sends a successful response.""" server = ServerProtocol() request = make_request() response = server.accept(request) server.send_response(response) self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 101 Switching Protocols\r\n" f"Date: {DATE}\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"\r\n".encode() ], ) self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) def test_send_response_after_failed_accept(self, _formatdate): """Server accepts a handshake request but sends a failed response.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) server.send_response(response) self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 400 Bad Request\r\n" f"Date: {DATE}\r\n" f"Connection: close\r\n" f"Content-Length: 73\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"\r\n" f"Failed to open a WebSocket connection: " f"missing Sec-WebSocket-Key header.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_send_response_after_reject(self, _formatdate): """Server rejects a handshake request and sends a failed response.""" server = ServerProtocol() response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") server.send_response(response) self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" f"Connection: close\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"\r\n" f"Sorry folks.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_send_response_without_accept_or_reject(self, _formatdate): """Server doesn't accept or reject and sends a failed response.""" server = ServerProtocol() server.send_response( Response( 410, "Gone", Headers( { "Connection": "close", "Content-Length": 6, "Content-Type": "text/plain", } ), b"AWOL.\n", ) ) self.assertEqual( server.data_to_send(), [ "HTTP/1.1 410 Gone\r\n" "Connection: close\r\n" "Content-Length: 6\r\n" "Content-Type: text/plain\r\n" "\r\n" "AWOL.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) class RequestTests(unittest.TestCase): """Test receiving opening handshake requests.""" def test_receive_request(self): """Server receives a handshake request.""" server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"\r\n" ).encode(), ) [request] = server.events_received() self.assertIsInstance(request, Request) self.assertEqual(request.path, "/test") self.assertEqual( request.headers, Headers( { "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", } ), ) self.assertIsNone(server.handshake_exc) def test_receive_no_request(self): """Server receives no handshake request.""" server = ServerProtocol() server.receive_eof() self.assertEqual(server.events_received(), []) self.assertEqual(server.events_received(), []) self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), "did not receive a valid HTTP request", ) self.assertIsInstance(server.handshake_exc.__cause__, EOFError) self.assertEqual( str(server.handshake_exc.__cause__), "connection closed while reading HTTP request line", ) def test_receive_truncated_request(self): """Server receives a truncated handshake request.""" server = ServerProtocol() server.receive_data(b"GET /test HTTP/1.1\r\n") server.receive_eof() self.assertEqual(server.events_received(), []) self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), "did not receive a valid HTTP request", ) self.assertIsInstance(server.handshake_exc.__cause__, EOFError) self.assertEqual( str(server.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) def test_receive_junk_request(self): """Server receives a junk handshake request.""" server = ServerProtocol() server.receive_data(b"HELO relay.invalid\r\n") server.receive_data(b"MAIL FROM: \r\n") server.receive_data(b"RCPT TO: \r\n") self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), "did not receive a valid HTTP request", ) self.assertIsInstance(server.handshake_exc.__cause__, ValueError) self.assertEqual( str(server.handshake_exc.__cause__), "invalid HTTP request line: HELO relay.invalid", ) class ResponseTests(unittest.TestCase): """Test generating opening handshake responses.""" @patch("email.utils.formatdate", return_value=DATE) def test_accept_response(self, _formatdate): """accept() creates a successful opening handshake response.""" server = ServerProtocol() request = make_request() response = server.accept(request) self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( response.headers, Headers( { "Date": DATE, "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, } ), ) self.assertEqual(response.body, b"") @patch("email.utils.formatdate", return_value=DATE) def test_reject_response(self, _formatdate): """reject() creates a failed opening handshake response.""" server = ServerProtocol() response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( response.headers, Headers( { "Date": DATE, "Connection": "close", "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", } ), ) self.assertEqual(response.body, b"Sorry folks.\n") def test_reject_response_supports_int_status(self): """reject() accepts an integer status code instead of an HTTPStatus.""" server = ServerProtocol() response = server.reject(404, "Sorry folks.\n") self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") @patch( "websockets.server.ServerProtocol.process_request", side_effect=Exception("BOOM"), ) def test_unexpected_error(self, process_request): """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() request = make_request() response = server.accept(request) self.assertEqual(response.status_code, 500) self.assertIsInstance(server.handshake_exc, Exception) self.assertEqual(str(server.handshake_exc), "BOOM") class HandshakeTests(unittest.TestCase): """Test processing of handshake responses to configure the connection.""" def assertHandshakeSuccess(self, server): """Assert that the opening handshake succeeded.""" self.assertEqual(server.state, OPEN) self.assertIsNone(server.handshake_exc) def assertHandshakeError(self, server, exc_type, msg): """Assert that the opening handshake failed with the given exception.""" self.assertEqual(server.state, CONNECTING) self.assertIsInstance(server.handshake_exc, exc_type) exc = server.handshake_exc exc_str = str(exc) while exc.__cause__ is not None: exc = exc.__cause__ exc_str += "; " + str(exc) self.assertEqual(exc_str, msg) def test_basic(self): """Handshake succeeds.""" server = ServerProtocol() request = make_request() response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) def test_missing_connection(self): """Handshake fails when the Connection header is missing.""" server = ServerProtocol() request = make_request() del request.headers["Connection"] response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") self.assertHandshakeError( server, InvalidUpgrade, "missing Connection header", ) def test_invalid_connection(self): """Handshake fails when the Connection header is invalid.""" server = ServerProtocol() request = make_request() del request.headers["Connection"] request.headers["Connection"] = "close" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") self.assertHandshakeError( server, InvalidUpgrade, "invalid Connection header: close", ) def test_missing_upgrade(self): """Handshake fails when the Upgrade header is missing.""" server = ServerProtocol() request = make_request() del request.headers["Upgrade"] response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") self.assertHandshakeError( server, InvalidUpgrade, "missing Upgrade header", ) def test_invalid_upgrade(self): """Handshake fails when the Upgrade header is invalid.""" server = ServerProtocol() request = make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") self.assertHandshakeError( server, InvalidUpgrade, "invalid Upgrade header: h2c", ) def test_missing_key(self): """Handshake fails when the Sec-WebSocket-Key header is missing.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "missing Sec-WebSocket-Key header", ) def test_multiple_key(self): """Handshake fails when the Sec-WebSocket-Key header is repeated.""" server = ServerProtocol() request = make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): """Handshake fails when the Sec-WebSocket-Key header is invalid.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = "" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) if sys.version_info[:2] >= (3, 11): b64_exc = "Only base64 data is allowed" else: # pragma: no cover b64_exc = "Non-base64 digit found" self.assertHandshakeError( server, InvalidHeader, f"invalid Sec-WebSocket-Key header: ; {b64_exc}", ) def test_truncated_key(self): """Handshake fails when the Sec-WebSocket-Key header is truncated.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Key"] # 12 bytes instead of 16, Base64-encoded request.headers["Sec-WebSocket-Key"] = KEY[:16] response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): """Handshake fails when the Sec-WebSocket-Version header is missing.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "missing Sec-WebSocket-Version header", ) def test_multiple_version(self): """Handshake fails when the Sec-WebSocket-Version header is repeated.""" server = ServerProtocol() request = make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): """Handshake fails when the Sec-WebSocket-Version header is invalid.""" server = ServerProtocol() request = make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "invalid Sec-WebSocket-Version header: 11", ) def test_origin(self): """Handshake succeeds when checking origin.""" server = ServerProtocol(origins=["https://example.com"]) request = make_request() request.headers["Origin"] = "https://example.com" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://example.com") def test_no_origin(self): """Handshake fails when checking origin and the Origin header is missing.""" server = ServerProtocol(origins=["https://example.com"]) request = make_request() response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 403) self.assertHandshakeError( server, InvalidOrigin, "missing Origin header", ) def test_unexpected_origin(self): """Handshake fails when checking origin and the Origin header is unexpected.""" server = ServerProtocol(origins=["https://example.com"]) request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 403) self.assertHandshakeError( server, InvalidOrigin, "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): """Handshake fails when checking origins and the Origin header is repeated.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = make_request() request.headers["Origin"] = "https://example.com" request.headers["Origin"] = "https://other.example.com" response = server.accept(request) server.send_response(response) # This is prohibited by the HTTP specification, so the return code is # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, InvalidHeader, "invalid Origin header: multiple values", ) def test_supported_origin(self): """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = make_request() request.headers["Origin"] = "https://original.example.com" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 403) self.assertHandshakeError( server, InvalidOrigin, "invalid Origin header: https://original.example.com", ) def test_supported_origin_regex(self): """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin_regex(self): """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://original.example.com" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 403) self.assertHandshakeError( server, InvalidOrigin, "invalid Origin header: https://original.example.com", ) def test_partial_match_origin_regex(self): """Handshake fails when checking origins and the origin a partial match.""" server = ServerProtocol( origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://other.example.com.hacked" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 403) self.assertHandshakeError( server, InvalidOrigin, "invalid Origin header: https://other.example.com.hacked", ) def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) request = make_request() response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertIsNone(server.origin) def test_no_extensions(self): """Handshake succeeds without extensions.""" server = ServerProtocol() request = make_request() response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension(self): """Server enables an extension when the client offers it.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op") self.assertEqual(server.extensions, [OpExtension()]) def test_extension_not_enabled(self): """Server doesn't enable an extension when the client doesn't offer it.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = make_request() response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_no_extensions_supported(self): """Client offers an extension, but the server doesn't support any.""" server = ServerProtocol() request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension_not_supported(self): """Client offers an extension, but the server doesn't support it.""" server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): """Client offers an extension with parameters supported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=this") self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): """Client offers an extension with parameters unsupported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): """Server supports the same extension with several parameters.""" server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), ] ) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=that") self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): """Server enables several extensions when the client offers them.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" ) self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): """Server respects the order of extensions set in its configuration.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" ) self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): """Handshake succeeds without subprotocols.""" server = ServerProtocol() request = make_request() response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_no_subprotocol_requested(self): """Server expects a subprotocol, but the client doesn't offer it.""" server = ServerProtocol(subprotocols=["chat"]) request = make_request() response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, NegotiationError, "missing subprotocol", ) def test_subprotocol(self): """Server enables a subprotocol when the client offers it.""" server = ServerProtocol(subprotocols=["chat"]) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_no_subprotocols_supported(self): """Client offers a subprotocol, but the server doesn't support any.""" server = ServerProtocol() request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): """Server enables all of the subprotocols when the client offers them.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "superchat") self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): """Server enables one of the subprotocols when the client offers it.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): """Server expects one of the subprotocols, but the client doesn't offer any.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) server.send_response(response) self.assertEqual(response.status_code, 400) self.assertHandshakeError( server, NegotiationError, "invalid subprotocol; expected one of superchat, chat", ) @staticmethod def optional_chat(protocol, subprotocols): if "chat" in subprotocols: return "chat" def test_select_subprotocol(self): """Server enables a subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_select_no_subprotocol(self): """Server doesn't enable any subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) server.send_response(response) self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): """ServerProtocol bypasses the opening handshake.""" server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): """ServerProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: ServerProtocol(logger=logger) self.assertEqual(len(logs.records), 1) class BackwardsCompatibilityTests(DeprecationTestCase): def test_server_connection_class(self): """ServerConnection is a deprecated alias for ServerProtocol.""" with self.assertDeprecationWarning( "ServerConnection was renamed to ServerProtocol" ): from websockets.server import ServerConnection server = ServerConnection() self.assertIsInstance(server, ServerProtocol) websockets-15.0.1/tests/test_streams.py000066400000000000000000000136471476212450300202170ustar00rootroot00000000000000from websockets.streams import StreamReader from .utils import GeneratorTestCase class StreamReaderTests(GeneratorTestCase): def setUp(self): self.reader = StreamReader() def test_read_line(self): self.reader.feed_data(b"spam\neggs\n") gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"eggs\n") def test_read_line_need_more_data(self): self.reader.feed_data(b"spa") gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"m\neg") line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"gs\n") line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"eggs\n") def test_read_line_not_enough_data(self): self.reader.feed_data(b"spa") self.reader.feed_eof() gen = self.reader.read_line(32) with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( str(raised.exception), "stream ends after 3 bytes, before end of line", ) def test_read_line_too_long(self): self.reader.feed_data(b"spam\neggs\n") gen = self.reader.read_line(2) with self.assertRaises(RuntimeError) as raised: next(gen) self.assertEqual( str(raised.exception), "read 5 bytes, expected no more than 2 bytes", ) def test_read_line_too_long_need_more_data(self): self.reader.feed_data(b"spa") gen = self.reader.read_line(2) with self.assertRaises(RuntimeError) as raised: next(gen) self.assertEqual( str(raised.exception), "read 3 bytes, expected no more than 2 bytes", ) def test_read_exact(self): self.reader.feed_data(b"spameggs") gen = self.reader.read_exact(4) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"spam") gen = self.reader.read_exact(4) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"eggs") def test_read_exact_need_more_data(self): self.reader.feed_data(b"spa") gen = self.reader.read_exact(4) self.assertGeneratorRunning(gen) self.reader.feed_data(b"meg") data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"spam") gen = self.reader.read_exact(4) self.assertGeneratorRunning(gen) self.reader.feed_data(b"gs") data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"eggs") def test_read_exact_not_enough_data(self): self.reader.feed_data(b"spa") self.reader.feed_eof() gen = self.reader.read_exact(4) with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( str(raised.exception), "stream ends after 3 bytes, expected 4 bytes", ) def test_read_to_eof(self): gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.assertGeneratorRunning(gen) self.reader.feed_eof() data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"spam") def test_read_to_eof_at_eof(self): self.reader.feed_eof() gen = self.reader.read_to_eof(32) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") def test_read_to_eof_too_long(self): gen = self.reader.read_to_eof(2) self.reader.feed_data(b"spam") with self.assertRaises(RuntimeError) as raised: next(gen) self.assertEqual( str(raised.exception), "read 4 bytes, expected no more than 2 bytes", ) def test_at_eof_after_feed_data(self): gen = self.reader.at_eof() self.assertGeneratorRunning(gen) self.reader.feed_data(b"spam") eof = self.assertGeneratorReturns(gen) self.assertFalse(eof) def test_at_eof_after_feed_eof(self): gen = self.reader.at_eof() self.assertGeneratorRunning(gen) self.reader.feed_eof() eof = self.assertGeneratorReturns(gen) self.assertTrue(eof) def test_feed_data_after_feed_data(self): self.reader.feed_data(b"spam") self.reader.feed_data(b"eggs") gen = self.reader.read_exact(8) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"spameggs") gen = self.reader.at_eof() self.assertGeneratorRunning(gen) def test_feed_eof_after_feed_data(self): self.reader.feed_data(b"spam") self.reader.feed_eof() gen = self.reader.read_exact(4) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"spam") gen = self.reader.at_eof() eof = self.assertGeneratorReturns(gen) self.assertTrue(eof) def test_feed_data_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_data(b"spam") self.assertEqual( str(raised.exception), "stream ended", ) def test_feed_eof_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_eof() self.assertEqual( str(raised.exception), "stream ended", ) def test_discard(self): gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.reader.discard() self.assertGeneratorRunning(gen) self.reader.feed_eof() data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") websockets-15.0.1/tests/test_uri.py000066400000000000000000000155721476212450300173370ustar00rootroot00000000000000import os import unittest from unittest.mock import patch from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * from websockets.uri import Proxy, get_proxy, parse_proxy VALID_URIS = [ ( "ws://localhost/", WebSocketURI(False, "localhost", 80, "/", "", None, None), ), ( "wss://localhost/", WebSocketURI(True, "localhost", 443, "/", "", None, None), ), ( "ws://localhost", WebSocketURI(False, "localhost", 80, "", "", None, None), ), ( "ws://localhost/path?query", WebSocketURI(False, "localhost", 80, "/path", "query", None, None), ), ( "ws://localhost/path;params", WebSocketURI(False, "localhost", 80, "/path;params", "", None, None), ), ( "WS://LOCALHOST/PATH?QUERY", WebSocketURI(False, "localhost", 80, "/PATH", "QUERY", None, None), ), ( "ws://user:pass@localhost/", WebSocketURI(False, "localhost", 80, "/", "", "user", "pass"), ), ( "ws://høst/", WebSocketURI(False, "xn--hst-0na", 80, "/", "", None, None), ), ( "ws://üser:påss@høst/πass?qùéry", WebSocketURI( False, "xn--hst-0na", 80, "/%CF%80ass", "q%C3%B9%C3%A9ry", "%C3%BCser", "p%C3%A5ss", ), ), ] INVALID_URIS = [ "http://localhost/", "https://localhost/", "ws://localhost/path#fragment", "ws://user@localhost/", "ws:///path", ] URIS_WITH_RESOURCE_NAMES = [ ("ws://localhost/", "/"), ("ws://localhost", "/"), ("ws://localhost/path?query", "/path?query"), ("ws://høst/πass?qùéry", "/%CF%80ass?q%C3%B9%C3%A9ry"), ] URIS_WITH_USER_INFO = [ ("ws://localhost/", None), ("ws://user:pass@localhost/", ("user", "pass")), ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), ] VALID_PROXIES = [ ( "http://proxy:8080", Proxy("http", "proxy", 8080, None, None), ), ( "https://proxy:8080", Proxy("https", "proxy", 8080, None, None), ), ( "http://proxy", Proxy("http", "proxy", 80, None, None), ), ( "http://proxy:8080/", Proxy("http", "proxy", 8080, None, None), ), ( "http://PROXY:8080", Proxy("http", "proxy", 8080, None, None), ), ( "http://user:pass@proxy:8080", Proxy("http", "proxy", 8080, "user", "pass"), ), ( "http://høst:8080/", Proxy("http", "xn--hst-0na", 8080, None, None), ), ( "http://üser:påss@høst:8080", Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), ), ] INVALID_PROXIES = [ "ws://proxy:8080", "wss://proxy:8080", "http://proxy:8080/path", "http://proxy:8080/?query", "http://proxy:8080/#fragment", "http://user@proxy", "http:///", ] PROXIES_WITH_USER_INFO = [ ("http://proxy", None), ("http://user:pass@proxy", ("user", "pass")), ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), ] PROXY_ENVS = [ ( {"ws_proxy": "http://proxy:8080"}, "ws://example.com/", "http://proxy:8080", ), ( {"ws_proxy": "http://proxy:8080"}, "wss://example.com/", None, ), ( {"wss_proxy": "http://proxy:8080"}, "ws://example.com/", None, ), ( {"wss_proxy": "http://proxy:8080"}, "wss://example.com/", "http://proxy:8080", ), ( {"http_proxy": "http://proxy:8080"}, "ws://example.com/", "http://proxy:8080", ), ( {"http_proxy": "http://proxy:8080"}, "wss://example.com/", None, ), ( {"https_proxy": "http://proxy:8080"}, "ws://example.com/", "http://proxy:8080", ), ( {"https_proxy": "http://proxy:8080"}, "wss://example.com/", "http://proxy:8080", ), ( {"socks_proxy": "http://proxy:1080"}, "ws://example.com/", "socks5h://proxy:1080", ), ( {"socks_proxy": "http://proxy:1080"}, "wss://example.com/", "socks5h://proxy:1080", ), ( {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, "ws://example.com/", "http://proxy1:8080", ), ( {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, "wss://example.com/", "http://proxy2:8080", ), ( {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, "ws://example.com/", "http://proxy2:8080", ), ( {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, "wss://example.com/", "http://proxy2:8080", ), ( {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, "ws://example.com/", "socks5h://proxy:1080", ), ( {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, "wss://example.com/", "socks5h://proxy:1080", ), ( {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, "ws://example.local/", None, ), ] class URITests(unittest.TestCase): def test_parse_valid_uris(self): for uri, parsed in VALID_URIS: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri), parsed) def test_parse_invalid_uris(self): for uri in INVALID_URIS: with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) def test_parse_resource_name(self): for uri, resource_name in URIS_WITH_RESOURCE_NAMES: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).resource_name, resource_name) def test_parse_user_info(self): for uri, user_info in URIS_WITH_USER_INFO: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).user_info, user_info) def test_parse_valid_proxies(self): for proxy, parsed in VALID_PROXIES: with self.subTest(proxy=proxy): self.assertEqual(parse_proxy(proxy), parsed) def test_parse_invalid_proxies(self): for proxy in INVALID_PROXIES: with self.subTest(proxy=proxy): with self.assertRaises(InvalidProxy): parse_proxy(proxy) def test_parse_proxy_user_info(self): for proxy, user_info in PROXIES_WITH_USER_INFO: with self.subTest(proxy=proxy): self.assertEqual(parse_proxy(proxy).user_info, user_info) def test_get_proxy(self): for environ, uri, proxy in PROXY_ENVS: with patch.dict(os.environ, environ): with self.subTest(environ=environ, uri=uri): self.assertEqual(get_proxy(parse_uri(uri)), proxy) websockets-15.0.1/tests/test_utils.py000066400000000000000000000071341476212450300176730ustar00rootroot00000000000000import base64 import itertools import platform import unittest from websockets.utils import accept_key, apply_mask as py_apply_mask, generate_key # Test vector from RFC 6455 KEY = "dGhlIHNhbXBsZSBub25jZQ==" ACCEPT = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" class UtilsTests(unittest.TestCase): def test_generate_key(self): key = generate_key() self.assertEqual(len(base64.b64decode(key.encode())), 16) def test_accept_key(self): self.assertEqual(accept_key(KEY), ACCEPT) class ApplyMaskTests(unittest.TestCase): @staticmethod def apply_mask(*args, **kwargs): return py_apply_mask(*args, **kwargs) apply_mask_type_combos = list(itertools.product([bytes, bytearray], repeat=2)) apply_mask_test_values = [ (b"", b"1234", b""), (b"aBcDe", b"\x00\x00\x00\x00", b"aBcDe"), (b"abcdABCD", b"1234", b"PPPPpppp"), (b"abcdABCD" * 10, b"1234", b"PPPPpppp" * 10), ] def test_apply_mask(self): for data_type, mask_type in self.apply_mask_type_combos: for data_in, mask, data_out in self.apply_mask_test_values: data_in, mask = data_type(data_in), mask_type(mask) with self.subTest(data_in=data_in, mask=mask): result = self.apply_mask(data_in, mask) self.assertEqual(result, data_out) def test_apply_mask_memoryview(self): for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: data_in, mask = memoryview(data_in), mask_type(mask) with self.subTest(data_in=data_in, mask=mask): result = self.apply_mask(data_in, mask) self.assertEqual(result, data_out) def test_apply_mask_non_contiguous_memoryview(self): for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: data_in, mask = memoryview(data_in)[::-1], mask_type(mask)[::-1] data_out = data_out[::-1] with self.subTest(data_in=data_in, mask=mask): result = self.apply_mask(data_in, mask) self.assertEqual(result, data_out) def test_apply_mask_check_input_types(self): for data_in, mask in [(None, None), (b"abcd", None), (None, b"abcd")]: with self.subTest(data_in=data_in, mask=mask): with self.assertRaises(TypeError): self.apply_mask(data_in, mask) def test_apply_mask_check_mask_length(self): for data_in, mask in [ (b"", b""), (b"abcd", b"123"), (b"", b"aBcDe"), (b"12345678", b"12345678"), ]: with self.subTest(data_in=data_in, mask=mask): with self.assertRaises(ValueError): self.apply_mask(data_in, mask) try: from websockets.speedups import apply_mask as c_apply_mask except ImportError: pass else: class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): try: return c_apply_mask(*args, **kwargs) except NotImplementedError as exc: # pragma: no cover # PyPy doesn't implement creating contiguous readonly buffer # from non-contiguous. We don't care about this edge case. if ( platform.python_implementation() == "PyPy" and "not implemented yet" in str(exc) ): raise unittest.SkipTest(str(exc)) else: raise websockets-15.0.1/tests/utils.py000066400000000000000000000104141476212450300166270ustar00rootroot00000000000000import contextlib import email.utils import logging import os import pathlib import platform import ssl import sys import tempfile import time import unittest import warnings from websockets.version import released # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ # -out test_localhost.crt -keyout test_localhost.key # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt CERTIFICATE = pathlib.Path(__file__).with_name("test_localhost.pem") CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) CLIENT_CONTEXT.load_verify_locations(bytes(CERTIFICATE)) SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) SERVER_CONTEXT.load_cert_chain(bytes(CERTIFICATE)) # Work around https://github.com/openssl/openssl/issues/7967 # This bug causes connect() to hang in tests for the client. Including this # workaround acknowledges that the issue could happen outside of the test suite. # It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it # happens, we can look for a library-level fix, but it won't be easy. SERVER_CONTEXT.num_tickets = 0 DATE = email.utils.formatdate(usegmt=True) # Unit for timeouts. May be increased in slow or noisy environments by setting # the WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. # Downstream distributors insist on running the test suite despites my pleas to # the contrary. They do it on build farms with unstable performance, leading to # flakiness, and then they file bugs. Make tests 100x slower to avoid flakiness. MS = 0.001 * float( os.environ.get( "WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "100" if released else "1", ) ) # PyPy, asyncio's debug mode, and coverage penalize performance of this # test suite. Increase timeouts to reduce the risk of spurious failures. if platform.python_implementation() == "PyPy": # pragma: no cover MS *= 2 if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 2 if os.environ.get("COVERAGE_RUN"): # pragma: no branch MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) class GeneratorTestCase(unittest.TestCase): """ Base class for testing generator-based coroutines. """ def assertGeneratorRunning(self, gen): """ Check that a generator-based coroutine hasn't completed yet. """ next(gen) def assertGeneratorReturns(self, gen): """ Check that a generator-based coroutine completes and return its value. """ with self.assertRaises(StopIteration) as raised: next(gen) return raised.exception.value class DeprecationTestCase(unittest.TestCase): """ Base class for testing deprecations. """ @contextlib.contextmanager def assertDeprecationWarning(self, message): """ Check that a deprecation warning was raised with the given message. """ with warnings.catch_warnings(record=True) as recorded_warnings: warnings.simplefilter("always") yield self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0] self.assertEqual(warning.category, DeprecationWarning) self.assertEqual(str(warning.message), message) class AssertNoLogsMixin: """ Backport of assertNoLogs for Python 3.9. """ if sys.version_info[:2] < (3, 10): # pragma: no cover @contextlib.contextmanager def assertNoLogs(self, logger=None, level=None): """ No message is logged on the given logger with at least the given level. """ with self.assertLogs(logger, level) as logs: # We want to test that no log message is emitted # but assertLogs expects at least one log message. logging.getLogger(logger).log(level, "dummy") yield level_name = logging.getLevelName(level) self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: yield str(pathlib.Path(temp_dir) / "websockets") websockets-15.0.1/tox.ini000066400000000000000000000017341476212450300152730ustar00rootroot00000000000000[tox] env_list = py39 py310 py311 py312 py313 coverage ruff mypy [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} pass_env = WEBSOCKETS_* deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] werkzeug [testenv:coverage] commands = python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage {[testenv]deps} [testenv:maxi_cov] commands = python tests/maxi_cov.py {envsitepackagesdir} python -m coverage report --show-missing --fail-under=100 deps = coverage {[testenv]deps} [testenv:ruff] commands = ruff format --check src tests ruff check src tests deps = ruff [testenv:mypy] commands = mypy --strict src deps = mypy python-socks werkzeug