././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1372943 aioquic-1.2.0/0000755000175100001770000000000000000000000013752 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/LICENSE0000644000175100001770000000273700000000000014770 0ustar00runnerdocker00000000000000Copyright (c) Jeremy Lainé. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of aioquic nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/MANIFEST.in0000644000175100001770000000041000000000000015503 0ustar00runnerdocker00000000000000exclude .readthedocs.yaml include LICENSE recursive-include docs *.py *.rst *.svg Makefile recursive-include examples *.css *.html *.py *.rst *.txt recursive-include requirements *.txt recursive-include scripts *.json *.py recursive-include tests *.bin *.pem *.py ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1372943 aioquic-1.2.0/PKG-INFO0000644000175100001770000001426100000000000015053 0ustar00runnerdocker00000000000000Metadata-Version: 2.1 Name: aioquic Version: 1.2.0 Summary: An implementation of QUIC and HTTP/3 Author-email: Jeremy Lainé License: BSD-3-Clause Project-URL: Homepage, https://github.com/aiortc/aioquic Project-URL: Changelog, https://aioquic.readthedocs.io/en/stable/changelog.html Project-URL: Documentation, https://aioquic.readthedocs.io/ Classifier: Development Status :: 5 - Production/Stable Classifier: Environment :: Web Environment Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: BSD License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 Classifier: Topic :: Internet :: WWW/HTTP Requires-Python: >=3.8 Description-Content-Type: text/x-rst License-File: LICENSE Requires-Dist: certifi Requires-Dist: cryptography>=42.0.0 Requires-Dist: pylsqpack<0.4.0,>=0.3.3 Requires-Dist: pyopenssl>=24 Requires-Dist: service-identity>=24.1.0 Provides-Extra: dev Requires-Dist: coverage[toml]>=7.2.2; extra == "dev" aioquic ======= .. image:: https://img.shields.io/pypi/l/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: License .. image:: https://img.shields.io/pypi/v/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Version .. image:: https://img.shields.io/pypi/pyversions/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Python versions .. image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg :target: https://github.com/aiortc/aioquic/actions :alt: Tests .. image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg :target: https://codecov.io/gh/aiortc/aioquic :alt: Coverage .. image:: https://readthedocs.org/projects/aioquic/badge/?version=latest :target: https://aioquic.readthedocs.io/ :alt: Documentation What is ``aioquic``? -------------------- ``aioquic`` is a library for the QUIC network protocol in Python. It features a minimal TLS 1.3 implementation, a QUIC stack and an HTTP/3 stack. ``aioquic`` is used by Python opensource projects such as `dnspython`_, `hypercorn`_, `mitmproxy`_ and the `Web Platform Tests`_ cross-browser test suite. It has also been used extensively in research papers about QUIC. To learn more about ``aioquic`` please `read the documentation`_. Why should I use ``aioquic``? ----------------------------- ``aioquic`` has been designed to be embedded into Python client and server libraries wishing to support QUIC and / or HTTP/3. The goal is to provide a common codebase for Python libraries in the hope of avoiding duplicated effort. Both the QUIC and the HTTP/3 APIs follow the "bring your own I/O" pattern, leaving actual I/O operations to the API user. This approach has a number of advantages including making the code testable and allowing integration with different concurrency models. A lot of effort has gone into writing an extensive test suite for the ``aioquic`` code to ensure best-in-class code quality, and it is regularly `tested for interoperability`_ against other `QUIC implementations`_. Features -------- - minimal TLS 1.3 implementation conforming with `RFC 8446`_ - QUIC stack conforming with `RFC 9000`_ (QUIC v1) and `RFC 9369`_ (QUIC v2) * IPv4 and IPv6 support * connection migration and NAT rebinding * logging TLS traffic secrets * logging QUIC events in QLOG format * version negotiation conforming with `RFC 9368`_ - HTTP/3 stack conforming with `RFC 9114`_ * server push support * WebSocket bootstrapping conforming with `RFC 9220`_ * datagram support conforming with `RFC 9297`_ Installing ---------- The easiest way to install ``aioquic`` is to run: .. code:: bash pip install aioquic Building from source -------------------- If there are no wheels for your system or if you wish to build ``aioquic`` from source you will need the OpenSSL development headers. Linux ..... On Debian/Ubuntu run: .. code-block:: console sudo apt install libssl-dev python3-dev On Alpine Linux run: .. code-block:: console sudo apk add openssl-dev python3-dev bsd-compat-headers libffi-dev OS X .... On OS X run: .. code-block:: console brew install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console export CFLAGS=-I$(brew --prefix openssl)/include export LDFLAGS=-L$(brew --prefix openssl)/lib Windows ....... On Windows the easiest way to install OpenSSL is to use `Chocolatey`_. .. code-block:: console choco install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console $Env:INCLUDE = "C:\Progra~1\OpenSSL\include" $Env:LIB = "C:\Progra~1\OpenSSL\lib" Running the examples -------------------- `aioquic` comes with a number of examples illustrating various QUIC usecases. You can browse these examples here: https://github.com/aiortc/aioquic/tree/main/examples License ------- ``aioquic`` is released under the `BSD license`_. .. _read the documentation: https://aioquic.readthedocs.io/en/latest/ .. _dnspython: https://github.com/rthalley/dnspython .. _hypercorn: https://github.com/pgjones/hypercorn .. _mitmproxy: https://github.com/mitmproxy/mitmproxy .. _Web Platform Tests: https://github.com/web-platform-tests/wpt .. _tested for interoperability: https://interop.seemann.io/ .. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations .. _cryptography: https://cryptography.io/ .. _Chocolatey: https://chocolatey.org/ .. _BSD license: https://aioquic.readthedocs.io/en/latest/license.html .. _RFC 8446: https://datatracker.ietf.org/doc/html/rfc8446 .. _RFC 9000: https://datatracker.ietf.org/doc/html/rfc9000 .. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114 .. _RFC 9220: https://datatracker.ietf.org/doc/html/rfc9220 .. _RFC 9297: https://datatracker.ietf.org/doc/html/rfc9297 .. _RFC 9368: https://datatracker.ietf.org/doc/html/rfc9368 .. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/README.rst0000644000175100001770000001161400000000000015444 0ustar00runnerdocker00000000000000aioquic ======= .. image:: https://img.shields.io/pypi/l/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: License .. image:: https://img.shields.io/pypi/v/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Version .. image:: https://img.shields.io/pypi/pyversions/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Python versions .. image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg :target: https://github.com/aiortc/aioquic/actions :alt: Tests .. image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg :target: https://codecov.io/gh/aiortc/aioquic :alt: Coverage .. image:: https://readthedocs.org/projects/aioquic/badge/?version=latest :target: https://aioquic.readthedocs.io/ :alt: Documentation What is ``aioquic``? -------------------- ``aioquic`` is a library for the QUIC network protocol in Python. It features a minimal TLS 1.3 implementation, a QUIC stack and an HTTP/3 stack. ``aioquic`` is used by Python opensource projects such as `dnspython`_, `hypercorn`_, `mitmproxy`_ and the `Web Platform Tests`_ cross-browser test suite. It has also been used extensively in research papers about QUIC. To learn more about ``aioquic`` please `read the documentation`_. Why should I use ``aioquic``? ----------------------------- ``aioquic`` has been designed to be embedded into Python client and server libraries wishing to support QUIC and / or HTTP/3. The goal is to provide a common codebase for Python libraries in the hope of avoiding duplicated effort. Both the QUIC and the HTTP/3 APIs follow the "bring your own I/O" pattern, leaving actual I/O operations to the API user. This approach has a number of advantages including making the code testable and allowing integration with different concurrency models. A lot of effort has gone into writing an extensive test suite for the ``aioquic`` code to ensure best-in-class code quality, and it is regularly `tested for interoperability`_ against other `QUIC implementations`_. Features -------- - minimal TLS 1.3 implementation conforming with `RFC 8446`_ - QUIC stack conforming with `RFC 9000`_ (QUIC v1) and `RFC 9369`_ (QUIC v2) * IPv4 and IPv6 support * connection migration and NAT rebinding * logging TLS traffic secrets * logging QUIC events in QLOG format * version negotiation conforming with `RFC 9368`_ - HTTP/3 stack conforming with `RFC 9114`_ * server push support * WebSocket bootstrapping conforming with `RFC 9220`_ * datagram support conforming with `RFC 9297`_ Installing ---------- The easiest way to install ``aioquic`` is to run: .. code:: bash pip install aioquic Building from source -------------------- If there are no wheels for your system or if you wish to build ``aioquic`` from source you will need the OpenSSL development headers. Linux ..... On Debian/Ubuntu run: .. code-block:: console sudo apt install libssl-dev python3-dev On Alpine Linux run: .. code-block:: console sudo apk add openssl-dev python3-dev bsd-compat-headers libffi-dev OS X .... On OS X run: .. code-block:: console brew install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console export CFLAGS=-I$(brew --prefix openssl)/include export LDFLAGS=-L$(brew --prefix openssl)/lib Windows ....... On Windows the easiest way to install OpenSSL is to use `Chocolatey`_. .. code-block:: console choco install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console $Env:INCLUDE = "C:\Progra~1\OpenSSL\include" $Env:LIB = "C:\Progra~1\OpenSSL\lib" Running the examples -------------------- `aioquic` comes with a number of examples illustrating various QUIC usecases. You can browse these examples here: https://github.com/aiortc/aioquic/tree/main/examples License ------- ``aioquic`` is released under the `BSD license`_. .. _read the documentation: https://aioquic.readthedocs.io/en/latest/ .. _dnspython: https://github.com/rthalley/dnspython .. _hypercorn: https://github.com/pgjones/hypercorn .. _mitmproxy: https://github.com/mitmproxy/mitmproxy .. _Web Platform Tests: https://github.com/web-platform-tests/wpt .. _tested for interoperability: https://interop.seemann.io/ .. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations .. _cryptography: https://cryptography.io/ .. _Chocolatey: https://chocolatey.org/ .. _BSD license: https://aioquic.readthedocs.io/en/latest/license.html .. _RFC 8446: https://datatracker.ietf.org/doc/html/rfc8446 .. _RFC 9000: https://datatracker.ietf.org/doc/html/rfc9000 .. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114 .. _RFC 9220: https://datatracker.ietf.org/doc/html/rfc9220 .. _RFC 9297: https://datatracker.ietf.org/doc/html/rfc9297 .. _RFC 9368: https://datatracker.ietf.org/doc/html/rfc9368 .. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369 ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.117294 aioquic-1.2.0/docs/0000755000175100001770000000000000000000000014702 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/Makefile0000644000175100001770000000113400000000000016341 0ustar00runnerdocker00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SPHINXPROJ = aioquic SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.117294 aioquic-1.2.0/docs/_ext/0000755000175100001770000000000000000000000015641 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/_ext/sphinx_aioquic.py0000644000175100001770000000200300000000000021231 0ustar00runnerdocker00000000000000from docutils import nodes from docutils.parsers.rst import Directive from docutils.statemachine import StringList class AioquicTransmit(Directive): def run(self): content = StringList( [ ".. note::", " After calling this method you need to call the QUIC connection " ":meth:`~aioquic.quic.connection.QuicConnection.datagrams_to_send` " "method to retrieve data which needs to be sent over the network. " "If you are using the :doc:`asyncio API `, calling the " ":meth:`~aioquic.asyncio.QuicConnectionProtocol.transmit` method " "will do it for you.", ] ) node = nodes.paragraph() self.state.nested_parse(content, 0, node) return [node] def setup(app): app.add_directive("aioquic_transmit", AioquicTransmit) return { "version": "0.1", "parallel_read_safe": True, "parallel_write_safe": True, } ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.117294 aioquic-1.2.0/docs/_static/0000755000175100001770000000000000000000000016330 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/_static/aioquic.svg0000644000175100001770000000774300000000000020516 0ustar00runnerdocker00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/asyncio.rst0000644000175100001770000000103600000000000017101 0ustar00runnerdocker00000000000000asyncio API =========== The asyncio API provides a high-level QUIC API built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework. ``aioquic`` comes with a selection of examples, including: - an HTTP/3 client - an HTTP/3 server The examples can be browsed on GitHub: https://github.com/aiortc/aioquic/tree/main/examples .. automodule:: aioquic.asyncio Client ------ .. autofunction:: connect Server ------ .. autofunction:: serve Common ------ .. autoclass:: QuicConnectionProtocol :members: ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/changelog.rst0000644000175100001770000000525600000000000017373 0ustar00runnerdocker00000000000000Changelog ========= 1.2.0 ----- * Add support for compatible version handling as defined in :rfc:`9368`. * Add support for QUIC Version 2, as defined in :rfc:`9369`. * Drop support for draft QUIC versions which were obsoleted by :rfc:`9000`. * Improve datagram padding to allow better packet coalescing and reduce the number of roundtrips during connection establishement. * Fix server anti-amplification checks during address validation to take into account invalid packets, such as datagram-level padding. * Allow asyncio clients to make efficient use of 0-RTT by passing `wait_connected=False` to :meth:`~aioquic.asyncio.connect`. * Add command-line arguments to the `http3_client` example for client certificates and negotiating QUIC Version 2. 1.1.0 ----- * Improve path challenge handling and compliance with :rfc:`9000`. * Limit the amount of buffered CRYPTO data to avoid memory exhaustion. * Enable SHA-384 based signature algorithms and SECP384R1 key exchange. * Build binary wheels against `OpenSSL`_ 3.3.0. 1.0.0 ----- * Ensure no data is sent after a stream reset. * Make :class:`~aioquic.h3.connection.H3Connection`'s :meth:`~aioquic.h3.connection.H3Connection.send_datagram` and :meth:`~aioquic.h3.connection.H3Connection.send_push_promise` methods raise an :class:`~aioquic.h3.exceptions.InvalidStreamTypeError` exception if an invalid stream ID is specified. * Improve the documentation for :class:`~aioquic.asyncio.QuicConnectionProtocol`'s :meth:`~aioquic.asyncio.QuicConnectionProtocol.transmit` method. * Fix :meth:`~datetime.datetime.utcnow` deprecation warning on Python 3.12 by using `cryptography`_ 42.0 and timezone-aware :class:`~datetime.datetime` instances when validating TLS certificates. * Build binary wheels against `OpenSSL`_ 3.2.0. * Ignore any non-ASCII ALPN values received. * Perform more extensive HTTP/3 header validation in :class:`~aioquic.h3.connection.H3Connection`. * Fix exceptions when draining stream writers in the :doc:`asyncio API `. * Set the :class:`~aioquic.quic.connection.QuicConnection` idle timer according to :rfc:`9000` section 10.1. * Implement fairer stream scheduling in :class:`~aioquic.quic.connection.QuicConnection` to avoid head-of-line blocking. * Only load `certifi`_ root certificates if none was specified in the :class:`~aioquic.quic.configuration.QuicConfiguration`. * Improve padding of UDP datagrams containing Initial packets to comply with :rfc:`9000` section 14.1. * Limit the number of pending connection IDs marked for retirement to prevent a possible DoS attack. .. _certifi: https://github.com/certifi/python-certifi .. _cryptography: https://cryptography.io/ .. _OpenSSL: https://www.openssl.org/ ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/conf.py0000644000175100001770000000473700000000000016214 0ustar00runnerdocker00000000000000# Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html import os import sys # -- Project information ----------------------------------------------------- project = "aioquic" author = "Jeremy Lainé" copyright = author # -- General configuration ------------------------------------------------ sys.path.append(os.path.abspath("./_ext")) # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx_autodoc_typehints", "sphinxcontrib_trio", "sphinx_aioquic", ] intersphinx_mapping = { "cryptography": ("https://cryptography.io/en/latest", None), "python": ("https://docs.python.org/3", None), } # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = "en" # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { "description": "A library for QUIC in Python.", "github_button": True, "github_user": "aiortc", "github_repo": "aioquic", "logo": "aioquic.svg", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { "**": [ "about.html", "navigation.html", "relations.html", "searchbox.html", ] } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/design.rst0000644000175100001770000000224700000000000016712 0ustar00runnerdocker00000000000000Design ====== Sans-IO APIs ............ Both the QUIC and the HTTP/3 APIs follow the `sans I/O`_ pattern, leaving actual I/O operations to the API user. This approach has a number of advantages including making the code testable and allowing integration with different concurrency models. TLS and encryption .................. TLS 1.3 +++++++ ``aioquic`` features a minimal TLS 1.3 implementation built upon the `cryptography`_ library. This is because QUIC requires some APIs which are currently unavailable in mainstream TLS implementations such as OpenSSL: - the ability to extract traffic secrets - the ability to operate directly on TLS messages, without using the TLS record layer Header protection and payload encryption ++++++++++++++++++++++++++++++++++++++++ QUIC makes extensive use of cryptographic operations to protect QUIC packet headers and encrypt packet payloads. These operations occur for every single packet and are a determining factor for performance. For this reason, they are implemented as a C extension linked to `OpenSSL`_. .. _sans I/O: https://sans-io.readthedocs.io/ .. _cryptography: https://cryptography.io/ .. _OpenSSL: https://www.openssl.org/ ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/h3.rst0000644000175100001770000000165200000000000015752 0ustar00runnerdocker00000000000000HTTP/3 API ========== The HTTP/3 API performs no I/O on its own, leaving this to the API user. This allows you to integrate HTTP/3 in any Python application, regardless of the concurrency model you are using. Connection ---------- .. automodule:: aioquic.h3.connection .. autoclass:: H3Connection :members: Events ------ .. automodule:: aioquic.h3.events .. autoclass:: H3Event :members: .. autoclass:: DatagramReceived :members: .. autoclass:: DataReceived :members: .. autoclass:: HeadersReceived :members: .. autoclass:: PushPromiseReceived :members: .. autoclass:: WebTransportStreamDataReceived :members: Exceptions ---------- .. automodule:: aioquic.h3.exceptions .. autoclass:: H3Error :members: .. autoclass:: InvalidStreamTypeError :members: .. autoclass:: NoAvailablePushIDError :members: ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/index.rst0000644000175100001770000000215400000000000016545 0ustar00runnerdocker00000000000000aioquic ======= .. image:: https://img.shields.io/pypi/l/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: License .. image:: https://img.shields.io/pypi/v/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Version .. image:: https://img.shields.io/pypi/pyversions/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Python versions .. image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg :target: https://github.com/aiortc/aioquic/actions :alt: Tests .. image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg :target: https://codecov.io/gh/aiortc/aioquic :alt: Coverage ``aioquic`` is a library for the QUIC network protocol in Python. It features several APIs: - a QUIC API following the "bring your own I/O" pattern, suitable for embedding in any framework, - an HTTP/3 API which also follows the "bring your own I/O" pattern, - a QUIC convenience API built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework. .. toctree:: :maxdepth: 2 design quic h3 asyncio changelog license ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/license.rst0000644000175100001770000000006000000000000017052 0ustar00runnerdocker00000000000000License ------- .. literalinclude:: ../LICENSE ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/docs/quic.rst0000644000175100001770000000171300000000000016377 0ustar00runnerdocker00000000000000QUIC API ======== The QUIC API performs no I/O on its own, leaving this to the API user. This allows you to integrate QUIC in any Python application, regardless of the concurrency model you are using. Connection ---------- .. automodule:: aioquic.quic.connection .. autoclass:: QuicConnection :members: Configuration ------------- .. automodule:: aioquic.quic.configuration .. autoclass:: QuicConfiguration :members: .. automodule:: aioquic.quic.logger .. autoclass:: QuicLogger :members: Events ------ .. automodule:: aioquic.quic.events .. autoclass:: QuicEvent :members: .. autoclass:: ConnectionTerminated :members: .. autoclass:: HandshakeCompleted :members: .. autoclass:: PingAcknowledged :members: .. autoclass:: StopSendingReceived :members: .. autoclass:: StreamDataReceived :members: .. autoclass:: StreamReset :members: ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.117294 aioquic-1.2.0/examples/0000755000175100001770000000000000000000000015570 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/README.rst0000644000175100001770000001044300000000000017261 0ustar00runnerdocker00000000000000Examples ======== After checking out the code using git you can run: .. code-block:: console pip install . dnslib jinja2 starlette wsproto HTTP/3 ------ HTTP/3 server ............. You can run the example server, which handles both HTTP/0.9 and HTTP/3: .. code-block:: console python examples/http3_server.py --certificate tests/ssl_cert.pem --private-key tests/ssl_key.pem HTTP/3 client ............. You can run the example client to perform an HTTP/3 request: .. code-block:: console python examples/http3_client.py --ca-certs tests/pycacert.pem https://localhost:4433/ Alternatively you can perform an HTTP/0.9 request: .. code-block:: console python examples/http3_client.py --ca-certs tests/pycacert.pem --legacy-http https://localhost:4433/ You can also open a WebSocket over HTTP/3: .. code-block:: console python examples/http3_client.py --ca-certs tests/pycacert.pem wss://localhost:4433/ws Chromium and Chrome usage ......................... Some flags are needed to allow Chrome to communicate with the demo server. Most are not necessary in a more production-oriented deployment with HTTP/2 fallback and a valid certificate, as demonstrated on https://quic.aiortc.org/ - The `--ignore-certificate-errors-spki-list`_ instructs Chrome to accept the demo TLS certificate, even though it is not signed by a known certificate authority. If you use your own valid certificate, you do not need this flag. - The `--origin-to-force-quic-on` forces Chrome to communicate using HTTP/3. This is needed because the demo server *only* provides an HTTP/3 server. Usually Chrome will connect to an HTTP/2 or HTTP/1.1 server and "discover" the server supports HTTP/3 through an Alt-Svc header. - The `--enable-experimental-web-platform-features`_ enables WebTransport, because the specifications and implementation are not yet finalised. For HTTP/3 itself, you do not need this flag. To access the demo server running on the local machine, launch Chromium or Chrome as follows: .. code:: bash google-chrome \ --enable-experimental-web-platform-features \ --ignore-certificate-errors-spki-list=BSQJ0jkQ7wwhR7KvPZ+DSNk2XTZ/MS6xCbo9qu++VdQ= \ --origin-to-force-quic-on=localhost:4433 \ https://localhost:4433/ The fingerprint passed to the `--ignore-certificate-errors-spki-list`_ option is obtained by running: .. code:: bash openssl x509 -in tests/ssl_cert.pem -pubkey -noout | \ openssl pkey -pubin -outform der | \ openssl dgst -sha256 -binary | \ openssl enc -base64 WebTransport ............ The demo server runs a :code:`WebTransport` echo service at `/wt`. You can connect by opening Developer Tools and running the following: .. code:: javascript let transport = new WebTransport('https://localhost:4433/wt'); await transport.ready; let stream = await transport.createBidirectionalStream(); let reader = stream.readable.getReader(); let writer = stream.writable.getWriter(); await writer.write(new Uint8Array([65, 66, 67])); let received = await reader.read(); await transport.close(); console.log('received', received); If all is well you should see: .. image:: https://user-images.githubusercontent.com/1567624/126713050-e3c0664c-b0b9-4ac8-a393-9b647c9cab6b.png DNS over QUIC ------------- By default the server will use the `Google Public DNS`_ service, you can override this with the ``--resolver`` argument. By default the server will listen for requests on port 853, which requires a privileged user. You can override this with the `--port` argument. You can run the server locally using: .. code-block:: console python examples/doq_server.py --certificate tests/ssl_cert.pem --private-key tests/ssl_key.pem --port 8053 You can then run the client with a specific query: .. code-block:: console python examples/doq_client.py --ca-certs tests/pycacert.pem --query-type A --query-name quic.aiortc.org --port 8053 Please note that for real-world usage you will need to obtain a valid TLS certificate. .. _Google Public DNS: https://developers.google.com/speed/public-dns .. _--enable-experimental-web-platform-features: https://peter.sh/experiments/chromium-command-line-switches/#enable-experimental-web-platform-features .. _--ignore-certificate-errors-spki-list: https://peter.sh/experiments/chromium-command-line-switches/#ignore-certificate-errors-spki-list ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/demo.py0000644000175100001770000001007500000000000017071 0ustar00runnerdocker00000000000000# # demo application for http3_server.py # import datetime import os from urllib.parse import urlencode from starlette.applications import Starlette from starlette.responses import PlainTextResponse, Response from starlette.routing import Mount, Route, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocketDisconnect ROOT = os.path.dirname(__file__) STATIC_ROOT = os.environ.get("STATIC_ROOT", os.path.join(ROOT, "htdocs")) STATIC_URL = "/" LOGS_PATH = os.path.join(STATIC_ROOT, "logs") QVIS_URL = "https://qvis.quictools.info/" templates = Jinja2Templates(directory=os.path.join(ROOT, "templates")) async def homepage(request): """ Simple homepage. """ await request.send_push_promise("/style.css") return templates.TemplateResponse("index.html", {"request": request}) async def echo(request): """ HTTP echo endpoint. """ content = await request.body() media_type = request.headers.get("content-type") return Response(content, media_type=media_type) async def logs(request): """ Browsable list of QLOG files. """ logs = [] for name in os.listdir(LOGS_PATH): if name.endswith(".qlog"): s = os.stat(os.path.join(LOGS_PATH, name)) file_url = "https://" + request.headers["host"] + "/logs/" + name logs.append( { "date": datetime.datetime.utcfromtimestamp(s.st_mtime).strftime( "%Y-%m-%d %H:%M:%S" ), "file_url": file_url, "name": name[:-5], "qvis_url": QVIS_URL + "?" + urlencode({"file": file_url}) + "#/sequence", "size": s.st_size, } ) return templates.TemplateResponse( "logs.html", { "logs": sorted(logs, key=lambda x: x["date"], reverse=True), "request": request, }, ) async def padding(request): """ Dynamically generated data, maximum 50MB. """ size = min(50000000, request.path_params["size"]) return PlainTextResponse("Z" * size) async def ws(websocket): """ WebSocket echo endpoint. """ if "chat" in websocket.scope["subprotocols"]: subprotocol = "chat" else: subprotocol = None await websocket.accept(subprotocol=subprotocol) try: while True: message = await websocket.receive_text() await websocket.send_text(message) except WebSocketDisconnect: pass async def wt(scope: Scope, receive: Receive, send: Send) -> None: """ WebTransport echo endpoint. """ # accept connection message = await receive() assert message["type"] == "webtransport.connect" await send({"type": "webtransport.accept"}) # echo back received data while True: message = await receive() if message["type"] == "webtransport.datagram.receive": await send( { "data": message["data"], "type": "webtransport.datagram.send", } ) elif message["type"] == "webtransport.stream.receive": await send( { "data": message["data"], "stream": message["stream"], "type": "webtransport.stream.send", } ) starlette = Starlette( routes=[ Route("/", homepage), Route("/{size:int}", padding), Route("/echo", echo, methods=["POST"]), Route("/logs", logs), WebSocketRoute("/ws", ws), Mount(STATIC_URL, StaticFiles(directory=STATIC_ROOT, html=True)), ] ) async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "webtransport" and scope["path"] == "/wt": await wt(scope, receive, send) else: await starlette(scope, receive, send) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/doq_client.py0000644000175100001770000001224300000000000020265 0ustar00runnerdocker00000000000000import argparse import asyncio import logging import pickle import ssl import struct from typing import Optional, cast from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent, StreamDataReceived from aioquic.quic.logger import QuicFileLogger from dnslib.dns import QTYPE, DNSHeader, DNSQuestion, DNSRecord logger = logging.getLogger("client") class DnsClientProtocol(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._ack_waiter: Optional[asyncio.Future[DNSRecord]] = None async def query(self, query_name: str, query_type: str) -> DNSRecord: # serialize query query = DNSRecord( header=DNSHeader(id=0), q=DNSQuestion(query_name, getattr(QTYPE, query_type)), ) data = bytes(query.pack()) data = struct.pack("!H", len(data)) + data # send query and wait for answer stream_id = self._quic.get_next_available_stream_id() self._quic.send_stream_data(stream_id, data, end_stream=True) waiter = self._loop.create_future() self._ack_waiter = waiter self.transmit() return await asyncio.shield(waiter) def quic_event_received(self, event: QuicEvent) -> None: if self._ack_waiter is not None: if isinstance(event, StreamDataReceived): # parse answer length = struct.unpack("!H", bytes(event.data[:2]))[0] answer = DNSRecord.parse(event.data[2 : 2 + length]) # return answer waiter = self._ack_waiter self._ack_waiter = None waiter.set_result(answer) def save_session_ticket(ticket): """ Callback which is invoked by the TLS engine when a new session ticket is received. """ logger.info("New session ticket received") if args.session_ticket: with open(args.session_ticket, "wb") as fp: pickle.dump(ticket, fp) async def main( configuration: QuicConfiguration, host: str, port: int, query_name: str, query_type: str, ) -> None: logger.debug(f"Connecting to {host}:{port}") async with connect( host, port, configuration=configuration, session_ticket_handler=save_session_ticket, create_protocol=DnsClientProtocol, ) as client: client = cast(DnsClientProtocol, client) logger.debug("Sending DNS query") answer = await client.query(query_name, query_type) logger.info("Received DNS answer\n%s" % answer) if __name__ == "__main__": parser = argparse.ArgumentParser(description="DNS over QUIC client") parser.add_argument( "--host", type=str, default="localhost", help="The remote peer's host name or IP address", ) parser.add_argument( "--port", type=int, default=853, help="The remote peer's port number" ) parser.add_argument( "-k", "--insecure", action="store_true", help="do not validate server certificate", ) parser.add_argument( "--ca-certs", type=str, help="load CA certificates from the specified file" ) parser.add_argument("--query-name", required=True, help="Domain to query") parser.add_argument("--query-type", default="A", help="The DNS query type to send") parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-s", "--session-ticket", type=str, help="read and write session ticket from the specified file", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) if args.ca_certs: configuration.load_verify_locations(args.ca_certs) if args.insecure: configuration.verify_mode = ssl.CERT_NONE if args.quic_log: configuration.quic_logger = QuicFileLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: configuration.session_ticket = pickle.load(fp) except FileNotFoundError: logger.debug(f"Unable to read {args.session_ticket}") pass else: logger.debug("No session ticket defined...") asyncio.run( main( configuration=configuration, host=args.host, port=args.port, query_name=args.query_name, query_type=args.query_type, ) ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/doq_server.py0000644000175100001770000000754100000000000020322 0ustar00runnerdocker00000000000000import argparse import asyncio import logging import struct from typing import Dict, Optional from aioquic.asyncio import QuicConnectionProtocol, serve from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent, StreamDataReceived from aioquic.quic.logger import QuicFileLogger from aioquic.tls import SessionTicket from dnslib.dns import DNSRecord class DnsServerProtocol(QuicConnectionProtocol): def quic_event_received(self, event: QuicEvent): if isinstance(event, StreamDataReceived): # parse query length = struct.unpack("!H", bytes(event.data[:2]))[0] query = DNSRecord.parse(event.data[2 : 2 + length]) # perform lookup and serialize answer data = query.send(args.resolver, 53) data = struct.pack("!H", len(data)) + data # send answer self._quic.send_stream_data(event.stream_id, data, end_stream=True) class SessionTicketStore: """ Simple in-memory store for session tickets. """ def __init__(self) -> None: self.tickets: Dict[bytes, SessionTicket] = {} def add(self, ticket: SessionTicket) -> None: self.tickets[ticket.ticket] = ticket def pop(self, label: bytes) -> Optional[SessionTicket]: return self.tickets.pop(label, None) async def main( host: str, port: int, configuration: QuicConfiguration, session_ticket_store: SessionTicketStore, retry: bool, ) -> None: await serve( host, port, configuration=configuration, create_protocol=DnsServerProtocol, session_ticket_fetcher=session_ticket_store.pop, session_ticket_handler=session_ticket_store.add, retry=retry, ) await asyncio.Future() if __name__ == "__main__": parser = argparse.ArgumentParser(description="DNS over QUIC server") parser.add_argument( "--host", type=str, default="::", help="listen on the specified address (defaults to ::)", ) parser.add_argument( "--port", type=int, default=853, help="listen on the specified port (defaults to 853)", ) parser.add_argument( "-k", "--private-key", type=str, help="load the TLS private key from the specified file", ) parser.add_argument( "-c", "--certificate", type=str, required=True, help="load the TLS certificate from the specified file", ) parser.add_argument( "--resolver", type=str, default="8.8.8.8", help="Upstream Classic DNS resolver to use", ) parser.add_argument( "--retry", action="store_true", help="send a retry for new connections", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) # create QUIC logger if args.quic_log: quic_logger = QuicFileLogger(args.quic_log) else: quic_logger = None configuration = QuicConfiguration( alpn_protocols=["doq"], is_client=False, quic_logger=quic_logger, ) configuration.load_cert_chain(args.certificate, args.private_key) try: asyncio.run( main( host=args.host, port=args.port, configuration=configuration, session_ticket_store=SessionTicketStore(), retry=args.retry, ) ) except KeyboardInterrupt: pass ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.117294 aioquic-1.2.0/examples/htdocs/0000755000175100001770000000000000000000000017054 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/htdocs/robots.txt0000644000175100001770000000003600000000000021124 0ustar00runnerdocker00000000000000User-agent: * Disallow: /logs ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/htdocs/style.css0000644000175100001770000000021000000000000020717 0ustar00runnerdocker00000000000000body { font-family: Arial, sans-serif; font-size: 16px; margin: 0 auto; width: 40em; } table.logs { width: 100%; } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/http3_client.py0000644000175100001770000004541500000000000020553 0ustar00runnerdocker00000000000000import argparse import asyncio import logging import os import pickle import ssl import time from collections import deque from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast from urllib.parse import urlparse import aioquic import wsproto import wsproto.events from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection from aioquic.h3.events import ( DataReceived, H3Event, HeadersReceived, PushPromiseReceived, ) from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent from aioquic.quic.logger import QuicFileLogger from aioquic.quic.packet import QuicProtocolVersion from aioquic.tls import CipherSuite, SessionTicket try: import uvloop except ImportError: uvloop = None logger = logging.getLogger("client") HttpConnection = Union[H0Connection, H3Connection] USER_AGENT = "aioquic/" + aioquic.__version__ class URL: def __init__(self, url: str) -> None: parsed = urlparse(url) self.authority = parsed.netloc self.full_path = parsed.path or "/" if parsed.query: self.full_path += "?" + parsed.query self.scheme = parsed.scheme class HttpRequest: def __init__( self, method: str, url: URL, content: bytes = b"", headers: Optional[Dict] = None, ) -> None: if headers is None: headers = {} self.content = content self.headers = headers self.method = method self.url = url class WebSocket: def __init__( self, http: HttpConnection, stream_id: int, transmit: Callable[[], None] ) -> None: self.http = http self.queue: asyncio.Queue[str] = asyncio.Queue() self.stream_id = stream_id self.subprotocol: Optional[str] = None self.transmit = transmit self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT) async def close(self, code: int = 1000, reason: str = "") -> None: """ Perform the closing handshake. """ data = self.websocket.send( wsproto.events.CloseConnection(code=code, reason=reason) ) self.http.send_data(stream_id=self.stream_id, data=data, end_stream=True) self.transmit() async def recv(self) -> str: """ Receive the next message. """ return await self.queue.get() async def send(self, message: str) -> None: """ Send a message. """ assert isinstance(message, str) data = self.websocket.send(wsproto.events.TextMessage(data=message)) self.http.send_data(stream_id=self.stream_id, data=data, end_stream=False) self.transmit() def http_event_received(self, event: H3Event) -> None: if isinstance(event, HeadersReceived): for header, value in event.headers: if header == b"sec-websocket-protocol": self.subprotocol = value.decode() elif isinstance(event, DataReceived): self.websocket.receive_data(event.data) for ws_event in self.websocket.events(): self.websocket_event_received(ws_event) def websocket_event_received(self, event: wsproto.events.Event) -> None: if isinstance(event, wsproto.events.TextMessage): self.queue.put_nowait(event.data) class HttpClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.pushes: Dict[int, Deque[H3Event]] = {} self._http: Optional[HttpConnection] = None self._request_events: Dict[int, Deque[H3Event]] = {} self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} self._websockets: Dict[int, WebSocket] = {} if self._quic.configuration.alpn_protocols[0].startswith("hq-"): self._http = H0Connection(self._quic) else: self._http = H3Connection(self._quic) async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]: """ Perform a GET request. """ return await self._request( HttpRequest(method="GET", url=URL(url), headers=headers) ) async def post( self, url: str, data: bytes, headers: Optional[Dict] = None ) -> Deque[H3Event]: """ Perform a POST request. """ return await self._request( HttpRequest(method="POST", url=URL(url), content=data, headers=headers) ) async def websocket( self, url: str, subprotocols: Optional[List[str]] = None ) -> WebSocket: """ Open a WebSocket. """ request = HttpRequest(method="CONNECT", url=URL(url)) stream_id = self._quic.get_next_available_stream_id() websocket = WebSocket( http=self._http, stream_id=stream_id, transmit=self.transmit ) self._websockets[stream_id] = websocket headers = [ (b":method", b"CONNECT"), (b":scheme", b"https"), (b":authority", request.url.authority.encode()), (b":path", request.url.full_path.encode()), (b":protocol", b"websocket"), (b"user-agent", USER_AGENT.encode()), (b"sec-websocket-version", b"13"), ] if subprotocols: headers.append( (b"sec-websocket-protocol", ", ".join(subprotocols).encode()) ) self._http.send_headers(stream_id=stream_id, headers=headers) self.transmit() return websocket def http_event_received(self, event: H3Event) -> None: if isinstance(event, (HeadersReceived, DataReceived)): stream_id = event.stream_id if stream_id in self._request_events: # http self._request_events[event.stream_id].append(event) if event.stream_ended: request_waiter = self._request_waiter.pop(stream_id) request_waiter.set_result(self._request_events.pop(stream_id)) elif stream_id in self._websockets: # websocket websocket = self._websockets[stream_id] websocket.http_event_received(event) elif event.push_id in self.pushes: # push self.pushes[event.push_id].append(event) elif isinstance(event, PushPromiseReceived): self.pushes[event.push_id] = deque() self.pushes[event.push_id].append(event) def quic_event_received(self, event: QuicEvent) -> None: #  pass event to the HTTP layer if self._http is not None: for http_event in self._http.handle_event(event): self.http_event_received(http_event) async def _request(self, request: HttpRequest) -> Deque[H3Event]: stream_id = self._quic.get_next_available_stream_id() self._http.send_headers( stream_id=stream_id, headers=[ (b":method", request.method.encode()), (b":scheme", request.url.scheme.encode()), (b":authority", request.url.authority.encode()), (b":path", request.url.full_path.encode()), (b"user-agent", USER_AGENT.encode()), ] + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], end_stream=not request.content, ) if request.content: self._http.send_data( stream_id=stream_id, data=request.content, end_stream=True ) waiter = self._loop.create_future() self._request_events[stream_id] = deque() self._request_waiter[stream_id] = waiter self.transmit() return await asyncio.shield(waiter) async def perform_http_request( client: HttpClient, url: str, data: Optional[str], include: bool, output_dir: Optional[str], ) -> None: # perform request start = time.time() if data is not None: data_bytes = data.encode() http_events = await client.post( url, data=data_bytes, headers={ "content-length": str(len(data_bytes)), "content-type": "application/x-www-form-urlencoded", }, ) method = "POST" else: http_events = await client.get(url) method = "GET" elapsed = time.time() - start # print speed octets = 0 for http_event in http_events: if isinstance(http_event, DataReceived): octets += len(http_event.data) logger.info( "Response received for %s %s : %d bytes in %.1f s (%.3f Mbps)" % (method, urlparse(url).path, octets, elapsed, octets * 8 / elapsed / 1000000) ) # output response if output_dir is not None: output_path = os.path.join( output_dir, os.path.basename(urlparse(url).path) or "index.html" ) with open(output_path, "wb") as output_file: write_response( http_events=http_events, include=include, output_file=output_file ) def process_http_pushes( client: HttpClient, include: bool, output_dir: Optional[str], ) -> None: for _, http_events in client.pushes.items(): method = "" octets = 0 path = "" for http_event in http_events: if isinstance(http_event, DataReceived): octets += len(http_event.data) elif isinstance(http_event, PushPromiseReceived): for header, value in http_event.headers: if header == b":method": method = value.decode() elif header == b":path": path = value.decode() logger.info("Push received for %s %s : %s bytes", method, path, octets) # output response if output_dir is not None: output_path = os.path.join( output_dir, os.path.basename(path) or "index.html" ) with open(output_path, "wb") as output_file: write_response( http_events=http_events, include=include, output_file=output_file ) def write_response( http_events: Deque[H3Event], output_file: BinaryIO, include: bool ) -> None: for http_event in http_events: if isinstance(http_event, HeadersReceived) and include: headers = b"" for k, v in http_event.headers: headers += k + b": " + v + b"\r\n" if headers: output_file.write(headers + b"\r\n") elif isinstance(http_event, DataReceived): output_file.write(http_event.data) def save_session_ticket(ticket: SessionTicket) -> None: """ Callback which is invoked by the TLS engine when a new session ticket is received. """ logger.info("New session ticket received") if args.session_ticket: with open(args.session_ticket, "wb") as fp: pickle.dump(ticket, fp) async def main( configuration: QuicConfiguration, urls: List[str], data: Optional[str], include: bool, output_dir: Optional[str], local_port: int, zero_rtt: bool, ) -> None: # parse URL parsed = urlparse(urls[0]) assert parsed.scheme in ( "https", "wss", ), "Only https:// or wss:// URLs are supported." host = parsed.hostname if parsed.port is not None: port = parsed.port else: port = 443 # check validity of 2nd urls and later. for i in range(1, len(urls)): _p = urlparse(urls[i]) # fill in if empty _scheme = _p.scheme or parsed.scheme _host = _p.hostname or host _port = _p.port or port assert _scheme == parsed.scheme, "URL scheme doesn't match" assert _host == host, "URL hostname doesn't match" assert _port == port, "URL port doesn't match" # reconstruct url with new hostname and port _p = _p._replace(scheme=_scheme) _p = _p._replace(netloc="{}:{}".format(_host, _port)) _p = urlparse(_p.geturl()) urls[i] = _p.geturl() async with connect( host, port, configuration=configuration, create_protocol=HttpClient, session_ticket_handler=save_session_ticket, local_port=local_port, wait_connected=not zero_rtt, ) as client: client = cast(HttpClient, client) if parsed.scheme == "wss": ws = await client.websocket(urls[0], subprotocols=["chat", "superchat"]) # send some messages and receive reply for i in range(2): message = "Hello {}, WebSocket!".format(i) print("> " + message) await ws.send(message) message = await ws.recv() print("< " + message) await ws.close() else: # perform request coros = [ perform_http_request( client=client, url=url, data=data, include=include, output_dir=output_dir, ) for url in urls ] await asyncio.gather(*coros) # process http pushes process_http_pushes(client=client, include=include, output_dir=output_dir) client.close(error_code=ErrorCode.H3_NO_ERROR) if __name__ == "__main__": defaults = QuicConfiguration(is_client=True) parser = argparse.ArgumentParser(description="HTTP/3 client") parser.add_argument( "url", type=str, nargs="+", help="the URL to query (must be HTTPS)" ) parser.add_argument( "--ca-certs", type=str, help="load CA certificates from the specified file" ) parser.add_argument( "--certificate", type=str, help="load the TLS certificate from the specified file", ) parser.add_argument( "--cipher-suites", type=str, help=( "only advertise the given cipher suites, e.g. `AES_256_GCM_SHA384," "CHACHA20_POLY1305_SHA256`" ), ) parser.add_argument( "--congestion-control-algorithm", type=str, default="reno", help="use the specified congestion control algorithm", ) parser.add_argument( "-d", "--data", type=str, help="send the specified data in a POST request" ) parser.add_argument( "-i", "--include", action="store_true", help="include the HTTP response headers in the output", ) parser.add_argument( "--insecure", action="store_true", help="do not validate server certificate", ) parser.add_argument( "--legacy-http", action="store_true", help="use HTTP/0.9", ) parser.add_argument( "--max-data", type=int, help="connection-wide flow control limit (default: %d)" % defaults.max_data, ) parser.add_argument( "--max-stream-data", type=int, help="per-stream flow control limit (default: %d)" % defaults.max_stream_data, ) parser.add_argument( "--negotiate-v2", action="store_true", help="start with QUIC v1 and try to negotiate QUIC v2", ) parser.add_argument( "--output-dir", type=str, help="write downloaded files to this directory", ) parser.add_argument( "--private-key", type=str, help="load the TLS private key from the specified file", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-s", "--session-ticket", type=str, help="read and write session ticket from the specified file", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) parser.add_argument( "--local-port", type=int, default=0, help="local port to bind for connections", ) parser.add_argument( "--max-datagram-size", type=int, default=defaults.max_datagram_size, help="maximum datagram size to send, excluding UDP or IP overhead", ) parser.add_argument( "--zero-rtt", action="store_true", help="try to send requests using 0-RTT" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) if args.output_dir is not None and not os.path.isdir(args.output_dir): raise Exception("%s is not a directory" % args.output_dir) # prepare configuration configuration = QuicConfiguration( is_client=True, alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN, congestion_control_algorithm=args.congestion_control_algorithm, max_datagram_size=args.max_datagram_size, ) if args.ca_certs: configuration.load_verify_locations(args.ca_certs) if args.cipher_suites: configuration.cipher_suites = [ CipherSuite[s] for s in args.cipher_suites.split(",") ] if args.insecure: configuration.verify_mode = ssl.CERT_NONE if args.max_data: configuration.max_data = args.max_data if args.max_stream_data: configuration.max_stream_data = args.max_stream_data if args.negotiate_v2: configuration.original_version = QuicProtocolVersion.VERSION_1 configuration.supported_versions = [ QuicProtocolVersion.VERSION_2, QuicProtocolVersion.VERSION_1, ] if args.quic_log: configuration.quic_logger = QuicFileLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: configuration.session_ticket = pickle.load(fp) except FileNotFoundError: pass # load SSL certificate and key if args.certificate is not None: configuration.load_cert_chain(args.certificate, args.private_key) if uvloop is not None: uvloop.install() asyncio.run( main( configuration=configuration, urls=args.url, data=args.data, include=args.include, output_dir=args.output_dir, local_port=args.local_port, zero_rtt=args.zero_rtt, ) ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/http3_server.py0000644000175100001770000005231000000000000020573 0ustar00runnerdocker00000000000000import argparse import asyncio import importlib import logging import time from collections import deque from email.utils import formatdate from typing import Callable, Deque, Dict, List, Optional, Union, cast import aioquic import wsproto import wsproto.events from aioquic.asyncio import QuicConnectionProtocol, serve from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.connection import H3_ALPN, H3Connection from aioquic.h3.events import ( DatagramReceived, DataReceived, H3Event, HeadersReceived, WebTransportStreamDataReceived, ) from aioquic.h3.exceptions import NoAvailablePushIDError from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent from aioquic.quic.logger import QuicFileLogger from aioquic.tls import SessionTicket try: import uvloop except ImportError: uvloop = None AsgiApplication = Callable HttpConnection = Union[H0Connection, H3Connection] SERVER_NAME = "aioquic/" + aioquic.__version__ class HttpRequestHandler: def __init__( self, *, authority: bytes, connection: HttpConnection, protocol: QuicConnectionProtocol, scope: Dict, stream_ended: bool, stream_id: int, transmit: Callable[[], None], ) -> None: self.authority = authority self.connection = connection self.protocol = protocol self.queue: asyncio.Queue[Dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit if stream_ended: self.queue.put_nowait({"type": "http.request"}) def http_event_received(self, event: H3Event) -> None: if isinstance(event, DataReceived): self.queue.put_nowait( { "type": "http.request", "body": event.data, "more_body": not event.stream_ended, } ) elif isinstance(event, HeadersReceived) and event.stream_ended: self.queue.put_nowait( {"type": "http.request", "body": b"", "more_body": False} ) async def run_asgi(self, app: AsgiApplication) -> None: await app(self.scope, self.receive, self.send) async def receive(self) -> Dict: return await self.queue.get() async def send(self, message: Dict) -> None: if message["type"] == "http.response.start": self.connection.send_headers( stream_id=self.stream_id, headers=[ (b":status", str(message["status"]).encode()), (b"server", SERVER_NAME.encode()), (b"date", formatdate(time.time(), usegmt=True).encode()), ] + [(k, v) for k, v in message["headers"]], ) elif message["type"] == "http.response.body": self.connection.send_data( stream_id=self.stream_id, data=message.get("body", b""), end_stream=not message.get("more_body", False), ) elif message["type"] == "http.response.push" and isinstance( self.connection, H3Connection ): request_headers = [ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", self.authority), (b":path", message["path"].encode()), ] + [(k, v) for k, v in message["headers"]] # send push promise try: push_stream_id = self.connection.send_push_promise( stream_id=self.stream_id, headers=request_headers ) except NoAvailablePushIDError: return # fake request cast(HttpServerProtocol, self.protocol).http_event_received( HeadersReceived( headers=request_headers, stream_ended=True, stream_id=push_stream_id ) ) self.transmit() class WebSocketHandler: def __init__( self, *, connection: HttpConnection, scope: Dict, stream_id: int, transmit: Callable[[], None], ) -> None: self.closed = False self.connection = connection self.http_event_queue: Deque[DataReceived] = deque() self.queue: asyncio.Queue[Dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit self.websocket: Optional[wsproto.Connection] = None def http_event_received(self, event: H3Event) -> None: if isinstance(event, DataReceived) and not self.closed: if self.websocket is not None: self.websocket.receive_data(event.data) for ws_event in self.websocket.events(): self.websocket_event_received(ws_event) else: # delay event processing until we get `websocket.accept` # from the ASGI application self.http_event_queue.append(event) def websocket_event_received(self, event: wsproto.events.Event) -> None: if isinstance(event, wsproto.events.TextMessage): self.queue.put_nowait({"type": "websocket.receive", "text": event.data}) elif isinstance(event, wsproto.events.Message): self.queue.put_nowait({"type": "websocket.receive", "bytes": event.data}) elif isinstance(event, wsproto.events.CloseConnection): self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code}) async def run_asgi(self, app: AsgiApplication) -> None: self.queue.put_nowait({"type": "websocket.connect"}) try: await app(self.scope, self.receive, self.send) finally: if not self.closed: await self.send({"type": "websocket.close", "code": 1000}) async def receive(self) -> Dict: return await self.queue.get() async def send(self, message: Dict) -> None: data = b"" end_stream = False if message["type"] == "websocket.accept": subprotocol = message.get("subprotocol") self.websocket = wsproto.Connection(wsproto.ConnectionType.SERVER) headers = [ (b":status", b"200"), (b"server", SERVER_NAME.encode()), (b"date", formatdate(time.time(), usegmt=True).encode()), ] if subprotocol is not None: headers.append((b"sec-websocket-protocol", subprotocol.encode())) self.connection.send_headers(stream_id=self.stream_id, headers=headers) # consume backlog while self.http_event_queue: self.http_event_received(self.http_event_queue.popleft()) elif message["type"] == "websocket.close": if self.websocket is not None: data = self.websocket.send( wsproto.events.CloseConnection(code=message["code"]) ) else: self.connection.send_headers( stream_id=self.stream_id, headers=[(b":status", b"403")] ) end_stream = True elif message["type"] == "websocket.send": if message.get("text") is not None: data = self.websocket.send( wsproto.events.TextMessage(data=message["text"]) ) elif message.get("bytes") is not None: data = self.websocket.send( wsproto.events.Message(data=message["bytes"]) ) if data: self.connection.send_data( stream_id=self.stream_id, data=data, end_stream=end_stream ) if end_stream: self.closed = True self.transmit() class WebTransportHandler: def __init__( self, *, connection: H3Connection, scope: Dict, stream_id: int, transmit: Callable[[], None], ) -> None: self.accepted = False self.closed = False self.connection = connection self.http_event_queue: Deque[H3Event] = deque() self.queue: asyncio.Queue[Dict] = asyncio.Queue() self.scope = scope self.stream_id = stream_id self.transmit = transmit def http_event_received(self, event: H3Event) -> None: if not self.closed: if self.accepted: if isinstance(event, DatagramReceived): self.queue.put_nowait( { "data": event.data, "type": "webtransport.datagram.receive", } ) elif isinstance(event, WebTransportStreamDataReceived): self.queue.put_nowait( { "data": event.data, "stream": event.stream_id, "type": "webtransport.stream.receive", } ) else: # delay event processing until we get `webtransport.accept` # from the ASGI application self.http_event_queue.append(event) async def run_asgi(self, app: AsgiApplication) -> None: self.queue.put_nowait({"type": "webtransport.connect"}) try: await app(self.scope, self.receive, self.send) finally: if not self.closed: await self.send({"type": "webtransport.close"}) async def receive(self) -> Dict: return await self.queue.get() async def send(self, message: Dict) -> None: data = b"" end_stream = False if message["type"] == "webtransport.accept": self.accepted = True headers = [ (b":status", b"200"), (b"server", SERVER_NAME.encode()), (b"date", formatdate(time.time(), usegmt=True).encode()), (b"sec-webtransport-http3-draft", b"draft02"), ] self.connection.send_headers(stream_id=self.stream_id, headers=headers) # consume backlog while self.http_event_queue: self.http_event_received(self.http_event_queue.popleft()) elif message["type"] == "webtransport.close": if not self.accepted: self.connection.send_headers( stream_id=self.stream_id, headers=[(b":status", b"403")] ) end_stream = True elif message["type"] == "webtransport.datagram.send": self.connection.send_datagram( stream_id=self.stream_id, data=message["data"] ) elif message["type"] == "webtransport.stream.send": self.connection._quic.send_stream_data( stream_id=message["stream"], data=message["data"] ) if data or end_stream: self.connection.send_data( stream_id=self.stream_id, data=data, end_stream=end_stream ) if end_stream: self.closed = True self.transmit() Handler = Union[HttpRequestHandler, WebSocketHandler, WebTransportHandler] class HttpServerProtocol(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._handlers: Dict[int, Handler] = {} self._http: Optional[HttpConnection] = None def http_event_received(self, event: H3Event) -> None: if isinstance(event, HeadersReceived) and event.stream_id not in self._handlers: authority = None headers = [] http_version = "0.9" if isinstance(self._http, H0Connection) else "3" raw_path = b"" method = "" protocol = None for header, value in event.headers: if header == b":authority": authority = value headers.append((b"host", value)) elif header == b":method": method = value.decode() elif header == b":path": raw_path = value elif header == b":protocol": protocol = value.decode() elif header and not header.startswith(b":"): headers.append((header, value)) if b"?" in raw_path: path_bytes, query_string = raw_path.split(b"?", maxsplit=1) else: path_bytes, query_string = raw_path, b"" path = path_bytes.decode() self._quic._logger.info("HTTP request %s %s", method, path) # FIXME: add a public API to retrieve peer address client_addr = self._http._quic._network_paths[0].addr client = (client_addr[0], client_addr[1]) handler: Handler scope: Dict if method == "CONNECT" and protocol == "websocket": subprotocols: List[str] = [] for header, value in event.headers: if header == b"sec-websocket-protocol": subprotocols = [x.strip() for x in value.decode().split(",")] scope = { "client": client, "headers": headers, "http_version": http_version, "method": method, "path": path, "query_string": query_string, "raw_path": raw_path, "root_path": "", "scheme": "wss", "subprotocols": subprotocols, "type": "websocket", } handler = WebSocketHandler( connection=self._http, scope=scope, stream_id=event.stream_id, transmit=self.transmit, ) elif method == "CONNECT" and protocol == "webtransport": assert isinstance( self._http, H3Connection ), "WebTransport is only supported over HTTP/3" scope = { "client": client, "headers": headers, "http_version": http_version, "method": method, "path": path, "query_string": query_string, "raw_path": raw_path, "root_path": "", "scheme": "https", "type": "webtransport", } handler = WebTransportHandler( connection=self._http, scope=scope, stream_id=event.stream_id, transmit=self.transmit, ) else: extensions: Dict[str, Dict] = {} if isinstance(self._http, H3Connection): extensions["http.response.push"] = {} scope = { "client": client, "extensions": extensions, "headers": headers, "http_version": http_version, "method": method, "path": path, "query_string": query_string, "raw_path": raw_path, "root_path": "", "scheme": "https", "type": "http", } handler = HttpRequestHandler( authority=authority, connection=self._http, protocol=self, scope=scope, stream_ended=event.stream_ended, stream_id=event.stream_id, transmit=self.transmit, ) self._handlers[event.stream_id] = handler asyncio.ensure_future(handler.run_asgi(application)) elif ( isinstance(event, (DataReceived, HeadersReceived)) and event.stream_id in self._handlers ): handler = self._handlers[event.stream_id] handler.http_event_received(event) elif isinstance(event, DatagramReceived): handler = self._handlers[event.stream_id] handler.http_event_received(event) elif isinstance(event, WebTransportStreamDataReceived): handler = self._handlers[event.session_id] handler.http_event_received(event) def quic_event_received(self, event: QuicEvent) -> None: if isinstance(event, ProtocolNegotiated): if event.alpn_protocol in H3_ALPN: self._http = H3Connection(self._quic, enable_webtransport=True) elif event.alpn_protocol in H0_ALPN: self._http = H0Connection(self._quic) elif isinstance(event, DatagramFrameReceived): if event.data == b"quack": self._quic.send_datagram_frame(b"quack-ack") #  pass event to the HTTP layer if self._http is not None: for http_event in self._http.handle_event(event): self.http_event_received(http_event) class SessionTicketStore: """ Simple in-memory store for session tickets. """ def __init__(self) -> None: self.tickets: Dict[bytes, SessionTicket] = {} def add(self, ticket: SessionTicket) -> None: self.tickets[ticket.ticket] = ticket def pop(self, label: bytes) -> Optional[SessionTicket]: return self.tickets.pop(label, None) async def main( host: str, port: int, configuration: QuicConfiguration, session_ticket_store: SessionTicketStore, retry: bool, ) -> None: await serve( host, port, configuration=configuration, create_protocol=HttpServerProtocol, session_ticket_fetcher=session_ticket_store.pop, session_ticket_handler=session_ticket_store.add, retry=retry, ) await asyncio.Future() if __name__ == "__main__": defaults = QuicConfiguration(is_client=False) parser = argparse.ArgumentParser(description="QUIC server") parser.add_argument( "app", type=str, nargs="?", default="demo:app", help="the ASGI application as :", ) parser.add_argument( "-c", "--certificate", type=str, required=True, help="load the TLS certificate from the specified file", ) parser.add_argument( "--congestion-control-algorithm", type=str, default="reno", help="use the specified congestion control algorithm", ) parser.add_argument( "--host", type=str, default="::", help="listen on the specified address (defaults to ::)", ) parser.add_argument( "--port", type=int, default=4433, help="listen on the specified port (defaults to 4433)", ) parser.add_argument( "-k", "--private-key", type=str, help="load the TLS private key from the specified file", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "--max-datagram-size", type=int, default=defaults.max_datagram_size, help="maximum datagram size to send, excluding UDP or IP overhead", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "--retry", action="store_true", help="send a retry for new connections", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) # import ASGI application module_str, attr_str = args.app.split(":", maxsplit=1) module = importlib.import_module(module_str) application = getattr(module, attr_str) # create QUIC logger if args.quic_log: quic_logger = QuicFileLogger(args.quic_log) else: quic_logger = None # open SSL log file if args.secrets_log: secrets_log_file = open(args.secrets_log, "a") else: secrets_log_file = None configuration = QuicConfiguration( alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], congestion_control_algorithm=args.congestion_control_algorithm, is_client=False, max_datagram_frame_size=65536, max_datagram_size=args.max_datagram_size, quic_logger=quic_logger, secrets_log_file=secrets_log_file, ) # load SSL certificate and key configuration.load_cert_chain(args.certificate, args.private_key) if uvloop is not None: uvloop.install() try: asyncio.run( main( host=args.host, port=args.port, configuration=configuration, session_ticket_store=SessionTicketStore(), retry=args.retry, ) ) except KeyboardInterrupt: pass ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/httpx_client.py0000644000175100001770000002312300000000000020650 0ustar00runnerdocker00000000000000import argparse import asyncio import logging import os import pickle import ssl import time from collections import deque from typing import AsyncIterator, Deque, Dict, Optional, Tuple, cast from urllib.parse import urlparse import httpx from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.h3.connection import H3_ALPN, H3Connection from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent from aioquic.quic.logger import QuicFileLogger logger = logging.getLogger("client") class H3ResponseStream(httpx.AsyncByteStream): def __init__(self, aiterator: AsyncIterator[bytes]): self._aiterator = aiterator async def __aiter__(self) -> AsyncIterator[bytes]: async for part in self._aiterator: yield part class H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._http = H3Connection(self._quic) self._read_queue: Dict[int, Deque[H3Event]] = {} self._read_ready: Dict[int, asyncio.Event] = {} async def handle_async_request(self, request: httpx.Request) -> httpx.Response: assert isinstance(request.stream, httpx.AsyncByteStream) stream_id = self._quic.get_next_available_stream_id() self._read_queue[stream_id] = deque() self._read_ready[stream_id] = asyncio.Event() # prepare request self._http.send_headers( stream_id=stream_id, headers=[ (b":method", request.method.encode()), (b":scheme", request.url.raw_scheme), (b":authority", request.url.netloc), (b":path", request.url.raw_path), ] + [ (k.lower(), v) for (k, v) in request.headers.raw if k.lower() not in (b"connection", b"host") ], ) async for data in request.stream: self._http.send_data(stream_id=stream_id, data=data, end_stream=False) self._http.send_data(stream_id=stream_id, data=b"", end_stream=True) # transmit request self.transmit() # process response status_code, headers, stream_ended = await self._receive_response(stream_id) return httpx.Response( status_code=status_code, headers=headers, stream=H3ResponseStream( self._receive_response_data(stream_id, stream_ended) ), extensions={ "http_version": b"HTTP/3", }, ) def http_event_received(self, event: H3Event): if isinstance(event, (HeadersReceived, DataReceived)): stream_id = event.stream_id if stream_id in self._read_queue: self._read_queue[event.stream_id].append(event) self._read_ready[event.stream_id].set() def quic_event_received(self, event: QuicEvent): #  pass event to the HTTP layer if self._http is not None: for http_event in self._http.handle_event(event): self.http_event_received(http_event) async def _receive_response(self, stream_id: int) -> Tuple[int, Headers, bool]: """ Read the response status and headers. """ stream_ended = False while True: event = await self._wait_for_http_event(stream_id) if isinstance(event, HeadersReceived): stream_ended = event.stream_ended break headers = [] status_code = 0 for header, value in event.headers: if header == b":status": status_code = int(value.decode()) else: headers.append((header, value)) return status_code, headers, stream_ended async def _receive_response_data( self, stream_id: int, stream_ended: bool ) -> AsyncIterator[bytes]: """ Read the response data. """ while not stream_ended: event = await self._wait_for_http_event(stream_id) if isinstance(event, DataReceived): stream_ended = event.stream_ended yield event.data elif isinstance(event, HeadersReceived): stream_ended = event.stream_ended async def _wait_for_http_event(self, stream_id: int) -> H3Event: """ Returns the next HTTP/3 event for the given stream. """ if not self._read_queue[stream_id]: await self._read_ready[stream_id].wait() event = self._read_queue[stream_id].popleft() if not self._read_queue[stream_id]: self._read_ready[stream_id].clear() return event def save_session_ticket(ticket): """ Callback which is invoked by the TLS engine when a new session ticket is received. """ logger.info("New session ticket received") if args.session_ticket: with open(args.session_ticket, "wb") as fp: pickle.dump(ticket, fp) async def main( configuration: QuicConfiguration, url: str, data: Optional[str], include: bool, output_dir: Optional[str], ) -> None: # parse URL parsed = urlparse(url) assert parsed.scheme == "https", "Only https:// URLs are supported." host = parsed.hostname if parsed.port is not None: port = parsed.port else: port = 443 async with connect( host, port, configuration=configuration, create_protocol=H3Transport, session_ticket_handler=save_session_ticket, ) as transport: async with httpx.AsyncClient( transport=cast(httpx.AsyncBaseTransport, transport) ) as client: # perform request start = time.time() if data is not None: response = await client.post( url, content=data.encode(), headers={"content-type": "application/x-www-form-urlencoded"}, ) else: response = await client.get(url) elapsed = time.time() - start # print speed octets = len(response.content) logger.info( "Received %d bytes in %.1f s (%.3f Mbps)" % (octets, elapsed, octets * 8 / elapsed / 1000000) ) # output response if output_dir is not None: output_path = os.path.join( output_dir, os.path.basename(urlparse(url).path) or "index.html" ) with open(output_path, "wb") as output_file: if include: headers = "" for header, value in response.headers.items(): headers += header + ": " + value + "\r\n" if headers: output_file.write(headers.encode() + b"\r\n") output_file.write(response.content) if __name__ == "__main__": parser = argparse.ArgumentParser(description="HTTP/3 client") parser.add_argument("url", type=str, help="the URL to query (must be HTTPS)") parser.add_argument( "--ca-certs", type=str, help="load CA certificates from the specified file" ) parser.add_argument( "-d", "--data", type=str, help="send the specified data in a POST request" ) parser.add_argument( "-i", "--include", action="store_true", help="include the HTTP response headers in the output", ) parser.add_argument( "-k", "--insecure", action="store_true", help="do not validate server certificate", ) parser.add_argument( "--output-dir", type=str, help="write downloaded files to this directory", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-s", "--session-ticket", type=str, help="read and write session ticket from the specified file", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) if args.output_dir is not None and not os.path.isdir(args.output_dir): raise Exception("%s is not a directory" % args.output_dir) # prepare configuration configuration = QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN) if args.ca_certs: configuration.load_verify_locations(args.ca_certs) if args.insecure: configuration.verify_mode = ssl.CERT_NONE if args.quic_log: configuration.quic_logger = QuicFileLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: configuration.session_ticket = pickle.load(fp) except FileNotFoundError: pass asyncio.run( main( configuration=configuration, url=args.url, data=args.data, include=args.include, output_dir=args.output_dir, ) ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/interop.py0000644000175100001770000004343400000000000017632 0ustar00runnerdocker00000000000000# # !!! WARNING !!! # # This example uses some private APIs. # import argparse import asyncio import logging import ssl import time from dataclasses import dataclass, field from enum import Flag from typing import Optional, cast import httpx from aioquic.asyncio import connect from aioquic.h0.connection import H0_ALPN from aioquic.h3.connection import H3_ALPN, H3Connection from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.logger import QuicFileLogger, QuicLogger from http3_client import HttpClient class Result(Flag): V = 0x000001 H = 0x000002 D = 0x000004 C = 0x000008 R = 0x000010 Z = 0x000020 S = 0x000040 Q = 0x000080 M = 0x000100 B = 0x000200 A = 0x000400 U = 0x000800 P = 0x001000 E = 0x002000 L = 0x004000 T = 0x008000 three = 0x010000 d = 0x020000 p = 0x040000 def __str__(self): flags = sorted( map( lambda x: getattr(Result, x), filter(lambda x: not x.startswith("_"), dir(Result)), ), key=lambda x: x.value, ) result_str = "" for flag in flags: if self & flag: result_str += flag.name else: result_str += "-" return result_str @dataclass class Server: name: str host: str port: int = 4433 http3: bool = True http3_port: Optional[int] = None retry_port: Optional[int] = 4434 path: str = "/" push_path: Optional[str] = None result: Result = field(default_factory=lambda: Result(0)) session_resumption_port: Optional[int] = None structured_logging: bool = False throughput_path: Optional[str] = "/%(size)d" verify_mode: Optional[int] = None SERVERS = [ Server("akamaiquic", "ietf.akaquic.com", port=443, verify_mode=ssl.CERT_NONE), Server( "aioquic", "quic.aiortc.org", port=443, push_path="/", structured_logging=True ), Server("ats", "quic.ogre.com"), Server("f5", "f5quic.com", retry_port=4433, throughput_path=None), Server( "haskell", "mew.org", structured_logging=True, throughput_path="/num/%(size)s" ), Server("gquic", "quic.rocks", retry_port=None), Server("lsquic", "http3-test.litespeedtech.com", push_path="/200?push=/100"), Server( "msquic", "quic.westus.cloudapp.azure.com", structured_logging=True, throughput_path=None, # "/%(size)d.txt", verify_mode=ssl.CERT_NONE, ), Server( "mvfst", "fb.mvfst.net", port=443, push_path="/push", retry_port=None, structured_logging=True, ), Server( "ngtcp2", "nghttp2.org", push_path="/?push=/100", structured_logging=True, throughput_path=None, ), Server("ngx_quic", "cloudflare-quic.com", port=443, retry_port=None), Server("pandora", "pandora.cm.in.tum.de", verify_mode=ssl.CERT_NONE), Server("picoquic", "test.privateoctopus.com", structured_logging=True), Server("quant", "quant.eggert.org", http3=False, structured_logging=True), Server("quic-go", "interop.seemann.io", port=443, retry_port=443), Server("quiche", "quic.tech", port=8443, retry_port=8444), Server("quicly", "quic.examp1e.net", http3_port=443), Server("quinn", "h3.stammw.eu", port=443), ] async def test_version_negotiation(server: Server, configuration: QuicConfiguration): # force version negotiation configuration.supported_versions.insert(0, 0x1A2A3A4A) async with connect( server.host, server.port, configuration=configuration ) as protocol: await protocol.ping() # check log for event in configuration.quic_logger.to_dict()["traces"][0]["events"]: if ( event["name"] == "transport:packet_received" and event["data"]["header"]["packet_type"] == "version_negotiation" ): server.result |= Result.V async def test_handshake_and_close(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: await protocol.ping() server.result |= Result.H server.result |= Result.C async def test_retry(server: Server, configuration: QuicConfiguration): # skip test if there is not retry port if server.retry_port is None: return async with connect( server.host, server.retry_port, configuration=configuration ) as protocol: await protocol.ping() # check log for event in configuration.quic_logger.to_dict()["traces"][0]["events"]: if ( event["name"] == "transport:packet_received" and event["data"]["header"]["packet_type"] == "retry" ): server.result |= Result.S async def test_quantum_readiness(server: Server, configuration: QuicConfiguration): configuration.quantum_readiness_test = True async with connect( server.host, server.port, configuration=configuration ) as protocol: await protocol.ping() server.result |= Result.Q async def test_http_0(server: Server, configuration: QuicConfiguration): if server.path is None: return configuration.alpn_protocols = H0_ALPN async with connect( server.host, server.port, configuration=configuration, create_protocol=HttpClient, ) as protocol: protocol = cast(HttpClient, protocol) # perform HTTP request events = await protocol.get( "https://{}:{}{}".format(server.host, server.port, server.path) ) if events and isinstance(events[0], HeadersReceived): server.result |= Result.D async def test_http_3(server: Server, configuration: QuicConfiguration): port = server.http3_port or server.port if server.path is None: return configuration.alpn_protocols = H3_ALPN async with connect( server.host, port, configuration=configuration, create_protocol=HttpClient, ) as protocol: protocol = cast(HttpClient, protocol) # perform HTTP request events = await protocol.get( "https://{}:{}{}".format(server.host, server.port, server.path) ) if events and isinstance(events[0], HeadersReceived): server.result |= Result.D server.result |= Result.three # perform more HTTP requests to use QPACK dynamic tables for i in range(2): events = await protocol.get( "https://{}:{}{}".format(server.host, server.port, server.path) ) if events and isinstance(events[0], HeadersReceived): http = cast(H3Connection, protocol._http) protocol._quic._logger.info( "QPACK decoder bytes RX %d TX %d", http._decoder_bytes_received, http._decoder_bytes_sent, ) protocol._quic._logger.info( "QPACK encoder bytes RX %d TX %d", http._encoder_bytes_received, http._encoder_bytes_sent, ) if ( http._decoder_bytes_received and http._decoder_bytes_sent and http._encoder_bytes_received and http._encoder_bytes_sent ): server.result |= Result.d # check push support if server.push_path is not None: protocol.pushes.clear() await protocol.get( "https://{}:{}{}".format(server.host, server.port, server.push_path) ) await asyncio.sleep(0.5) for push_id, events in protocol.pushes.items(): if ( len(events) >= 3 and isinstance(events[0], PushPromiseReceived) and isinstance(events[1], HeadersReceived) and isinstance(events[2], DataReceived) ): protocol._quic._logger.info( "Push promise %d for %s received (status %s)", push_id, dict(events[0].headers)[b":path"].decode("ascii"), int(dict(events[1].headers)[b":status"]), ) server.result |= Result.p async def test_session_resumption(server: Server, configuration: QuicConfiguration): port = server.session_resumption_port or server.port saved_ticket = None def session_ticket_handler(ticket): nonlocal saved_ticket saved_ticket = ticket # connect a first time, receive a ticket async with connect( server.host, port, configuration=configuration, session_ticket_handler=session_ticket_handler, ) as protocol: await protocol.ping() # some servers don't send the ticket immediately await asyncio.sleep(1) # connect a second time, with the ticket if saved_ticket is not None: configuration.session_ticket = saved_ticket async with connect(server.host, port, configuration=configuration) as protocol: await protocol.ping() # check session was resumed if protocol._quic.tls.session_resumed: server.result |= Result.R # check early data was accepted if protocol._quic.tls.early_data_accepted: server.result |= Result.Z async def test_key_update(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: # cause some traffic await protocol.ping() # request key update protocol.request_key_update() # cause more traffic await protocol.ping() server.result |= Result.U async def test_server_cid_change(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: # cause some traffic await protocol.ping() # change connection ID protocol.change_connection_id() # cause more traffic await protocol.ping() server.result |= Result.M async def test_nat_rebinding(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: # cause some traffic await protocol.ping() # replace transport protocol._transport.close() loop = asyncio.get_event_loop() await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0)) # cause more traffic await protocol.ping() # check log path_challenges = 0 for event in configuration.quic_logger.to_dict()["traces"][0]["events"]: if ( event["name"] == "transport:packet_received" and event["data"]["header"]["packet_type"] == "1RTT" ): for frame in event["data"]["frames"]: if frame["frame_type"] == "path_challenge": path_challenges += 1 if not path_challenges: protocol._quic._logger.warning("No PATH_CHALLENGE received") else: server.result |= Result.B async def test_address_mobility(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: # cause some traffic await protocol.ping() # replace transport protocol._transport.close() loop = asyncio.get_event_loop() await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0)) # change connection ID protocol.change_connection_id() # cause more traffic await protocol.ping() # check log path_challenges = 0 for event in configuration.quic_logger.to_dict()["traces"][0]["events"]: if ( event["name"] == "transport:packet_received" and event["data"]["header"]["packet_type"] == "1RTT" ): for frame in event["data"]["frames"]: if frame["frame_type"] == "path_challenge": path_challenges += 1 if not path_challenges: protocol._quic._logger.warning("No PATH_CHALLENGE received") else: server.result |= Result.A async def test_spin_bit(server: Server, configuration: QuicConfiguration): async with connect( server.host, server.port, configuration=configuration ) as protocol: for i in range(5): await protocol.ping() # check log spin_bits = set() for event in configuration.quic_logger.to_dict()["traces"][0]["events"]: if event["name"] == "connectivity:spin_bit_updated": spin_bits.add(event["data"]["state"]) if len(spin_bits) == 2: server.result |= Result.P async def test_throughput(server: Server, configuration: QuicConfiguration): failures = 0 if server.throughput_path is None: return for size in [5000000, 10000000]: path = server.throughput_path % {"size": size} print("Testing %d bytes download: %s" % (size, path)) # perform HTTP request over TCP start = time.time() response = httpx.get("https://" + server.host + path, verify=False) tcp_octets = len(response.content) tcp_elapsed = time.time() - start assert tcp_octets == size, "HTTP/TCP response size mismatch" # perform HTTP request over QUIC if server.http3: configuration.alpn_protocols = H3_ALPN port = server.http3_port or server.port else: configuration.alpn_protocols = H0_ALPN port = server.port start = time.time() async with connect( server.host, port, configuration=configuration, create_protocol=HttpClient, ) as protocol: protocol = cast(HttpClient, protocol) http_events = await protocol.get( "https://{}:{}{}".format(server.host, server.port, path) ) quic_elapsed = time.time() - start quic_octets = 0 for http_event in http_events: if isinstance(http_event, DataReceived): quic_octets += len(http_event.data) assert quic_octets == size, "HTTP/QUIC response size mismatch" print(" - HTTP/TCP completed in %.3f s" % tcp_elapsed) print(" - HTTP/QUIC completed in %.3f s" % quic_elapsed) if quic_elapsed > 1.1 * tcp_elapsed: failures += 1 print(" => FAIL") else: print(" => PASS") if failures == 0: server.result |= Result.T def print_result(server: Server) -> None: result = str(server.result).replace("three", "3") result = result[0:8] + " " + result[8:16] + " " + result[16:] print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result)) async def main(servers, tests, quic_log=False, secrets_log_file=None) -> None: for server in servers: if server.structured_logging: server.result |= Result.L for test_name, test_func in tests: print("\n=== %s %s ===\n" % (server.name, test_name)) configuration = QuicConfiguration( alpn_protocols=H3_ALPN + H0_ALPN, is_client=True, quic_logger=QuicFileLogger(quic_log) if quic_log else QuicLogger(), secrets_log_file=secrets_log_file, verify_mode=server.verify_mode, ) if test_name == "test_throughput": timeout = 120 else: timeout = 10 try: await asyncio.wait_for( test_func(server, configuration), timeout=timeout ) except Exception as exc: print(exc) print("") print_result(server) # print summary if len(servers) > 1: print("SUMMARY") for server in servers: print_result(server) if __name__ == "__main__": parser = argparse.ArgumentParser(description="QUIC interop client") parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "--server", type=str, help="only run against the specified server." ) parser.add_argument("--test", type=str, help="only run the specified test.") parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) # open SSL log file if args.secrets_log: secrets_log_file = open(args.secrets_log, "a") else: secrets_log_file = None # determine what to run servers = SERVERS tests = list(filter(lambda x: x[0].startswith("test_"), globals().items())) if args.server: servers = list(filter(lambda x: x.name == args.server, servers)) if args.test: tests = list(filter(lambda x: x[0] == args.test, tests)) asyncio.run( main( servers=servers, tests=tests, quic_log=args.quic_log, secrets_log_file=secrets_log_file, ) ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/siduck_client.py0000644000175100001770000000610200000000000020761 0ustar00runnerdocker00000000000000import argparse import asyncio import logging import ssl from typing import Optional, cast from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import DatagramFrameReceived, QuicEvent from aioquic.quic.logger import QuicFileLogger logger = logging.getLogger("client") class SiduckClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._ack_waiter: Optional[asyncio.Future[None]] = None async def quack(self) -> None: assert self._ack_waiter is None, "Only one quack at a time." self._quic.send_datagram_frame(b"quack") waiter = self._loop.create_future() self._ack_waiter = waiter self.transmit() return await asyncio.shield(waiter) def quic_event_received(self, event: QuicEvent) -> None: if self._ack_waiter is not None: if isinstance(event, DatagramFrameReceived) and event.data == b"quack-ack": waiter = self._ack_waiter self._ack_waiter = None waiter.set_result(None) async def main(configuration: QuicConfiguration, host: str, port: int) -> None: async with connect( host, port, configuration=configuration, create_protocol=SiduckClient ) as client: client = cast(SiduckClient, client) logger.info("sending quack") await client.quack() logger.info("received quack-ack") if __name__ == "__main__": parser = argparse.ArgumentParser(description="SiDUCK client") parser.add_argument( "host", type=str, help="The remote peer's host name or IP address" ) parser.add_argument("port", type=int, help="The remote peer's port number") parser.add_argument( "-k", "--insecure", action="store_true", help="do not validate server certificate", ) parser.add_argument( "-q", "--quic-log", type=str, help="log QUIC events to QLOG files in the specified directory", ) parser.add_argument( "-l", "--secrets-log", type=str, help="log secrets to a file, for use with Wireshark", ) parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG if args.verbose else logging.INFO, ) configuration = QuicConfiguration( alpn_protocols=["siduck"], is_client=True, max_datagram_frame_size=65536 ) if args.insecure: configuration.verify_mode = ssl.CERT_NONE if args.quic_log: configuration.quic_logger = QuicFileLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") asyncio.run( main( configuration=configuration, host=args.host, port=args.port, ) ) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.121294 aioquic-1.2.0/examples/templates/0000755000175100001770000000000000000000000017566 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/templates/index.html0000644000175100001770000000234500000000000021567 0ustar00runnerdocker00000000000000 aioquic

Welcome to aioquic

This is a test page for aioquic, a QUIC and HTTP/3 implementation written in Python.

{% if request.scope["http_version"] == "3" %}

Congratulations, you loaded this page using HTTP/3!

{% endif %}

Available endpoints

  • GET / returns the homepage
  • GET /NNNNN returns NNNNN bytes of plain text
  • POST /echo returns the request data
  • CONNECT /ws runs a WebSocket echo service. You must set the :protocol pseudo-header to "websocket".
  • CONNECT /wt runs a WebTransport echo service. You must set the :protocol pseudo-header to "webtransport".
././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/examples/templates/logs.html0000644000175100001770000000130700000000000021421 0ustar00runnerdocker00000000000000 aioquic - logs

QLOG files

{% for log in logs %} {% endfor %}
name date (UTC) size
{{ log.name }} [qvis] {{ log.date }} {{ log.size }}
././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/pyproject.toml0000644000175100001770000000332700000000000016673 0ustar00runnerdocker00000000000000[build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] name = "aioquic" description = "An implementation of QUIC and HTTP/3" readme = "README.rst" requires-python = ">=3.8" license = { text = "BSD-3-Clause" } authors = [ { name = "Jeremy Lainé", email = "jeremy.laine@m4x.org" }, ] classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Internet :: WWW/HTTP", ] dependencies = [ "certifi", "cryptography>=42.0.0", "pylsqpack>=0.3.3,<0.4.0", "pyopenssl>=24", "service-identity>=24.1.0", ] dynamic = ["version"] [project.optional-dependencies] dev = [ "coverage[toml]>=7.2.2", ] [project.urls] Homepage = "https://github.com/aiortc/aioquic" Changelog = "https://aioquic.readthedocs.io/en/stable/changelog.html" Documentation = "https://aioquic.readthedocs.io/" [tool.coverage.run] source = ["aioquic"] [tool.mypy] disallow_untyped_calls = true disallow_untyped_decorators = true ignore_missing_imports = true strict_optional = false warn_redundant_casts = true warn_unused_ignores = true [tool.ruff.lint] select = [ "E", # pycodestyle "F", # Pyflakes "W", # pycodestyle "I", # isort ] [tool.setuptools.dynamic] version = {attr = "aioquic.__version__"} ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.121294 aioquic-1.2.0/requirements/0000755000175100001770000000000000000000000016475 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/requirements/doc.txt0000644000175100001770000000007100000000000020001 0ustar00runnerdocker00000000000000cryptography sphinx_autodoc_typehints sphinxcontrib-trio ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.121294 aioquic-1.2.0/scripts/0000755000175100001770000000000000000000000015441 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/scripts/fetch-vendor.json0000644000175100001770000000016200000000000020717 0ustar00runnerdocker00000000000000{ "urls": ["https://github.com/aiortc/aioquic-openssl/releases/download/3.3.0-1/openssl-{platform}.tar.gz"] } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/scripts/fetch-vendor.py0000644000175100001770000000411300000000000020376 0ustar00runnerdocker00000000000000import argparse import json import logging import os import platform import shutil import struct import subprocess def get_platform(): system = platform.system() machine = platform.machine() if system == "Linux": return f"manylinux_{machine}" elif system == "Darwin": # cibuildwheel sets ARCHFLAGS: # https://github.com/pypa/cibuildwheel/blob/5255155bc57eb6224354356df648dc42e31a0028/cibuildwheel/macos.py#L207-L220 if "ARCHFLAGS" in os.environ: machine = os.environ["ARCHFLAGS"].split()[1] return f"macosx_{machine}" elif system == "Windows": if struct.calcsize("P") * 8 == 64: return "win_amd64" else: return "win32" else: raise Exception(f"Unsupported system {system}") parser = argparse.ArgumentParser(description="Fetch and extract tarballs") parser.add_argument("destination_dir") parser.add_argument("--cache-dir", default="tarballs") parser.add_argument("--config-file", default=os.path.splitext(__file__)[0] + ".json") args = parser.parse_args() logging.basicConfig(level=logging.INFO) # read config file with open(args.config_file, "r") as fp: config = json.load(fp) # create fresh destination directory logging.info("Creating directory %s" % args.destination_dir) if os.path.exists(args.destination_dir): shutil.rmtree(args.destination_dir) os.makedirs(args.destination_dir) for url_template in config["urls"]: tarball_url = url_template.replace("{platform}", get_platform()) # download tarball tarball_name = tarball_url.split("/")[-1] tarball_file = os.path.join(args.cache_dir, tarball_name) if not os.path.exists(tarball_file): logging.info("Downloading %s" % tarball_url) if not os.path.exists(args.cache_dir): os.mkdir(args.cache_dir) subprocess.check_call( ["curl", "--location", "--output", tarball_file, "--silent", tarball_url] ) # extract tarball logging.info("Extracting %s" % tarball_name) subprocess.check_call(["tar", "-C", args.destination_dir, "-xf", tarball_file]) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1372943 aioquic-1.2.0/setup.cfg0000644000175100001770000000004600000000000015573 0ustar00runnerdocker00000000000000[egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/setup.py0000644000175100001770000000220700000000000015465 0ustar00runnerdocker00000000000000import sys import setuptools from wheel.bdist_wheel import bdist_wheel if sys.platform == "win32": extra_compile_args = [] libraries = ["libcrypto", "advapi32", "crypt32", "gdi32", "user32", "ws2_32"] else: extra_compile_args = ["-std=c99"] libraries = ["crypto"] class bdist_wheel_abi3(bdist_wheel): def get_tag(self): python, abi, plat = super().get_tag() if python.startswith("cp"): return "cp38", "abi3", plat return python, abi, plat setuptools.setup( ext_modules=[ setuptools.Extension( "aioquic._buffer", extra_compile_args=extra_compile_args, sources=["src/aioquic/_buffer.c"], define_macros=[("Py_LIMITED_API", "0x03080000")], py_limited_api=True, ), setuptools.Extension( "aioquic._crypto", extra_compile_args=extra_compile_args, libraries=libraries, sources=["src/aioquic/_crypto.c"], define_macros=[("Py_LIMITED_API", "0x03080000")], py_limited_api=True, ), ], cmdclass={"bdist_wheel": bdist_wheel_abi3}, ) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1132941 aioquic-1.2.0/src/0000755000175100001770000000000000000000000014541 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1720306888.121294 aioquic-1.2.0/src/aioquic/0000755000175100001770000000000000000000000016173 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/__init__.py0000644000175100001770000000002600000000000020302 0ustar00runnerdocker00000000000000__version__ = "1.2.0" ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/_buffer.c0000644000175100001770000002777200000000000017766 0ustar00runnerdocker00000000000000#define PY_SSIZE_T_CLEAN #include #include #define MODULE_NAME "aioquic._buffer" static PyObject *BufferReadError; static PyObject *BufferWriteError; typedef struct { PyObject_HEAD uint8_t *base; uint8_t *end; uint8_t *pos; } BufferObject; static PyObject *BufferType; #define CHECK_READ_BOUNDS(self, len) \ if (len < 0 || self->pos + len > self->end) { \ PyErr_SetString(BufferReadError, "Read out of bounds"); \ return NULL; \ } #define CHECK_WRITE_BOUNDS(self, len) \ if (self->pos + len > self->end) { \ PyErr_SetString(BufferWriteError, "Write out of bounds"); \ return NULL; \ } static int Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs) { const char *kwlist[] = {"capacity", "data", NULL}; Py_ssize_t capacity = 0; const unsigned char *data = NULL; Py_ssize_t data_len = 0; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ny#", (char**)kwlist, &capacity, &data, &data_len)) return -1; if (data != NULL) { self->base = malloc(data_len); self->end = self->base + data_len; memcpy(self->base, data, data_len); } else { self->base = malloc(capacity); self->end = self->base + capacity; } self->pos = self->base; return 0; } static void Buffer_dealloc(BufferObject *self) { free(self->base); PyTypeObject *tp = Py_TYPE(self); freefunc free = PyType_GetSlot(tp, Py_tp_free); free(self); Py_DECREF(tp); } static PyObject * Buffer_data_slice(BufferObject *self, PyObject *args) { Py_ssize_t start, stop; if (!PyArg_ParseTuple(args, "nn", &start, &stop)) return NULL; if (start < 0 || self->base + start > self->end || stop < 0 || self->base + stop > self->end || stop < start) { PyErr_SetString(BufferReadError, "Read out of bounds"); return NULL; } return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start)); } static PyObject * Buffer_eof(BufferObject *self, PyObject *args) { if (self->pos == self->end) Py_RETURN_TRUE; Py_RETURN_FALSE; } static PyObject * Buffer_pull_bytes(BufferObject *self, PyObject *args) { Py_ssize_t len; if (!PyArg_ParseTuple(args, "n", &len)) return NULL; CHECK_READ_BOUNDS(self, len); PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len); self->pos += len; return o; } static PyObject * Buffer_pull_uint8(BufferObject *self, PyObject *args) { CHECK_READ_BOUNDS(self, 1) return PyLong_FromUnsignedLong( (uint8_t)(*(self->pos++)) ); } static PyObject * Buffer_pull_uint16(BufferObject *self, PyObject *args) { CHECK_READ_BOUNDS(self, 2) uint16_t value = (uint16_t)(*(self->pos)) << 8 | (uint16_t)(*(self->pos + 1)); self->pos += 2; return PyLong_FromUnsignedLong(value); } static PyObject * Buffer_pull_uint32(BufferObject *self, PyObject *args) { CHECK_READ_BOUNDS(self, 4) uint32_t value = (uint32_t)(*(self->pos)) << 24 | (uint32_t)(*(self->pos + 1)) << 16 | (uint32_t)(*(self->pos + 2)) << 8 | (uint32_t)(*(self->pos + 3)); self->pos += 4; return PyLong_FromUnsignedLong(value); } static PyObject * Buffer_pull_uint64(BufferObject *self, PyObject *args) { CHECK_READ_BOUNDS(self, 8) uint64_t value = (uint64_t)(*(self->pos)) << 56 | (uint64_t)(*(self->pos + 1)) << 48 | (uint64_t)(*(self->pos + 2)) << 40 | (uint64_t)(*(self->pos + 3)) << 32 | (uint64_t)(*(self->pos + 4)) << 24 | (uint64_t)(*(self->pos + 5)) << 16 | (uint64_t)(*(self->pos + 6)) << 8 | (uint64_t)(*(self->pos + 7)); self->pos += 8; return PyLong_FromUnsignedLongLong(value); } static PyObject * Buffer_pull_uint_var(BufferObject *self, PyObject *args) { uint64_t value; CHECK_READ_BOUNDS(self, 1) switch (*(self->pos) >> 6) { case 0: value = *(self->pos++) & 0x3F; break; case 1: CHECK_READ_BOUNDS(self, 2) value = (uint16_t)(*(self->pos) & 0x3F) << 8 | (uint16_t)(*(self->pos + 1)); self->pos += 2; break; case 2: CHECK_READ_BOUNDS(self, 4) value = (uint32_t)(*(self->pos) & 0x3F) << 24 | (uint32_t)(*(self->pos + 1)) << 16 | (uint32_t)(*(self->pos + 2)) << 8 | (uint32_t)(*(self->pos + 3)); self->pos += 4; break; default: CHECK_READ_BOUNDS(self, 8) value = (uint64_t)(*(self->pos) & 0x3F) << 56 | (uint64_t)(*(self->pos + 1)) << 48 | (uint64_t)(*(self->pos + 2)) << 40 | (uint64_t)(*(self->pos + 3)) << 32 | (uint64_t)(*(self->pos + 4)) << 24 | (uint64_t)(*(self->pos + 5)) << 16 | (uint64_t)(*(self->pos + 6)) << 8 | (uint64_t)(*(self->pos + 7)); self->pos += 8; break; } return PyLong_FromUnsignedLongLong(value); } static PyObject * Buffer_push_bytes(BufferObject *self, PyObject *args) { const unsigned char *data; Py_ssize_t data_len; if (!PyArg_ParseTuple(args, "y#", &data, &data_len)) return NULL; CHECK_WRITE_BOUNDS(self, data_len) memcpy(self->pos, data, data_len); self->pos += data_len; Py_RETURN_NONE; } static PyObject * Buffer_push_uint8(BufferObject *self, PyObject *args) { uint8_t value; if (!PyArg_ParseTuple(args, "B", &value)) return NULL; CHECK_WRITE_BOUNDS(self, 1) *(self->pos++) = value; Py_RETURN_NONE; } static PyObject * Buffer_push_uint16(BufferObject *self, PyObject *args) { uint16_t value; if (!PyArg_ParseTuple(args, "H", &value)) return NULL; CHECK_WRITE_BOUNDS(self, 2) *(self->pos++) = (value >> 8); *(self->pos++) = value; Py_RETURN_NONE; } static PyObject * Buffer_push_uint32(BufferObject *self, PyObject *args) { uint32_t value; if (!PyArg_ParseTuple(args, "I", &value)) return NULL; CHECK_WRITE_BOUNDS(self, 4) *(self->pos++) = (value >> 24); *(self->pos++) = (value >> 16); *(self->pos++) = (value >> 8); *(self->pos++) = value; Py_RETURN_NONE; } static PyObject * Buffer_push_uint64(BufferObject *self, PyObject *args) { uint64_t value; if (!PyArg_ParseTuple(args, "K", &value)) return NULL; CHECK_WRITE_BOUNDS(self, 8) *(self->pos++) = (value >> 56); *(self->pos++) = (value >> 48); *(self->pos++) = (value >> 40); *(self->pos++) = (value >> 32); *(self->pos++) = (value >> 24); *(self->pos++) = (value >> 16); *(self->pos++) = (value >> 8); *(self->pos++) = value; Py_RETURN_NONE; } static PyObject * Buffer_push_uint_var(BufferObject *self, PyObject *args) { uint64_t value; if (!PyArg_ParseTuple(args, "K", &value)) return NULL; if (value <= 0x3F) { CHECK_WRITE_BOUNDS(self, 1) *(self->pos++) = value; Py_RETURN_NONE; } else if (value <= 0x3FFF) { CHECK_WRITE_BOUNDS(self, 2) *(self->pos++) = (value >> 8) | 0x40; *(self->pos++) = value; Py_RETURN_NONE; } else if (value <= 0x3FFFFFFF) { CHECK_WRITE_BOUNDS(self, 4) *(self->pos++) = (value >> 24) | 0x80; *(self->pos++) = (value >> 16); *(self->pos++) = (value >> 8); *(self->pos++) = value; Py_RETURN_NONE; } else if (value <= 0x3FFFFFFFFFFFFFFF) { CHECK_WRITE_BOUNDS(self, 8) *(self->pos++) = (value >> 56) | 0xC0; *(self->pos++) = (value >> 48); *(self->pos++) = (value >> 40); *(self->pos++) = (value >> 32); *(self->pos++) = (value >> 24); *(self->pos++) = (value >> 16); *(self->pos++) = (value >> 8); *(self->pos++) = value; Py_RETURN_NONE; } else { PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer"); return NULL; } } static PyObject * Buffer_seek(BufferObject *self, PyObject *args) { Py_ssize_t pos; if (!PyArg_ParseTuple(args, "n", &pos)) return NULL; if (pos < 0 || self->base + pos > self->end) { PyErr_SetString(BufferReadError, "Seek out of bounds"); return NULL; } self->pos = self->base + pos; Py_RETURN_NONE; } static PyObject * Buffer_tell(BufferObject *self, PyObject *args) { return PyLong_FromSsize_t(self->pos - self->base); } static PyMethodDef Buffer_methods[] = { {"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""}, {"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""}, {"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."}, {"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."}, {"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."}, {"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."}, {"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."}, {"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."}, {"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."}, {"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."}, {"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."}, {"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."}, {"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."}, {"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."}, {"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""}, {"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""}, {NULL} }; static PyObject* Buffer_capacity_getter(BufferObject* self, void *closure) { return PyLong_FromSsize_t(self->end - self->base); } static PyObject* Buffer_data_getter(BufferObject* self, void *closure) { return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base); } static PyGetSetDef Buffer_getset[] = { {"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL }, {"data", (getter) Buffer_data_getter, NULL, "", NULL }, {NULL} }; static PyType_Slot BufferType_slots[] = { {Py_tp_dealloc, Buffer_dealloc}, {Py_tp_methods, Buffer_methods}, {Py_tp_doc, "Buffer objects"}, {Py_tp_getset, Buffer_getset}, {Py_tp_init, Buffer_init}, {0, 0}, }; static PyType_Spec BufferType_spec = { MODULE_NAME ".Buffer", sizeof(BufferObject), 0, Py_TPFLAGS_DEFAULT, BufferType_slots }; static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, MODULE_NAME, /* m_name */ "Serialization utilities.", /* m_doc */ -1, /* m_size */ NULL, /* m_methods */ NULL, /* m_reload */ NULL, /* m_traverse */ NULL, /* m_clear */ NULL, /* m_free */ }; PyMODINIT_FUNC PyInit__buffer(void) { PyObject* m; m = PyModule_Create(&moduledef); if (m == NULL) return NULL; BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL); Py_INCREF(BufferReadError); PyModule_AddObject(m, "BufferReadError", BufferReadError); BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL); Py_INCREF(BufferWriteError); PyModule_AddObject(m, "BufferWriteError", BufferWriteError); BufferType = PyType_FromSpec(&BufferType_spec); if (BufferType == NULL) return NULL; PyModule_AddObject(m, "Buffer", BufferType); return m; } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/_buffer.pyi0000644000175100001770000000176600000000000020340 0ustar00runnerdocker00000000000000from typing import Optional class BufferReadError(ValueError): ... class BufferWriteError(ValueError): ... class Buffer: def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ... @property def capacity(self) -> int: ... @property def data(self) -> bytes: ... def data_slice(self, start: int, end: int) -> bytes: ... def eof(self) -> bool: ... def seek(self, pos: int) -> None: ... def tell(self) -> int: ... def pull_bytes(self, length: int) -> bytes: ... def pull_uint8(self) -> int: ... def pull_uint16(self) -> int: ... def pull_uint32(self) -> int: ... def pull_uint64(self) -> int: ... def pull_uint_var(self) -> int: ... def push_bytes(self, value: bytes) -> None: ... def push_uint8(self, value: int) -> None: ... def push_uint16(self, value: int) -> None: ... def push_uint32(self, v: int) -> None: ... def push_uint64(self, v: int) -> None: ... def push_uint_var(self, value: int) -> None: ... ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/_crypto.c0000644000175100001770000002747100000000000020031 0ustar00runnerdocker00000000000000#define PY_SSIZE_T_CLEAN #include #include #include #define MODULE_NAME "aioquic._crypto" #define AEAD_KEY_LENGTH_MAX 32 #define AEAD_NONCE_LENGTH 12 #define AEAD_TAG_LENGTH 16 #define PACKET_LENGTH_MAX 1500 #define PACKET_NUMBER_LENGTH_MAX 4 #define SAMPLE_LENGTH 16 #define CHECK_RESULT(expr) \ if (!(expr)) { \ ERR_clear_error(); \ PyErr_SetString(CryptoError, "OpenSSL call failed"); \ return NULL; \ } #define CHECK_RESULT_CTOR(expr) \ if (!(expr)) { \ ERR_clear_error(); \ PyErr_SetString(CryptoError, "OpenSSL call failed"); \ return -1; \ } static PyObject *CryptoError; /* AEAD */ typedef struct { PyObject_HEAD EVP_CIPHER_CTX *decrypt_ctx; EVP_CIPHER_CTX *encrypt_ctx; unsigned char buffer[PACKET_LENGTH_MAX]; unsigned char key[AEAD_KEY_LENGTH_MAX]; unsigned char iv[AEAD_NONCE_LENGTH]; unsigned char nonce[AEAD_NONCE_LENGTH]; } AEADObject; static PyObject *AEADType; static EVP_CIPHER_CTX * create_ctx(const EVP_CIPHER *cipher, int key_length, int operation) { EVP_CIPHER_CTX *ctx; int res; ctx = EVP_CIPHER_CTX_new(); CHECK_RESULT(ctx != 0); res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation); CHECK_RESULT(res != 0); res = EVP_CIPHER_CTX_set_key_length(ctx, key_length); CHECK_RESULT(res != 0); res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL); CHECK_RESULT(res != 0); return ctx; } static int AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs) { const char *cipher_name; const unsigned char *key, *iv; Py_ssize_t cipher_name_len, key_len, iv_len; if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len)) return -1; const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name); if (evp_cipher == 0) { PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name); return -1; } if (key_len > AEAD_KEY_LENGTH_MAX) { PyErr_SetString(CryptoError, "Invalid key length"); return -1; } if (iv_len > AEAD_NONCE_LENGTH) { PyErr_SetString(CryptoError, "Invalid iv length"); return -1; } memcpy(self->key, key, key_len); memcpy(self->iv, iv, iv_len); self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0); CHECK_RESULT_CTOR(self->decrypt_ctx != 0); self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1); CHECK_RESULT_CTOR(self->encrypt_ctx != 0); return 0; } static void AEAD_dealloc(AEADObject *self) { EVP_CIPHER_CTX_free(self->decrypt_ctx); EVP_CIPHER_CTX_free(self->encrypt_ctx); PyTypeObject *tp = Py_TYPE(self); freefunc free = PyType_GetSlot(tp, Py_tp_free); free(self); Py_DECREF(tp); } static PyObject* AEAD_decrypt(AEADObject *self, PyObject *args) { const unsigned char *data, *associated; Py_ssize_t data_len, associated_len; int outlen, outlen2, res; uint64_t pn; if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn)) return NULL; if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) { PyErr_SetString(CryptoError, "Invalid payload length"); return NULL; } memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH); for (int i = 0; i < 8; ++i) { self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i); } res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH))); CHECK_RESULT(res != 0); res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0); CHECK_RESULT(res != 0); res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len); CHECK_RESULT(res != 0); res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH); CHECK_RESULT(res != 0); res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2); if (res == 0) { PyErr_SetString(CryptoError, "Payload decryption failed"); return NULL; } return PyBytes_FromStringAndSize((const char*)self->buffer, outlen); } static PyObject* AEAD_encrypt(AEADObject *self, PyObject *args) { const unsigned char *data, *associated; Py_ssize_t data_len, associated_len; int outlen, outlen2, res; uint64_t pn; if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn)) return NULL; if (data_len > PACKET_LENGTH_MAX) { PyErr_SetString(CryptoError, "Invalid payload length"); return NULL; } memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH); for (int i = 0; i < 8; ++i) { self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i); } res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1); CHECK_RESULT(res != 0); res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len); CHECK_RESULT(res != 0); res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len); CHECK_RESULT(res != 0); res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2); CHECK_RESULT(res != 0 && outlen2 == 0); res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen); CHECK_RESULT(res != 0); return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH); } static PyMethodDef AEAD_methods[] = { {"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""}, {"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""}, {NULL} }; static PyType_Slot AEADType_slots[] = { {Py_tp_dealloc, AEAD_dealloc}, {Py_tp_methods, AEAD_methods}, {Py_tp_doc, "AEAD objects"}, {Py_tp_init, AEAD_init}, {0, 0}, }; static PyType_Spec AEADType_spec = { MODULE_NAME ".AEADType", sizeof(AEADObject), 0, Py_TPFLAGS_DEFAULT, AEADType_slots }; /* HeaderProtection */ typedef struct { PyObject_HEAD EVP_CIPHER_CTX *ctx; int is_chacha20; unsigned char buffer[PACKET_LENGTH_MAX]; unsigned char mask[31]; unsigned char zero[5]; } HeaderProtectionObject; static PyObject *HeaderProtectionType; static int HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs) { const char *cipher_name; const unsigned char *key; Py_ssize_t cipher_name_len, key_len; int res; if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len)) return -1; const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name); if (evp_cipher == 0) { PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name); return -1; } memset(self->mask, 0, sizeof(self->mask)); memset(self->zero, 0, sizeof(self->zero)); self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0; self->ctx = EVP_CIPHER_CTX_new(); CHECK_RESULT_CTOR(self->ctx != 0); res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1); CHECK_RESULT_CTOR(res != 0); res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len); CHECK_RESULT_CTOR(res != 0); res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1); CHECK_RESULT_CTOR(res != 0); return 0; } static void HeaderProtection_dealloc(HeaderProtectionObject *self) { EVP_CIPHER_CTX_free(self->ctx); PyTypeObject *tp = Py_TYPE(self); freefunc free = PyType_GetSlot(tp, Py_tp_free); free(self); Py_DECREF(tp); } static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample) { int outlen; if (self->is_chacha20) { return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) && EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero)); } else { return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH); } } static PyObject* HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args) { const unsigned char *header, *payload; Py_ssize_t header_len, payload_len; int res; if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len)) return NULL; int pn_length = (header[0] & 0x03) + 1; int pn_offset = header_len - pn_length; res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length); CHECK_RESULT(res != 0); memcpy(self->buffer, header, header_len); memcpy(self->buffer + header_len, payload, payload_len); if (self->buffer[0] & 0x80) { self->buffer[0] ^= self->mask[0] & 0x0F; } else { self->buffer[0] ^= self->mask[0] & 0x1F; } for (int i = 0; i < pn_length; ++i) { self->buffer[pn_offset + i] ^= self->mask[1 + i]; } return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len); } static PyObject* HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args) { const unsigned char *packet; Py_ssize_t packet_len; int pn_offset, res; if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset)) return NULL; res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX); CHECK_RESULT(res != 0); memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX); if (self->buffer[0] & 0x80) { self->buffer[0] ^= self->mask[0] & 0x0F; } else { self->buffer[0] ^= self->mask[0] & 0x1F; } int pn_length = (self->buffer[0] & 0x03) + 1; uint32_t pn_truncated = 0; for (int i = 0; i < pn_length; ++i) { self->buffer[pn_offset + i] ^= self->mask[1 + i]; pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8); } return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated); } static PyMethodDef HeaderProtection_methods[] = { {"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""}, {"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""}, {NULL} }; static PyType_Slot HeaderProtectionType_slots[] = { {Py_tp_dealloc, HeaderProtection_dealloc}, {Py_tp_methods, HeaderProtection_methods}, {Py_tp_doc, "HeaderProtection objects"}, {Py_tp_init, HeaderProtection_init}, {0, 0}, }; static PyType_Spec HeaderProtectionType_spec = { MODULE_NAME ".HeaderProtectionType", sizeof(HeaderProtectionObject), 0, Py_TPFLAGS_DEFAULT, HeaderProtectionType_slots }; static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, MODULE_NAME, /* m_name */ "Cryptography utilities.", /* m_doc */ -1, /* m_size */ NULL, /* m_methods */ NULL, /* m_reload */ NULL, /* m_traverse */ NULL, /* m_clear */ NULL, /* m_free */ }; PyMODINIT_FUNC PyInit__crypto(void) { PyObject* m; m = PyModule_Create(&moduledef); if (m == NULL) return NULL; CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL); Py_INCREF(CryptoError); PyModule_AddObject(m, "CryptoError", CryptoError); AEADType = PyType_FromSpec(&AEADType_spec); if (AEADType == NULL) return NULL; PyModule_AddObject(m, "AEAD", AEADType); HeaderProtectionType = PyType_FromSpec(&HeaderProtectionType_spec); if (HeaderProtectionType == NULL) return NULL; PyModule_AddObject(m, "HeaderProtection", HeaderProtectionType); // ensure required ciphers are initialised EVP_add_cipher(EVP_aes_128_ecb()); EVP_add_cipher(EVP_aes_128_gcm()); EVP_add_cipher(EVP_aes_256_ecb()); EVP_add_cipher(EVP_aes_256_gcm()); return m; } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/_crypto.pyi0000644000175100001770000000114200000000000020373 0ustar00runnerdocker00000000000000from typing import Tuple class AEAD: def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ... def decrypt( self, data: bytes, associated_data: bytes, packet_number: int ) -> bytes: ... def encrypt( self, data: bytes, associated_data: bytes, packet_number: int ) -> bytes: ... class CryptoError(ValueError): ... class HeaderProtection: def __init__(self, cipher_name: bytes, key: bytes): ... def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ... def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ... ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1252942 aioquic-1.2.0/src/aioquic/asyncio/0000755000175100001770000000000000000000000017640 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/asyncio/__init__.py0000644000175100001770000000017300000000000021752 0ustar00runnerdocker00000000000000from .client import connect # noqa from .protocol import QuicConnectionProtocol # noqa from .server import serve # noqa ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/asyncio/client.py0000644000175100001770000000731700000000000021500 0ustar00runnerdocker00000000000000import asyncio import socket from contextlib import asynccontextmanager from typing import AsyncGenerator, Callable, Optional, cast from ..quic.configuration import QuicConfiguration from ..quic.connection import QuicConnection, QuicTokenHandler from ..tls import SessionTicketHandler from .protocol import QuicConnectionProtocol, QuicStreamHandler __all__ = ["connect"] @asynccontextmanager async def connect( host: str, port: int, *, configuration: Optional[QuicConfiguration] = None, create_protocol: Optional[Callable] = QuicConnectionProtocol, session_ticket_handler: Optional[SessionTicketHandler] = None, stream_handler: Optional[QuicStreamHandler] = None, token_handler: Optional[QuicTokenHandler] = None, wait_connected: bool = True, local_port: int = 0, ) -> AsyncGenerator[QuicConnectionProtocol, None]: """ Connect to a QUIC server at the given `host` and `port`. :meth:`connect()` returns an awaitable. Awaiting it yields a :class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to create streams. :func:`connect` also accepts the following optional arguments: * ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration` configuration object. * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that manages the connection. It should be a callable or class accepting the same arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. * ``session_ticket_handler`` is a callback which is invoked by the TLS engine when a new session ticket is received. * ``stream_handler`` is a callback which is invoked whenever a stream is created. It must accept two arguments: a :class:`asyncio.StreamReader` and a :class:`asyncio.StreamWriter`. * ``wait_connected`` indicates whether the context manager should wait for the connection to be established before yielding the :class:`~aioquic.asyncio.QuicConnectionProtocol`. By default this is `True` but you can set it to `False` if you want to immediately start sending data using 0-RTT. * ``local_port`` is the UDP port number that this client wants to bind. """ loop = asyncio.get_event_loop() local_host = "::" # lookup remote address infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM) addr = infos[0][4] if len(addr) == 2: addr = ("::ffff:" + addr[0], addr[1], 0, 0) # prepare QUIC connection if configuration is None: configuration = QuicConfiguration(is_client=True) if configuration.server_name is None: configuration.server_name = host connection = QuicConnection( configuration=configuration, session_ticket_handler=session_ticket_handler, token_handler=token_handler, ) # explicitly enable IPv4/IPv6 dual stack sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) completed = False try: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) sock.bind((local_host, local_port, 0, 0)) completed = True finally: if not completed: sock.close() # connect transport, protocol = await loop.create_datagram_endpoint( lambda: create_protocol(connection, stream_handler=stream_handler), sock=sock, ) protocol = cast(QuicConnectionProtocol, protocol) try: protocol.connect(addr, transmit=wait_connected) if wait_connected: await protocol.wait_connected() yield protocol finally: protocol.close() await protocol.wait_closed() transport.close() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/asyncio/protocol.py0000644000175100001770000002316300000000000022060 0ustar00runnerdocker00000000000000import asyncio from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast from ..quic import events from ..quic.connection import NetworkAddress, QuicConnection from ..quic.packet import QuicErrorCode QuicConnectionIdHandler = Callable[[bytes], None] QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None] class QuicConnectionProtocol(asyncio.DatagramProtocol): def __init__( self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None ): loop = asyncio.get_event_loop() self._closed = asyncio.Event() self._connected = False self._connected_waiter: Optional[asyncio.Future[None]] = None self._loop = loop self._ping_waiters: Dict[int, asyncio.Future[None]] = {} self._quic = quic self._stream_readers: Dict[int, asyncio.StreamReader] = {} self._timer: Optional[asyncio.TimerHandle] = None self._timer_at: Optional[float] = None self._transmit_task: Optional[asyncio.Handle] = None self._transport: Optional[asyncio.DatagramTransport] = None # callbacks self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None self._connection_terminated_handler: Callable[[], None] = lambda: None if stream_handler is not None: self._stream_handler = stream_handler else: self._stream_handler = lambda r, w: None def change_connection_id(self) -> None: """ Change the connection ID used to communicate with the peer. The previous connection ID will be retired. """ self._quic.change_connection_id() self.transmit() def close( self, error_code: int = QuicErrorCode.NO_ERROR, reason_phrase: str = "", ) -> None: """ Close the connection. :param error_code: An error code indicating why the connection is being closed. :param reason_phrase: A human-readable explanation of why the connection is being closed. """ self._quic.close( error_code=error_code, reason_phrase=reason_phrase, ) self.transmit() def connect(self, addr: NetworkAddress, transmit=True) -> None: """ Initiate the TLS handshake. This method can only be called for clients and a single time. """ self._quic.connect(addr, now=self._loop.time()) if transmit: self.transmit() async def create_stream( self, is_unidirectional: bool = False ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ Create a QUIC stream and return a pair of (reader, writer) objects. The returned reader and writer objects are instances of :class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes. """ stream_id = self._quic.get_next_available_stream_id( is_unidirectional=is_unidirectional ) return self._create_stream(stream_id) def request_key_update(self) -> None: """ Request an update of the encryption keys. """ self._quic.request_key_update() self.transmit() async def ping(self) -> None: """ Ping the peer and wait for the response. """ waiter = self._loop.create_future() uid = id(waiter) self._ping_waiters[uid] = waiter self._quic.send_ping(uid) self.transmit() await asyncio.shield(waiter) def transmit(self) -> None: """ Send pending datagrams to the peer and arm the timer if needed. This method is called automatically when data is received from the peer or when a timer goes off. If you interact directly with the underlying :class:`~aioquic.quic.connection.QuicConnection`, make sure you call this method whenever data needs to be sent out to the network. """ self._transmit_task = None # send datagrams for data, addr in self._quic.datagrams_to_send(now=self._loop.time()): self._transport.sendto(data, addr) # re-arm timer timer_at = self._quic.get_timer() if self._timer is not None and self._timer_at != timer_at: self._timer.cancel() self._timer = None if self._timer is None and timer_at is not None: self._timer = self._loop.call_at(timer_at, self._handle_timer) self._timer_at = timer_at async def wait_closed(self) -> None: """ Wait for the connection to be closed. """ await self._closed.wait() async def wait_connected(self) -> None: """ Wait for the TLS handshake to complete. """ assert self._connected_waiter is None, "already awaiting connected" if not self._connected: self._connected_waiter = self._loop.create_future() await asyncio.shield(self._connected_waiter) # asyncio.Transport def connection_made(self, transport: asyncio.BaseTransport) -> None: """:meta private:""" self._transport = cast(asyncio.DatagramTransport, transport) def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: """:meta private:""" self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time()) self._process_events() self.transmit() # overridable def quic_event_received(self, event: events.QuicEvent) -> None: """ Called when a QUIC event is received. Reimplement this in your subclass to handle the events. """ # FIXME: move this to a subclass if isinstance(event, events.ConnectionTerminated): for reader in self._stream_readers.values(): reader.feed_eof() elif isinstance(event, events.StreamDataReceived): reader = self._stream_readers.get(event.stream_id, None) if reader is None: reader, writer = self._create_stream(event.stream_id) self._stream_handler(reader, writer) reader.feed_data(event.data) if event.end_stream: reader.feed_eof() # private def _create_stream( self, stream_id: int ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: adapter = QuicStreamAdapter(self, stream_id) reader = asyncio.StreamReader() protocol = asyncio.streams.StreamReaderProtocol(reader) writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop) self._stream_readers[stream_id] = reader return reader, writer def _handle_timer(self) -> None: now = max(self._timer_at, self._loop.time()) self._timer = None self._timer_at = None self._quic.handle_timer(now=now) self._process_events() self.transmit() def _process_events(self) -> None: event = self._quic.next_event() while event is not None: if isinstance(event, events.ConnectionIdIssued): self._connection_id_issued_handler(event.connection_id) elif isinstance(event, events.ConnectionIdRetired): self._connection_id_retired_handler(event.connection_id) elif isinstance(event, events.ConnectionTerminated): self._connection_terminated_handler() # abort connection waiter if self._connected_waiter is not None: waiter = self._connected_waiter self._connected_waiter = None waiter.set_exception(ConnectionError) # abort ping waiters for waiter in self._ping_waiters.values(): waiter.set_exception(ConnectionError) self._ping_waiters.clear() self._closed.set() elif isinstance(event, events.HandshakeCompleted): if self._connected_waiter is not None: waiter = self._connected_waiter self._connected = True self._connected_waiter = None waiter.set_result(None) elif isinstance(event, events.PingAcknowledged): waiter = self._ping_waiters.pop(event.uid, None) if waiter is not None: waiter.set_result(None) self.quic_event_received(event) event = self._quic.next_event() def _transmit_soon(self) -> None: if self._transmit_task is None: self._transmit_task = self._loop.call_soon(self.transmit) class QuicStreamAdapter(asyncio.Transport): def __init__(self, protocol: QuicConnectionProtocol, stream_id: int): self.protocol = protocol self.stream_id = stream_id self._closing = False def can_write_eof(self) -> bool: return True def get_extra_info(self, name: str, default: Any = None) -> Any: """ Get information about the underlying QUIC stream. """ if name == "stream_id": return self.stream_id def write(self, data): self.protocol._quic.send_stream_data(self.stream_id, data) self.protocol._transmit_soon() def write_eof(self): if self._closing: return self._closing = True self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True) self.protocol._transmit_soon() def close(self): self.write_eof() def is_closing(self) -> bool: return self._closing ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/asyncio/server.py0000644000175100001770000002036000000000000021521 0ustar00runnerdocker00000000000000import asyncio import os from functools import partial from typing import Callable, Dict, Optional, Text, Union, cast from ..buffer import Buffer from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration from ..quic.connection import NetworkAddress, QuicConnection from ..quic.packet import ( QuicPacketType, encode_quic_retry, encode_quic_version_negotiation, pull_quic_header, ) from ..quic.retry import QuicRetryTokenHandler from ..tls import SessionTicketFetcher, SessionTicketHandler from .protocol import QuicConnectionProtocol, QuicStreamHandler __all__ = ["serve"] class QuicServer(asyncio.DatagramProtocol): def __init__( self, *, configuration: QuicConfiguration, create_protocol: Callable = QuicConnectionProtocol, session_ticket_fetcher: Optional[SessionTicketFetcher] = None, session_ticket_handler: Optional[SessionTicketHandler] = None, retry: bool = False, stream_handler: Optional[QuicStreamHandler] = None, ) -> None: self._configuration = configuration self._create_protocol = create_protocol self._loop = asyncio.get_event_loop() self._protocols: Dict[bytes, QuicConnectionProtocol] = {} self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler self._transport: Optional[asyncio.DatagramTransport] = None self._stream_handler = stream_handler if retry: self._retry = QuicRetryTokenHandler() else: self._retry = None def close(self): for protocol in set(self._protocols.values()): protocol.close() self._protocols.clear() self._transport.close() def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.DatagramTransport, transport) def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: data = cast(bytes, data) buf = Buffer(data=data) try: header = pull_quic_header( buf, host_cid_length=self._configuration.connection_id_length ) except ValueError: return # version negotiation if ( header.version is not None and header.version not in self._configuration.supported_versions ): self._transport.sendto( encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=self._configuration.supported_versions, ), addr, ) return protocol = self._protocols.get(header.destination_cid, None) original_destination_connection_id: Optional[bytes] = None retry_source_connection_id: Optional[bytes] = None if ( protocol is None and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE and header.packet_type == QuicPacketType.INITIAL ): # retry if self._retry is not None: if not header.token: # create a retry token source_cid = os.urandom(8) self._transport.sendto( encode_quic_retry( version=header.version, source_cid=source_cid, destination_cid=header.source_cid, original_destination_cid=header.destination_cid, retry_token=self._retry.create_token( addr, header.destination_cid, source_cid ), ), addr, ) return else: # validate retry token try: ( original_destination_connection_id, retry_source_connection_id, ) = self._retry.validate_token(addr, header.token) except ValueError: return else: original_destination_connection_id = header.destination_cid # create new connection connection = QuicConnection( configuration=self._configuration, original_destination_connection_id=original_destination_connection_id, retry_source_connection_id=retry_source_connection_id, session_ticket_fetcher=self._session_ticket_fetcher, session_ticket_handler=self._session_ticket_handler, ) protocol = self._create_protocol( connection, stream_handler=self._stream_handler ) protocol.connection_made(self._transport) # register callbacks protocol._connection_id_issued_handler = partial( self._connection_id_issued, protocol=protocol ) protocol._connection_id_retired_handler = partial( self._connection_id_retired, protocol=protocol ) protocol._connection_terminated_handler = partial( self._connection_terminated, protocol=protocol ) self._protocols[header.destination_cid] = protocol self._protocols[connection.host_cid] = protocol if protocol is not None: protocol.datagram_received(data, addr) def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): self._protocols[cid] = protocol def _connection_id_retired( self, cid: bytes, protocol: QuicConnectionProtocol ) -> None: assert self._protocols[cid] == protocol del self._protocols[cid] def _connection_terminated(self, protocol: QuicConnectionProtocol): for cid, proto in list(self._protocols.items()): if proto == protocol: del self._protocols[cid] async def serve( host: str, port: int, *, configuration: QuicConfiguration, create_protocol: Callable = QuicConnectionProtocol, session_ticket_fetcher: Optional[SessionTicketFetcher] = None, session_ticket_handler: Optional[SessionTicketHandler] = None, retry: bool = False, stream_handler: QuicStreamHandler = None, ) -> QuicServer: """ Start a QUIC server at the given `host` and `port`. :func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration` containing TLS certificate and private key as the ``configuration`` argument. :func:`serve` also accepts the following optional arguments: * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that manages the connection. It should be a callable or class accepting the same arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. * ``session_ticket_fetcher`` is a callback which is invoked by the TLS engine when a session ticket is presented by the peer. It should return the session ticket with the specified ID or `None` if it is not found. * ``session_ticket_handler`` is a callback which is invoked by the TLS engine when a new session ticket is issued. It should store the session ticket for future lookup. * ``retry`` specifies whether client addresses should be validated prior to the cryptographic handshake using a retry packet. * ``stream_handler`` is a callback which is invoked whenever a stream is created. It must accept two arguments: a :class:`asyncio.StreamReader` and a :class:`asyncio.StreamWriter`. """ loop = asyncio.get_event_loop() _, protocol = await loop.create_datagram_endpoint( lambda: QuicServer( configuration=configuration, create_protocol=create_protocol, session_ticket_fetcher=session_ticket_fetcher, session_ticket_handler=session_ticket_handler, retry=retry, stream_handler=stream_handler, ), local_addr=(host, port), ) return protocol ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/buffer.py0000644000175100001770000000140200000000000020013 0ustar00runnerdocker00000000000000from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF UINT_VAR_MAX_SIZE = 8 def encode_uint_var(value: int) -> bytes: """ Encode a variable-length unsigned integer. """ buf = Buffer(capacity=UINT_VAR_MAX_SIZE) buf.push_uint_var(value) return buf.data def size_uint_var(value: int) -> int: """ Return the number of bytes required to encode the given value as a QUIC variable-length unsigned integer. """ if value <= 0x3F: return 1 elif value <= 0x3FFF: return 2 elif value <= 0x3FFFFFFF: return 4 elif value <= 0x3FFFFFFFFFFFFFFF: return 8 else: raise ValueError("Integer is too big for a variable-length integer") ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1252942 aioquic-1.2.0/src/aioquic/h0/0000755000175100001770000000000000000000000016502 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h0/__init__.py0000644000175100001770000000000000000000000020601 0ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h0/connection.py0000644000175100001770000000477600000000000021231 0ustar00runnerdocker00000000000000from typing import Dict, List from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived from aioquic.quic.connection import QuicConnection from aioquic.quic.events import QuicEvent, StreamDataReceived H0_ALPN = ["hq-interop"] class H0Connection: """ An HTTP/0.9 connection object. """ def __init__(self, quic: QuicConnection): self._buffer: Dict[int, bytes] = {} self._headers_received: Dict[int, bool] = {} self._is_client = quic.configuration.is_client self._quic = quic def handle_event(self, event: QuicEvent) -> List[H3Event]: http_events: List[H3Event] = [] if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0: data = self._buffer.pop(event.stream_id, b"") + event.data if not self._headers_received.get(event.stream_id, False): if self._is_client: http_events.append( HeadersReceived( headers=[], stream_ended=False, stream_id=event.stream_id ) ) elif data.endswith(b"\r\n") or event.end_stream: method, path = data.rstrip().split(b" ", 1) http_events.append( HeadersReceived( headers=[(b":method", method), (b":path", path)], stream_ended=False, stream_id=event.stream_id, ) ) data = b"" else: # incomplete request, stash the data self._buffer[event.stream_id] = data return http_events self._headers_received[event.stream_id] = True http_events.append( DataReceived( data=data, stream_ended=event.end_stream, stream_id=event.stream_id ) ) return http_events def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None: self._quic.send_stream_data(stream_id, data, end_stream) def send_headers( self, stream_id: int, headers: Headers, end_stream: bool = False ) -> None: if self._is_client: headers_dict = dict(headers) data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n" else: data = b"" self._quic.send_stream_data(stream_id, data, end_stream) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1252942 aioquic-1.2.0/src/aioquic/h3/0000755000175100001770000000000000000000000016505 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h3/__init__.py0000644000175100001770000000000000000000000020604 0ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h3/connection.py0000644000175100001770000012434600000000000021230 0ustar00runnerdocker00000000000000import logging import re from enum import Enum, IntEnum from typing import Dict, FrozenSet, List, Optional, Set import pylsqpack from aioquic.buffer import UINT_VAR_MAX_SIZE, Buffer, BufferReadError, encode_uint_var from aioquic.h3.events import ( DatagramReceived, DataReceived, H3Event, Headers, HeadersReceived, PushPromiseReceived, WebTransportStreamDataReceived, ) from aioquic.h3.exceptions import InvalidStreamTypeError, NoAvailablePushIDError from aioquic.quic.connection import QuicConnection, stream_is_unidirectional from aioquic.quic.events import DatagramFrameReceived, QuicEvent, StreamDataReceived from aioquic.quic.logger import QuicLoggerTrace logger = logging.getLogger("http3") H3_ALPN = ["h3"] RESERVED_SETTINGS = (0x0, 0x2, 0x3, 0x4, 0x5) UPPERCASE = re.compile(b"[A-Z]") COLON = 0x3A NUL = 0x00 LF = 0x0A CR = 0x0D SP = 0x20 HTAB = 0x09 WHITESPACE = (SP, HTAB) class ErrorCode(IntEnum): H3_DATAGRAM_ERROR = 0x33 H3_NO_ERROR = 0x100 H3_GENERAL_PROTOCOL_ERROR = 0x101 H3_INTERNAL_ERROR = 0x102 H3_STREAM_CREATION_ERROR = 0x103 H3_CLOSED_CRITICAL_STREAM = 0x104 H3_FRAME_UNEXPECTED = 0x105 H3_FRAME_ERROR = 0x106 H3_EXCESSIVE_LOAD = 0x107 H3_ID_ERROR = 0x108 H3_SETTINGS_ERROR = 0x109 H3_MISSING_SETTINGS = 0x10A H3_REQUEST_REJECTED = 0x10B H3_REQUEST_CANCELLED = 0x10C H3_REQUEST_INCOMPLETE = 0x10D H3_MESSAGE_ERROR = 0x10E H3_CONNECT_ERROR = 0x10F H3_VERSION_FALLBACK = 0x110 QPACK_DECOMPRESSION_FAILED = 0x200 QPACK_ENCODER_STREAM_ERROR = 0x201 QPACK_DECODER_STREAM_ERROR = 0x202 class FrameType(IntEnum): DATA = 0x0 HEADERS = 0x1 PRIORITY = 0x2 CANCEL_PUSH = 0x3 SETTINGS = 0x4 PUSH_PROMISE = 0x5 GOAWAY = 0x7 MAX_PUSH_ID = 0xD DUPLICATE_PUSH = 0xE WEBTRANSPORT_STREAM = 0x41 class HeadersState(Enum): INITIAL = 0 AFTER_HEADERS = 1 AFTER_TRAILERS = 2 class Setting(IntEnum): QPACK_MAX_TABLE_CAPACITY = 0x1 MAX_FIELD_SECTION_SIZE = 0x6 QPACK_BLOCKED_STREAMS = 0x7 # https://datatracker.ietf.org/doc/html/rfc9220#section-5 ENABLE_CONNECT_PROTOCOL = 0x8 # https://datatracker.ietf.org/doc/html/rfc9297#section-5.1 H3_DATAGRAM = 0x33 # https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http2-02#section-10.1 ENABLE_WEBTRANSPORT = 0x2B603742 # Dummy setting to check it is correctly ignored by the peer. # https://datatracker.ietf.org/doc/html/rfc9114#section-7.2.4.1 DUMMY = 0x21 class StreamType(IntEnum): CONTROL = 0 PUSH = 1 QPACK_ENCODER = 2 QPACK_DECODER = 3 WEBTRANSPORT = 0x54 class ProtocolError(Exception): """ Base class for protocol errors. These errors are not exposed to the API user, they are handled in :meth:`H3Connection.handle_event`. """ error_code = ErrorCode.H3_GENERAL_PROTOCOL_ERROR def __init__(self, reason_phrase: str = ""): self.reason_phrase = reason_phrase class QpackDecompressionFailed(ProtocolError): error_code = ErrorCode.QPACK_DECOMPRESSION_FAILED class QpackDecoderStreamError(ProtocolError): error_code = ErrorCode.QPACK_DECODER_STREAM_ERROR class QpackEncoderStreamError(ProtocolError): error_code = ErrorCode.QPACK_ENCODER_STREAM_ERROR class ClosedCriticalStream(ProtocolError): error_code = ErrorCode.H3_CLOSED_CRITICAL_STREAM class DatagramError(ProtocolError): error_code = ErrorCode.H3_DATAGRAM_ERROR class FrameUnexpected(ProtocolError): error_code = ErrorCode.H3_FRAME_UNEXPECTED class MessageError(ProtocolError): error_code = ErrorCode.H3_MESSAGE_ERROR class MissingSettingsError(ProtocolError): error_code = ErrorCode.H3_MISSING_SETTINGS class SettingsError(ProtocolError): error_code = ErrorCode.H3_SETTINGS_ERROR class StreamCreationError(ProtocolError): error_code = ErrorCode.H3_STREAM_CREATION_ERROR def encode_frame(frame_type: int, frame_data: bytes) -> bytes: frame_length = len(frame_data) buf = Buffer(capacity=frame_length + 2 * UINT_VAR_MAX_SIZE) buf.push_uint_var(frame_type) buf.push_uint_var(frame_length) buf.push_bytes(frame_data) return buf.data def encode_settings(settings: Dict[int, int]) -> bytes: buf = Buffer(capacity=1024) for setting, value in settings.items(): buf.push_uint_var(setting) buf.push_uint_var(value) return buf.data def parse_max_push_id(data: bytes) -> int: buf = Buffer(data=data) max_push_id = buf.pull_uint_var() assert buf.eof() return max_push_id def parse_settings(data: bytes) -> Dict[int, int]: buf = Buffer(data=data) settings: Dict[int, int] = {} while not buf.eof(): setting = buf.pull_uint_var() value = buf.pull_uint_var() if setting in RESERVED_SETTINGS: raise SettingsError("Setting identifier 0x%x is reserved" % setting) if setting in settings: raise SettingsError("Setting identifier 0x%x is included twice" % setting) settings[setting] = value return dict(settings) def stream_is_request_response(stream_id: int): """ Returns True if the stream is a client-initiated bidirectional stream. """ return stream_id % 4 == 0 def validate_header_name(key: bytes) -> None: """ Validate a header name as specified by RFC 9113 section 8.2.1. """ for i, c in enumerate(key): if c <= 0x20 or (c >= 0x41 and c <= 0x5A) or c >= 0x7F: raise MessageError("Header %r contains invalid characters" % key) if c == COLON and i != 0: # Colon not at start, definitely bad. Keys starting with a colon # will be checked in pseudo-header validation code. raise MessageError("Header %r contains a non-initial colon" % key) def validate_header_value(key: bytes, value: bytes): """ Validate a header value as specified by RFC 9113 section 8.2.1. """ for c in value: if c == NUL or c == LF or c == CR: raise MessageError("Header %r value has forbidden characters" % key) if len(value) > 0: first = value[0] if first in WHITESPACE: raise MessageError("Header %r value starts with whitespace" % key) if len(value) > 1: last = value[-1] if last in WHITESPACE: raise MessageError("Header %r value ends with whitespace" % key) def validate_headers( headers: Headers, allowed_pseudo_headers: FrozenSet[bytes], required_pseudo_headers: FrozenSet[bytes], stream: Optional["H3Stream"] = None, ) -> None: after_pseudo_headers = False authority: Optional[bytes] = None path: Optional[bytes] = None scheme: Optional[bytes] = None seen_pseudo_headers: Set[bytes] = set() for key, value in headers: validate_header_name(key) validate_header_value(key, value) if key.startswith(b":"): # pseudo-headers if after_pseudo_headers: raise MessageError( "Pseudo-header %r is not allowed after regular headers" % key ) if key not in allowed_pseudo_headers: raise MessageError("Pseudo-header %r is not valid" % key) if key in seen_pseudo_headers: raise MessageError("Pseudo-header %r is included twice" % key) seen_pseudo_headers.add(key) # store value if key == b":authority": authority = value elif key == b":path": path = value elif key == b":scheme": scheme = value else: # regular headers after_pseudo_headers = True # a few more semantic checks if key == b"content-length": try: content_length = int(value) if content_length < 0: raise ValueError except ValueError: raise MessageError("content-length is not a non-negative integer") if stream: stream.expected_content_length = content_length elif key == b"transfer-encoding" and value != b"trailers": raise MessageError( "The only valid value for transfer-encoding is trailers" ) # check required pseudo-headers are present missing = required_pseudo_headers.difference(seen_pseudo_headers) if missing: raise MessageError("Pseudo-headers %s are missing" % sorted(missing)) if scheme in (b"http", b"https"): if not authority: raise MessageError("Pseudo-header b':authority' cannot be empty") if not path: raise MessageError("Pseudo-header b':path' cannot be empty") def validate_push_promise_headers(headers: Headers) -> None: validate_headers( headers, allowed_pseudo_headers=frozenset( (b":method", b":scheme", b":authority", b":path") ), required_pseudo_headers=frozenset( (b":method", b":scheme", b":authority", b":path") ), ) def validate_request_headers( headers: Headers, stream: Optional["H3Stream"] = None ) -> None: validate_headers( headers, allowed_pseudo_headers=frozenset( # FIXME: The pseudo-header :protocol is not actually defined, but # we use it for the WebSocket demo. (b":method", b":scheme", b":authority", b":path", b":protocol") ), required_pseudo_headers=frozenset((b":method", b":authority")), stream=stream, ) def validate_response_headers( headers: Headers, stream: Optional["H3Stream"] = None ) -> None: validate_headers( headers, allowed_pseudo_headers=frozenset((b":status",)), required_pseudo_headers=frozenset((b":status",)), stream=stream, ) def validate_trailers(headers: Headers) -> None: validate_headers( headers, allowed_pseudo_headers=frozenset(), required_pseudo_headers=frozenset(), ) class H3Stream: def __init__(self, stream_id: int) -> None: self.blocked = False self.blocked_frame_size: Optional[int] = None self.buffer = b"" self.ended = False self.frame_size: Optional[int] = None self.frame_type: Optional[int] = None self.headers_recv_state: HeadersState = HeadersState.INITIAL self.headers_send_state: HeadersState = HeadersState.INITIAL self.push_id: Optional[int] = None self.session_id: Optional[int] = None self.stream_id = stream_id self.stream_type: Optional[int] = None self.expected_content_length: Optional[int] = None self.content_length: int = 0 class H3Connection: """ A low-level HTTP/3 connection object. :param quic: A :class:`~aioquic.quic.connection.QuicConnection` instance. """ def __init__(self, quic: QuicConnection, enable_webtransport: bool = False) -> None: # settings self._max_table_capacity = 4096 self._blocked_streams = 16 self._enable_webtransport = enable_webtransport self._is_client = quic.configuration.is_client self._is_done = False self._quic = quic self._quic_logger: Optional[QuicLoggerTrace] = quic._quic_logger self._decoder = pylsqpack.Decoder( self._max_table_capacity, self._blocked_streams ) self._decoder_bytes_received = 0 self._decoder_bytes_sent = 0 self._encoder = pylsqpack.Encoder() self._encoder_bytes_received = 0 self._encoder_bytes_sent = 0 self._settings_received = False self._stream: Dict[int, H3Stream] = {} self._max_push_id: Optional[int] = 8 if self._is_client else None self._next_push_id: int = 0 self._local_control_stream_id: Optional[int] = None self._local_decoder_stream_id: Optional[int] = None self._local_encoder_stream_id: Optional[int] = None self._peer_control_stream_id: Optional[int] = None self._peer_decoder_stream_id: Optional[int] = None self._peer_encoder_stream_id: Optional[int] = None self._received_settings: Optional[Dict[int, int]] = None self._sent_settings: Optional[Dict[int, int]] = None self._init_connection() def create_webtransport_stream( self, session_id: int, is_unidirectional: bool = False ) -> int: """ Create a WebTransport stream and return the stream ID. .. aioquic_transmit:: :param session_id: The WebTransport session identifier. :param is_unidirectional: Whether to create a unidirectional stream. """ if is_unidirectional: stream_id = self._create_uni_stream(StreamType.WEBTRANSPORT) self._quic.send_stream_data(stream_id, encode_uint_var(session_id)) else: stream_id = self._quic.get_next_available_stream_id() self._log_stream_type( stream_id=stream_id, stream_type=StreamType.WEBTRANSPORT ) self._quic.send_stream_data( stream_id, encode_uint_var(FrameType.WEBTRANSPORT_STREAM) + encode_uint_var(session_id), ) return stream_id def handle_event(self, event: QuicEvent) -> List[H3Event]: """ Handle a QUIC event and return a list of HTTP events. :param event: The QUIC event to handle. """ if not self._is_done: try: if isinstance(event, StreamDataReceived): stream_id = event.stream_id stream = self._get_or_create_stream(stream_id) if stream_is_unidirectional(stream_id): return self._receive_stream_data_uni( stream, event.data, event.end_stream ) else: return self._receive_request_or_push_data( stream, event.data, event.end_stream ) elif isinstance(event, DatagramFrameReceived): return self._receive_datagram(event.data) except ProtocolError as exc: self._is_done = True self._quic.close( error_code=exc.error_code, reason_phrase=exc.reason_phrase ) return [] def send_datagram(self, stream_id: int, data: bytes) -> None: """ Send a datagram for the specified stream. If the stream ID is not a client-initiated bidirectional stream, an :class:`~aioquic.h3.exceptions.InvalidStreamTypeError` exception is raised. .. aioquic_transmit:: :param stream_id: The stream ID. :param data: The HTTP/3 datagram payload. """ # check stream ID is valid if not stream_is_request_response(stream_id): raise InvalidStreamTypeError( "Datagrams can only be sent for client-initiated bidirectional streams" ) self._quic.send_datagram_frame(encode_uint_var(stream_id // 4) + data) def send_push_promise(self, stream_id: int, headers: Headers) -> int: """ Send a push promise related to the specified stream. Returns the stream ID on which headers and data can be sent. If the stream ID is not a client-initiated bidirectional stream, an :class:`~aioquic.h3.exceptions.InvalidStreamTypeError` exception is raised. If there are not available push IDs, an :class:`~aioquic.h3.exceptions.NoAvailablePushIDError` exception is raised. .. aioquic_transmit:: :param stream_id: The stream ID on which to send the data. :param headers: The HTTP request headers for this push. """ assert not self._is_client, "Only servers may send a push promise." # check stream ID is valid if not stream_is_request_response(stream_id): raise InvalidStreamTypeError( "Push promises can only be sent for client-initiated bidirectional " "streams" ) # check a push ID is available if self._max_push_id is None or self._next_push_id >= self._max_push_id: raise NoAvailablePushIDError # send push promise push_id = self._next_push_id self._next_push_id += 1 self._quic.send_stream_data( stream_id, encode_frame( FrameType.PUSH_PROMISE, encode_uint_var(push_id) + self._encode_headers(stream_id, headers), ), ) #  create push stream push_stream_id = self._create_uni_stream(StreamType.PUSH, push_id=push_id) self._quic.send_stream_data(push_stream_id, encode_uint_var(push_id)) return push_stream_id def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None: """ Send data on the given stream. .. aioquic_transmit:: :param stream_id: The stream ID on which to send the data. :param data: The data to send. :param end_stream: Whether to end the stream. """ # check DATA frame is allowed stream = self._get_or_create_stream(stream_id) if stream.headers_send_state != HeadersState.AFTER_HEADERS: raise FrameUnexpected("DATA frame is not allowed in this state") # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_created", data=self._quic_logger.encode_http3_data_frame( length=len(data), stream_id=stream_id ), ) self._quic.send_stream_data( stream_id, encode_frame(FrameType.DATA, data), end_stream ) def send_headers( self, stream_id: int, headers: Headers, end_stream: bool = False ) -> None: """ Send headers on the given stream. .. aioquic_transmit:: :param stream_id: The stream ID on which to send the headers. :param headers: The HTTP headers to send. :param end_stream: Whether to end the stream. """ # check HEADERS frame is allowed stream = self._get_or_create_stream(stream_id) if stream.headers_send_state == HeadersState.AFTER_TRAILERS: raise FrameUnexpected("HEADERS frame is not allowed in this state") frame_data = self._encode_headers(stream_id, headers) # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_created", data=self._quic_logger.encode_http3_headers_frame( length=len(frame_data), headers=headers, stream_id=stream_id ), ) # update state and send headers if stream.headers_send_state == HeadersState.INITIAL: stream.headers_send_state = HeadersState.AFTER_HEADERS else: stream.headers_send_state = HeadersState.AFTER_TRAILERS self._quic.send_stream_data( stream_id, encode_frame(FrameType.HEADERS, frame_data), end_stream ) @property def received_settings(self) -> Optional[Dict[int, int]]: """ Return the received SETTINGS frame, or None. """ return self._received_settings @property def sent_settings(self) -> Optional[Dict[int, int]]: """ Return the sent SETTINGS frame, or None. """ return self._sent_settings def _create_uni_stream( self, stream_type: int, push_id: Optional[int] = None ) -> int: """ Create an unidirectional stream of the given type. """ stream_id = self._quic.get_next_available_stream_id(is_unidirectional=True) self._log_stream_type( push_id=push_id, stream_id=stream_id, stream_type=stream_type ) self._quic.send_stream_data(stream_id, encode_uint_var(stream_type)) return stream_id def _decode_headers(self, stream_id: int, frame_data: Optional[bytes]) -> Headers: """ Decode a HEADERS block and send decoder updates on the decoder stream. This is called with frame_data=None when a stream becomes unblocked. """ try: if frame_data is None: decoder, headers = self._decoder.resume_header(stream_id) else: decoder, headers = self._decoder.feed_header(stream_id, frame_data) self._decoder_bytes_sent += len(decoder) self._quic.send_stream_data(self._local_decoder_stream_id, decoder) except pylsqpack.DecompressionFailed as exc: raise QpackDecompressionFailed() from exc return headers def _encode_headers(self, stream_id: int, headers: Headers) -> bytes: """ Encode a HEADERS block and send encoder updates on the encoder stream. """ encoder, frame_data = self._encoder.encode(stream_id, headers) self._encoder_bytes_sent += len(encoder) self._quic.send_stream_data(self._local_encoder_stream_id, encoder) return frame_data def _get_or_create_stream(self, stream_id: int) -> H3Stream: if stream_id not in self._stream: self._stream[stream_id] = H3Stream(stream_id) return self._stream[stream_id] def _get_local_settings(self) -> Dict[int, int]: """ Return the local HTTP/3 settings. """ settings: Dict[int, int] = { Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity, Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams, Setting.ENABLE_CONNECT_PROTOCOL: 1, Setting.DUMMY: 1, } if self._enable_webtransport: settings[Setting.H3_DATAGRAM] = 1 settings[Setting.ENABLE_WEBTRANSPORT] = 1 return settings def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None: """ Handle a frame received on the peer's control stream. """ if frame_type != FrameType.SETTINGS and not self._settings_received: raise MissingSettingsError if frame_type == FrameType.SETTINGS: if self._settings_received: raise FrameUnexpected("SETTINGS have already been received") settings = parse_settings(frame_data) self._validate_settings(settings) self._received_settings = settings encoder = self._encoder.apply_settings( max_table_capacity=settings.get(Setting.QPACK_MAX_TABLE_CAPACITY, 0), blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0), ) self._quic.send_stream_data(self._local_encoder_stream_id, encoder) self._settings_received = True elif frame_type == FrameType.MAX_PUSH_ID: if self._is_client: raise FrameUnexpected("Servers must not send MAX_PUSH_ID") self._max_push_id = parse_max_push_id(frame_data) elif frame_type in ( FrameType.DATA, FrameType.HEADERS, FrameType.PUSH_PROMISE, FrameType.DUPLICATE_PUSH, ): raise FrameUnexpected("Invalid frame type on control stream") def _check_content_length(self, stream: H3Stream): if ( stream.expected_content_length is not None and stream.content_length != stream.expected_content_length ): raise MessageError("content-length does not match data size") def _handle_request_or_push_frame( self, frame_type: int, frame_data: Optional[bytes], stream: H3Stream, stream_ended: bool, ) -> List[H3Event]: """ Handle a frame received on a request or push stream. """ http_events: List[H3Event] = [] if frame_type == FrameType.DATA: # check DATA frame is allowed if stream.headers_recv_state != HeadersState.AFTER_HEADERS: raise FrameUnexpected("DATA frame is not allowed in this state") if frame_data is not None: stream.content_length += len(frame_data) if stream_ended: self._check_content_length(stream) if stream_ended or frame_data: http_events.append( DataReceived( data=frame_data, push_id=stream.push_id, stream_ended=stream_ended, stream_id=stream.stream_id, ) ) elif frame_type == FrameType.HEADERS: # check HEADERS frame is allowed if stream.headers_recv_state == HeadersState.AFTER_TRAILERS: raise FrameUnexpected("HEADERS frame is not allowed in this state") # try to decode HEADERS, may raise pylsqpack.StreamBlocked headers = self._decode_headers(stream.stream_id, frame_data) # validate headers if stream.headers_recv_state == HeadersState.INITIAL: if self._is_client: validate_response_headers(headers, stream) else: validate_request_headers(headers, stream) else: validate_trailers(headers) # content-length needs checking even when there is no data if stream_ended: self._check_content_length(stream) # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_parsed", data=self._quic_logger.encode_http3_headers_frame( length=( stream.blocked_frame_size if frame_data is None else len(frame_data) ), headers=headers, stream_id=stream.stream_id, ), ) # update state and emit headers if stream.headers_recv_state == HeadersState.INITIAL: stream.headers_recv_state = HeadersState.AFTER_HEADERS else: stream.headers_recv_state = HeadersState.AFTER_TRAILERS http_events.append( HeadersReceived( headers=headers, push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=stream_ended, ) ) elif frame_type == FrameType.PUSH_PROMISE and stream.push_id is None: if not self._is_client: raise FrameUnexpected("Clients must not send PUSH_PROMISE") frame_buf = Buffer(data=frame_data) push_id = frame_buf.pull_uint_var() headers = self._decode_headers( stream.stream_id, frame_data[frame_buf.tell() :] ) # validate headers validate_push_promise_headers(headers) # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_parsed", data=self._quic_logger.encode_http3_push_promise_frame( length=len(frame_data), headers=headers, push_id=push_id, stream_id=stream.stream_id, ), ) # emit event http_events.append( PushPromiseReceived( headers=headers, push_id=push_id, stream_id=stream.stream_id ) ) elif frame_type in ( FrameType.PRIORITY, FrameType.CANCEL_PUSH, FrameType.SETTINGS, FrameType.PUSH_PROMISE, FrameType.GOAWAY, FrameType.MAX_PUSH_ID, FrameType.DUPLICATE_PUSH, ): raise FrameUnexpected( "Invalid frame type on request stream" if stream.push_id is None else "Invalid frame type on push stream" ) return http_events def _init_connection(self) -> None: # send our settings self._local_control_stream_id = self._create_uni_stream(StreamType.CONTROL) self._sent_settings = self._get_local_settings() self._quic.send_stream_data( self._local_control_stream_id, encode_frame(FrameType.SETTINGS, encode_settings(self._sent_settings)), ) if self._is_client and self._max_push_id is not None: self._quic.send_stream_data( self._local_control_stream_id, encode_frame(FrameType.MAX_PUSH_ID, encode_uint_var(self._max_push_id)), ) # create encoder and decoder streams self._local_encoder_stream_id = self._create_uni_stream( StreamType.QPACK_ENCODER ) self._local_decoder_stream_id = self._create_uni_stream( StreamType.QPACK_DECODER ) def _log_stream_type( self, stream_id: int, stream_type: int, push_id: Optional[int] = None ) -> None: if self._quic_logger is not None: type_name = { 0: "control", 1: "push", 2: "qpack_encoder", 3: "qpack_decoder", 0x54: "webtransport", # NOTE: not standardized yet }.get(stream_type, "unknown") data = {"new": type_name, "stream_id": stream_id} if push_id is not None: data["associated_push_id"] = push_id self._quic_logger.log_event( category="http", event="stream_type_set", data=data, ) def _receive_datagram(self, data: bytes) -> List[H3Event]: """ Handle a datagram. """ buf = Buffer(data=data) try: quarter_stream_id = buf.pull_uint_var() except BufferReadError: raise DatagramError("Could not parse quarter stream ID") return [ DatagramReceived(data=data[buf.tell() :], stream_id=quarter_stream_id * 4) ] def _receive_request_or_push_data( self, stream: H3Stream, data: bytes, stream_ended: bool ) -> List[H3Event]: """ Handle data received on a request or push stream. """ http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True if stream.blocked: return http_events # shortcut for WEBTRANSPORT_STREAM frame fragments if ( stream.frame_type == FrameType.WEBTRANSPORT_STREAM and stream.session_id is not None ): http_events.append( WebTransportStreamDataReceived( data=stream.buffer, session_id=stream.session_id, stream_id=stream.stream_id, stream_ended=stream_ended, ) ) stream.buffer = b"" return http_events # shortcut for DATA frame fragments if ( stream.frame_type == FrameType.DATA and stream.frame_size is not None and len(stream.buffer) < stream.frame_size ): stream.content_length += len(stream.buffer) http_events.append( DataReceived( data=stream.buffer, push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=False, ) ) stream.frame_size -= len(stream.buffer) stream.buffer = b"" return http_events # handle lone FIN if stream_ended and not stream.buffer: self._check_content_length(stream) http_events.append( DataReceived( data=b"", push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=True, ) ) return http_events buf = Buffer(data=stream.buffer) consumed = 0 while not buf.eof(): # fetch next frame header if stream.frame_size is None: try: stream.frame_type = buf.pull_uint_var() stream.frame_size = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # WEBTRANSPORT_STREAM frames last until the end of the stream if stream.frame_type == FrameType.WEBTRANSPORT_STREAM: stream.session_id = stream.frame_size stream.frame_size = None frame_data = stream.buffer[consumed:] stream.buffer = b"" self._log_stream_type( stream_id=stream.stream_id, stream_type=StreamType.WEBTRANSPORT ) if frame_data or stream_ended: http_events.append( WebTransportStreamDataReceived( data=frame_data, session_id=stream.session_id, stream_id=stream.stream_id, stream_ended=stream_ended, ) ) return http_events # log frame if ( self._quic_logger is not None and stream.frame_type == FrameType.DATA ): self._quic_logger.log_event( category="http", event="frame_parsed", data=self._quic_logger.encode_http3_data_frame( length=stream.frame_size, stream_id=stream.stream_id ), ) # check how much data is available chunk_size = min(stream.frame_size, buf.capacity - consumed) if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size: break # read available data frame_data = buf.pull_bytes(chunk_size) frame_type = stream.frame_type consumed = buf.tell() # detect end of frame stream.frame_size -= chunk_size if not stream.frame_size: stream.frame_size = None stream.frame_type = None try: http_events.extend( self._handle_request_or_push_frame( frame_type=frame_type, frame_data=frame_data, stream=stream, stream_ended=stream.ended and buf.eof(), ) ) except pylsqpack.StreamBlocked: stream.blocked = True stream.blocked_frame_size = len(frame_data) break # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return http_events def _receive_stream_data_uni( self, stream: H3Stream, data: bytes, stream_ended: bool ) -> List[H3Event]: http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True buf = Buffer(data=stream.buffer) consumed = 0 unblocked_streams: Set[int] = set() while ( stream.stream_type in (StreamType.PUSH, StreamType.CONTROL, StreamType.WEBTRANSPORT) or not buf.eof() ): # fetch stream type for unidirectional streams if stream.stream_type is None: try: stream.stream_type = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # check unicity if stream.stream_type == StreamType.CONTROL: if self._peer_control_stream_id is not None: raise StreamCreationError("Only one control stream is allowed") self._peer_control_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_DECODER: if self._peer_decoder_stream_id is not None: raise StreamCreationError( "Only one QPACK decoder stream is allowed" ) self._peer_decoder_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_ENCODER: if self._peer_encoder_stream_id is not None: raise StreamCreationError( "Only one QPACK encoder stream is allowed" ) self._peer_encoder_stream_id = stream.stream_id # for PUSH, logging is performed once the push_id is known if stream.stream_type != StreamType.PUSH: self._log_stream_type( stream_id=stream.stream_id, stream_type=stream.stream_type ) if stream.stream_type == StreamType.CONTROL: if stream_ended: raise ClosedCriticalStream("Closing control stream is not allowed") # fetch next frame try: frame_type = buf.pull_uint_var() frame_length = buf.pull_uint_var() frame_data = buf.pull_bytes(frame_length) except BufferReadError: break consumed = buf.tell() self._handle_control_frame(frame_type, frame_data) elif stream.stream_type == StreamType.PUSH: # fetch push id if stream.push_id is None: try: stream.push_id = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() self._log_stream_type( push_id=stream.push_id, stream_id=stream.stream_id, stream_type=stream.stream_type, ) # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return self._receive_request_or_push_data(stream, b"", stream_ended) elif stream.stream_type == StreamType.WEBTRANSPORT: # fetch session id if stream.session_id is None: try: stream.session_id = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() frame_data = stream.buffer[consumed:] stream.buffer = b"" if frame_data or stream_ended: http_events.append( WebTransportStreamDataReceived( data=frame_data, session_id=stream.session_id, stream_ended=stream.ended, stream_id=stream.stream_id, ) ) return http_events elif stream.stream_type == StreamType.QPACK_DECODER: # feed unframed data to decoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: self._encoder.feed_decoder(data) except pylsqpack.DecoderStreamError as exc: raise QpackDecoderStreamError() from exc self._decoder_bytes_received += len(data) elif stream.stream_type == StreamType.QPACK_ENCODER: # feed unframed data to encoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: unblocked_streams.update(self._decoder.feed_encoder(data)) except pylsqpack.EncoderStreamError as exc: raise QpackEncoderStreamError() from exc self._encoder_bytes_received += len(data) else: # unknown stream type, discard data buf.seek(buf.capacity) consumed = buf.tell() # remove processed data from buffer stream.buffer = stream.buffer[consumed:] # process unblocked streams for stream_id in unblocked_streams: stream = self._stream[stream_id] # resume headers http_events.extend( self._handle_request_or_push_frame( frame_type=FrameType.HEADERS, frame_data=None, stream=stream, stream_ended=stream.ended and not stream.buffer, ) ) stream.blocked = False stream.blocked_frame_size = None # resume processing if stream.buffer: http_events.extend( self._receive_request_or_push_data(stream, b"", stream.ended) ) return http_events def _validate_settings(self, settings: Dict[int, int]) -> None: for setting in [ Setting.ENABLE_CONNECT_PROTOCOL, Setting.ENABLE_WEBTRANSPORT, Setting.H3_DATAGRAM, ]: if setting in settings and settings[setting] not in (0, 1): raise SettingsError(f"{setting.name} setting must be 0 or 1") if ( settings.get(Setting.H3_DATAGRAM) == 1 and self._quic._remote_max_datagram_frame_size is None ): raise SettingsError( "H3_DATAGRAM requires max_datagram_frame_size transport parameter" ) if ( settings.get(Setting.ENABLE_WEBTRANSPORT) == 1 and settings.get(Setting.H3_DATAGRAM) != 1 ): raise SettingsError("ENABLE_WEBTRANSPORT requires H3_DATAGRAM") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h3/events.py0000644000175100001770000000414300000000000020365 0ustar00runnerdocker00000000000000from dataclasses import dataclass from typing import List, Optional, Tuple Headers = List[Tuple[bytes, bytes]] class H3Event: """ Base class for HTTP/3 events. """ @dataclass class DataReceived(H3Event): """ The DataReceived event is fired whenever data is received on a stream from the remote peer. """ data: bytes "The data which was received." stream_id: int "The ID of the stream the data was received for." stream_ended: bool "Whether the STREAM frame had the FIN bit set." push_id: Optional[int] = None "The Push ID or `None` if this is not a push." @dataclass class DatagramReceived(H3Event): """ The DatagramReceived is fired whenever a datagram is received from the the remote peer. """ data: bytes "The data which was received." stream_id: int "The ID of the stream the data was received for." @dataclass class HeadersReceived(H3Event): """ The HeadersReceived event is fired whenever headers are received. """ headers: Headers "The headers." stream_id: int "The ID of the stream the headers were received for." stream_ended: bool "Whether the STREAM frame had the FIN bit set." push_id: Optional[int] = None "The Push ID or `None` if this is not a push." @dataclass class PushPromiseReceived(H3Event): """ The PushedStreamReceived event is fired whenever a pushed stream has been received from the remote peer. """ headers: Headers "The request headers." push_id: int "The Push ID of the push promise." stream_id: int "The Stream ID of the stream that the push is related to." @dataclass class WebTransportStreamDataReceived(H3Event): """ The WebTransportStreamDataReceived is fired whenever data is received for a WebTransport stream. """ data: bytes "The data which was received." stream_id: int "The ID of the stream the data was received for." stream_ended: bool "Whether the STREAM frame had the FIN bit set." session_id: int "The ID of the session the data was received for." ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/h3/exceptions.py0000644000175100001770000000052500000000000021242 0ustar00runnerdocker00000000000000class H3Error(Exception): """ Base class for HTTP/3 exceptions. """ class InvalidStreamTypeError(H3Error): """ An action was attempted on an invalid stream type. """ class NoAvailablePushIDError(H3Error): """ There are no available push IDs left, or push is not supported by the remote party. """ ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/py.typed0000644000175100001770000000000700000000000017667 0ustar00runnerdocker00000000000000Marker ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1292942 aioquic-1.2.0/src/aioquic/quic/0000755000175100001770000000000000000000000017134 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/__init__.py0000644000175100001770000000000000000000000021233 0ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/configuration.py0000644000175100001770000001070500000000000022360 0ustar00runnerdocker00000000000000from dataclasses import dataclass, field from os import PathLike from re import split from typing import Any, List, Optional, TextIO, Union from ..tls import ( CipherSuite, SessionTicket, load_pem_private_key, load_pem_x509_certificates, ) from .logger import QuicLogger from .packet import QuicProtocolVersion SMALLEST_MAX_DATAGRAM_SIZE = 1200 @dataclass class QuicConfiguration: """ A QUIC configuration. """ alpn_protocols: Optional[List[str]] = None """ A list of supported ALPN protocols. """ congestion_control_algorithm: str = "reno" """ The name of the congestion control algorithm to use. Currently supported algorithms: `"reno", `"cubic"`. """ connection_id_length: int = 8 """ The length in bytes of local connection IDs. """ idle_timeout: float = 60.0 """ The idle timeout in seconds. The connection is terminated if nothing is received for the given duration. """ is_client: bool = True """ Whether this is the client side of the QUIC connection. """ max_data: int = 1048576 """ Connection-wide flow control limit. """ max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE """ The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead. """ max_stream_data: int = 1048576 """ Per-stream flow control limit. """ quic_logger: Optional[QuicLogger] = None """ The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to. """ secrets_log_file: TextIO = None """ A file-like object in which to log traffic secrets. This is useful to analyze traffic captures with Wireshark. """ server_name: Optional[str] = None """ The server name to use when verifying the server's TLS certificate, which can either be a DNS name or an IP address. If it is a DNS name, it is also sent during the TLS handshake in the Server Name Indication (SNI) extension. .. note:: This is only used by clients. """ session_ticket: Optional[SessionTicket] = None """ The TLS session ticket which should be used for session resumption. """ token: bytes = b"" """ The address validation token that can be used to validate future connections. .. note:: This is only used by clients. """ # For internal purposes, not guaranteed to be stable. cadata: Optional[bytes] = None cafile: Optional[str] = None capath: Optional[str] = None certificate: Any = None certificate_chain: List[Any] = field(default_factory=list) cipher_suites: Optional[List[CipherSuite]] = None initial_rtt: float = 0.1 max_datagram_frame_size: Optional[int] = None original_version: Optional[int] = None private_key: Any = None quantum_readiness_test: bool = False supported_versions: List[int] = field( default_factory=lambda: [ QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2, ] ) verify_mode: Optional[int] = None def load_cert_chain( self, certfile: PathLike, keyfile: Optional[PathLike] = None, password: Optional[Union[bytes, str]] = None, ) -> None: """ Load a private key and the corresponding certificate. """ with open(certfile, "rb") as fp: boundary = b"-----BEGIN PRIVATE KEY-----\n" chunks = split(b"\n" + boundary, fp.read()) certificates = load_pem_x509_certificates(chunks[0]) if len(chunks) == 2: private_key = boundary + chunks[1] self.private_key = load_pem_private_key(private_key) self.certificate = certificates[0] self.certificate_chain = certificates[1:] if keyfile is not None: with open(keyfile, "rb") as fp: self.private_key = load_pem_private_key( fp.read(), password=password.encode("utf8") if isinstance(password, str) else password, ) def load_verify_locations( self, cafile: Optional[str] = None, capath: Optional[str] = None, cadata: Optional[bytes] = None, ) -> None: """ Load a set of "certification authority" (CA) certificates used to validate other peers' certificates. """ self.cafile = cafile self.capath = capath self.cadata = cadata ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1292942 aioquic-1.2.0/src/aioquic/quic/congestion/0000755000175100001770000000000000000000000021304 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/congestion/__init__.py0000644000175100001770000000000000000000000023403 0ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/congestion/base.py0000644000175100001770000000742200000000000022575 0ustar00runnerdocker00000000000000import abc from typing import Any, Dict, Iterable, Optional, Protocol from ..packet_builder import QuicSentPacket K_GRANULARITY = 0.001 # seconds K_INITIAL_WINDOW = 10 K_MINIMUM_WINDOW = 2 class QuicCongestionControl(abc.ABC): """ Base class for congestion control implementations. """ bytes_in_flight: int = 0 congestion_window: int = 0 ssthresh: Optional[int] = None def __init__(self, *, max_datagram_size: int) -> None: self.congestion_window = K_INITIAL_WINDOW * max_datagram_size @abc.abstractmethod def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ... @abc.abstractmethod def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ... @abc.abstractmethod def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ... @abc.abstractmethod def on_packets_lost( self, *, now: float, packets: Iterable[QuicSentPacket] ) -> None: ... @abc.abstractmethod def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ... def get_log_data(self) -> Dict[str, Any]: data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight} if self.ssthresh is not None: data["ssthresh"] = self.ssthresh return data class QuicCongestionControlFactory(Protocol): def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ... class QuicRttMonitor: """ Roundtrip time monitor for HyStart. """ def __init__(self) -> None: self._increases = 0 self._last_time = None self._ready = False self._size = 5 self._filtered_min: Optional[float] = None self._sample_idx = 0 self._sample_max: Optional[float] = None self._sample_min: Optional[float] = None self._sample_time = 0.0 self._samples = [0.0 for i in range(self._size)] def add_rtt(self, *, rtt: float) -> None: self._samples[self._sample_idx] = rtt self._sample_idx += 1 if self._sample_idx >= self._size: self._sample_idx = 0 self._ready = True if self._ready: self._sample_max = self._samples[0] self._sample_min = self._samples[0] for sample in self._samples[1:]: if sample < self._sample_min: self._sample_min = sample elif sample > self._sample_max: self._sample_max = sample def is_rtt_increasing(self, *, now: float, rtt: float) -> bool: if now > self._sample_time + K_GRANULARITY: self.add_rtt(rtt=rtt) self._sample_time = now if self._ready: if self._filtered_min is None or self._filtered_min > self._sample_max: self._filtered_min = self._sample_max delta = self._sample_min - self._filtered_min if delta * 4 >= self._filtered_min: self._increases += 1 if self._increases >= self._size: return True elif delta > 0: self._increases = 0 return False _factories: Dict[str, QuicCongestionControlFactory] = {} def create_congestion_control( name: str, *, max_datagram_size: int ) -> QuicCongestionControl: """ Create an instance of the `name` congestion control algorithm. """ try: factory = _factories[name] except KeyError: raise Exception(f"Unknown congestion control algorithm: {name}") return factory(max_datagram_size=max_datagram_size) def register_congestion_control( name: str, factory: QuicCongestionControlFactory ) -> None: """ Register a congestion control algorithm named `name`. """ _factories[name] = factory ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/congestion/cubic.py0000644000175100001770000001752200000000000022752 0ustar00runnerdocker00000000000000from typing import Any, Dict, Iterable from ..packet_builder import QuicSentPacket from .base import ( K_INITIAL_WINDOW, K_MINIMUM_WINDOW, QuicCongestionControl, QuicRttMonitor, register_congestion_control, ) # cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions) K_CUBIC_C = 0.4 K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7 K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity def better_cube_root(x: float) -> float: if x < 0: # avoid precision errors that make the cube root returns an imaginary number return -((-x) ** (1.0 / 3.0)) else: return (x) ** (1.0 / 3.0) class CubicCongestionControl(QuicCongestionControl): """ Cubic congestion control implementation for aioquic """ def __init__(self, max_datagram_size: int) -> None: super().__init__(max_datagram_size=max_datagram_size) # increase by one segment self.additive_increase_factor: int = max_datagram_size self._max_datagram_size: int = max_datagram_size self._congestion_recovery_start_time = 0.0 self._rtt_monitor = QuicRttMonitor() self.rtt = 0.02 # starting RTT is considered to be 20ms self.reset() self.last_ack = 0.0 def W_cubic(self, t) -> int: W_max_segments = self._W_max / self._max_datagram_size target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments) return int(target_segments * self._max_datagram_size) def is_reno_friendly(self, t) -> bool: return self.W_cubic(t) < self._W_est def is_concave(self) -> bool: return self.congestion_window < self._W_max def reset(self) -> None: self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size self.ssthresh = None self._first_slow_start = True self._starting_congestion_avoidance = False self.K: float = 0.0 self._W_est = 0 self._cwnd_epoch = 0 self._t_epoch = 0.0 self._W_max = self.congestion_window def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: self.bytes_in_flight -= packet.sent_bytes self.last_ack = packet.sent_time if self.ssthresh is None or self.congestion_window < self.ssthresh: # slow start self.congestion_window += packet.sent_bytes else: # congestion avoidance if self._first_slow_start and not self._starting_congestion_avoidance: # exiting slow start without having a loss self._first_slow_start = False self._W_max = self.congestion_window self._t_epoch = now self._cwnd_epoch = self.congestion_window self._W_est = self._cwnd_epoch # calculate K W_max_segments = self._W_max / self._max_datagram_size cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size self.K = better_cube_root( (W_max_segments - cwnd_epoch_segments) / K_CUBIC_C ) # initialize the variables used at start of congestion avoidance if self._starting_congestion_avoidance: self._starting_congestion_avoidance = False self._first_slow_start = False self._t_epoch = now self._cwnd_epoch = self.congestion_window self._W_est = self._cwnd_epoch # calculate K W_max_segments = self._W_max / self._max_datagram_size cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size self.K = better_cube_root( (W_max_segments - cwnd_epoch_segments) / K_CUBIC_C ) self._W_est = int( self._W_est + self.additive_increase_factor * (packet.sent_bytes / self.congestion_window) ) t = now - self._t_epoch target: int = 0 W_cubic = self.W_cubic(t + self.rtt) if W_cubic < self.congestion_window: target = self.congestion_window elif W_cubic > 1.5 * self.congestion_window: target = int(self.congestion_window * 1.5) else: target = W_cubic if self.is_reno_friendly(t): # reno friendly region of cubic # (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region) self.congestion_window = self._W_est elif self.is_concave(): # concave region of cubic # (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region) self.congestion_window = int( self.congestion_window + ( (target - self.congestion_window) * (self._max_datagram_size / self.congestion_window) ) ) else: # convex region of cubic # (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region) self.congestion_window = int( self.congestion_window + ( (target - self.congestion_window) * (self._max_datagram_size / self.congestion_window) ) ) def on_packet_sent(self, *, packet: QuicSentPacket) -> None: self.bytes_in_flight += packet.sent_bytes if self.last_ack == 0.0: return elapsed_idle = packet.sent_time - self.last_ack if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME: self.reset() def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: for packet in packets: self.bytes_in_flight -= packet.sent_bytes def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None: lost_largest_time = 0.0 for packet in packets: self.bytes_in_flight -= packet.sent_bytes lost_largest_time = packet.sent_time # start a new congestion event if packet was sent after the # start of the previous congestion recovery period. if lost_largest_time > self._congestion_recovery_start_time: self._congestion_recovery_start_time = now # Normal congestion handle, can't be used in same time as fast convergence # self._W_max = self.congestion_window # fast convergence if self._W_max is not None and self.congestion_window < self._W_max: self._W_max = int( self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2 ) else: self._W_max = self.congestion_window # normal congestion MD flight_size = self.bytes_in_flight new_ssthresh = max( int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW * self._max_datagram_size, ) self.ssthresh = new_ssthresh self.congestion_window = max( self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size ) # restart a new congestion avoidance phase self._starting_congestion_avoidance = True def on_rtt_measurement(self, *, now: float, rtt: float) -> None: self.rtt = rtt # check whether we should exit slow start if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( rtt=rtt, now=now ): self.ssthresh = self.congestion_window def get_log_data(self) -> Dict[str, Any]: data = super().get_log_data() data["cubic-wmax"] = int(self._W_max) return data register_congestion_control("cubic", CubicCongestionControl) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/congestion/reno.py0000644000175100001770000000544700000000000022633 0ustar00runnerdocker00000000000000from typing import Iterable from ..packet_builder import QuicSentPacket from .base import ( K_MINIMUM_WINDOW, QuicCongestionControl, QuicRttMonitor, register_congestion_control, ) K_LOSS_REDUCTION_FACTOR = 0.5 class RenoCongestionControl(QuicCongestionControl): """ New Reno congestion control. """ def __init__(self, *, max_datagram_size: int) -> None: super().__init__(max_datagram_size=max_datagram_size) self._max_datagram_size = max_datagram_size self._congestion_recovery_start_time = 0.0 self._congestion_stash = 0 self._rtt_monitor = QuicRttMonitor() def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: self.bytes_in_flight -= packet.sent_bytes # don't increase window in congestion recovery if packet.sent_time <= self._congestion_recovery_start_time: return if self.ssthresh is None or self.congestion_window < self.ssthresh: # slow start self.congestion_window += packet.sent_bytes else: # congestion avoidance self._congestion_stash += packet.sent_bytes count = self._congestion_stash // self.congestion_window if count: self._congestion_stash -= count * self.congestion_window self.congestion_window += count * self._max_datagram_size def on_packet_sent(self, *, packet: QuicSentPacket) -> None: self.bytes_in_flight += packet.sent_bytes def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: for packet in packets: self.bytes_in_flight -= packet.sent_bytes def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None: lost_largest_time = 0.0 for packet in packets: self.bytes_in_flight -= packet.sent_bytes lost_largest_time = packet.sent_time # start a new congestion event if packet was sent after the # start of the previous congestion recovery period. if lost_largest_time > self._congestion_recovery_start_time: self._congestion_recovery_start_time = now self.congestion_window = max( int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW * self._max_datagram_size, ) self.ssthresh = self.congestion_window # TODO : collapse congestion window if persistent congestion def on_rtt_measurement(self, *, now: float, rtt: float) -> None: # check whether we should exit slow start if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( now=now, rtt=rtt ): self.ssthresh = self.congestion_window register_congestion_control("reno", RenoCongestionControl) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/connection.py0000644000175100001770000041777300000000000021670 0ustar00runnerdocker00000000000000import binascii import logging import os from collections import deque from dataclasses import dataclass from enum import Enum from functools import partial from typing import ( Any, Callable, Deque, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, ) from .. import tls from ..buffer import ( UINT_VAR_MAX, UINT_VAR_MAX_SIZE, Buffer, BufferReadError, size_uint_var, ) from . import events from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration from .congestion.base import K_GRANULARITY from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback from .logger import QuicLoggerTrace from .packet import ( CONNECTION_ID_MAX_SIZE, NON_ACK_ELICITING_FRAME_TYPES, PROBING_FRAME_TYPES, RETRY_INTEGRITY_TAG_SIZE, STATELESS_RESET_TOKEN_SIZE, QuicErrorCode, QuicFrameType, QuicHeader, QuicPacketType, QuicProtocolVersion, QuicStreamFrame, QuicTransportParameters, QuicVersionInformation, get_retry_integrity_tag, get_spin_bit, pretty_protocol_version, pull_ack_frame, pull_quic_header, pull_quic_transport_parameters, push_ack_frame, push_quic_transport_parameters, ) from .packet_builder import QuicDeliveryState, QuicPacketBuilder, QuicPacketBuilderStop from .recovery import QuicPacketRecovery, QuicPacketSpace from .stream import FinalSizeError, QuicStream, StreamFinishedError logger = logging.getLogger("quic") CRYPTO_BUFFER_SIZE = 16384 EPOCH_SHORTCUTS = { "I": tls.Epoch.INITIAL, "H": tls.Epoch.HANDSHAKE, "0": tls.Epoch.ZERO_RTT, "1": tls.Epoch.ONE_RTT, } MAX_EARLY_DATA = 0xFFFFFFFF MAX_REMOTE_CHALLENGES = 5 MAX_LOCAL_CHALLENGES = 5 SECRETS_LABELS = [ [ None, "CLIENT_EARLY_TRAFFIC_SECRET", "CLIENT_HANDSHAKE_TRAFFIC_SECRET", "CLIENT_TRAFFIC_SECRET_0", ], [ None, None, "SERVER_HANDSHAKE_TRAFFIC_SECRET", "SERVER_TRAFFIC_SECRET_0", ], ] STREAM_FLAGS = 0x07 STREAM_COUNT_MAX = 0x1000000000000000 UDP_HEADER_SIZE = 8 MAX_PENDING_RETIRES = 100 MAX_PENDING_CRYPTO = 524288 # in bytes NetworkAddress = Any # frame sizes ACK_FRAME_CAPACITY = 64 # FIXME: this is arbitrary! APPLICATION_CLOSE_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE # + reason length CONNECTION_LIMIT_FRAME_CAPACITY = 1 + UINT_VAR_MAX_SIZE HANDSHAKE_DONE_FRAME_CAPACITY = 1 MAX_STREAM_DATA_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE NEW_CONNECTION_ID_FRAME_CAPACITY = ( 1 + 2 * UINT_VAR_MAX_SIZE + 1 + CONNECTION_ID_MAX_SIZE + STATELESS_RESET_TOKEN_SIZE ) PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8 PATH_RESPONSE_FRAME_CAPACITY = 1 + 8 PING_FRAME_CAPACITY = 1 RESET_STREAM_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE RETIRE_CONNECTION_ID_CAPACITY = 1 + UINT_VAR_MAX_SIZE STOP_SENDING_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE STREAMS_BLOCKED_CAPACITY = 1 + UINT_VAR_MAX_SIZE TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE # + reason length def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]: return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) def is_version_compatible(from_version: int, to_version: int) -> bool: """ Return whether it is possible to perform compatible version negotiation from `from_version` to `to_version`. """ # Version 1 is compatible with version 2 and vice versa. These are the # only compatible versions so far. return set([from_version, to_version]) == set( [QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2] ) def dump_cid(cid: bytes) -> str: return binascii.hexlify(cid).decode("ascii") def get_epoch(packet_type: QuicPacketType) -> tls.Epoch: if packet_type == QuicPacketType.INITIAL: return tls.Epoch.INITIAL elif packet_type == QuicPacketType.ZERO_RTT: return tls.Epoch.ZERO_RTT elif packet_type == QuicPacketType.HANDSHAKE: return tls.Epoch.HANDSHAKE else: return tls.Epoch.ONE_RTT def stream_is_client_initiated(stream_id: int) -> bool: """ Returns True if the stream is client initiated. """ return not (stream_id & 1) def stream_is_unidirectional(stream_id: int) -> bool: """ Returns True if the stream is unidirectional. """ return bool(stream_id & 2) class Limit: def __init__(self, frame_type: int, name: str, value: int): self.frame_type = frame_type self.name = name self.sent = value self.used = 0 self.value = value class QuicConnectionError(Exception): def __init__(self, error_code: int, frame_type: int, reason_phrase: str): self.error_code = error_code self.frame_type = frame_type self.reason_phrase = reason_phrase def __str__(self) -> str: s = "Error: %d, reason: %s" % (self.error_code, self.reason_phrase) if self.frame_type is not None: s += ", frame_type: %s" % self.frame_type return s class QuicConnectionAdapter(logging.LoggerAdapter): def process(self, msg: str, kwargs: Any) -> Tuple[str, Any]: return "[%s] %s" % (self.extra["id"], msg), kwargs @dataclass class QuicConnectionId: cid: bytes sequence_number: int stateless_reset_token: bytes = b"" was_sent: bool = False class QuicConnectionState(Enum): FIRSTFLIGHT = 0 CONNECTED = 1 CLOSING = 2 DRAINING = 3 TERMINATED = 4 class QuicNetworkPath: def __init__(self, addr: NetworkAddress, is_validated: bool = False): self.addr: NetworkAddress = addr self.bytes_received: int = 0 self.bytes_sent: int = 0 self.is_validated: bool = is_validated self.local_challenge_sent: bool = False self.remote_challenges: Deque[bytes] = deque() def can_send(self, size: int) -> bool: return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received @dataclass class QuicReceiveContext: epoch: tls.Epoch host_cid: bytes network_path: QuicNetworkPath quic_logger_frames: Optional[List[Any]] time: float version: Optional[int] QuicTokenHandler = Callable[[bytes], None] END_STATES = frozenset( [ QuicConnectionState.CLOSING, QuicConnectionState.DRAINING, QuicConnectionState.TERMINATED, ] ) class QuicConnection: """ A QUIC connection. The state machine is driven by three kinds of sources: - the API user requesting data to be send out (see :meth:`connect`, :meth:`reset_stream`, :meth:`send_ping`, :meth:`send_datagram_frame` and :meth:`send_stream_data`) - data being received from the network (see :meth:`receive_datagram`) - a timer firing (see :meth:`handle_timer`) :param configuration: The QUIC configuration to use. """ def __init__( self, *, configuration: QuicConfiguration, original_destination_connection_id: Optional[bytes] = None, retry_source_connection_id: Optional[bytes] = None, session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None, session_ticket_handler: Optional[tls.SessionTicketHandler] = None, token_handler: Optional[QuicTokenHandler] = None, ) -> None: assert configuration.max_datagram_size >= SMALLEST_MAX_DATAGRAM_SIZE, ( "The smallest allowed maximum datagram size is " f"{SMALLEST_MAX_DATAGRAM_SIZE} bytes" ) if configuration.is_client: assert ( original_destination_connection_id is None ), "Cannot set original_destination_connection_id for a client" assert ( retry_source_connection_id is None ), "Cannot set retry_source_connection_id for a client" else: assert token_handler is None, "Cannot set `token_handler` for a server" assert ( configuration.token == b"" ), "Cannot set `configuration.token` for a server" assert ( configuration.certificate is not None ), "SSL certificate is required for a server" assert ( configuration.private_key is not None ), "SSL private key is required for a server" assert ( original_destination_connection_id is not None ), "original_destination_connection_id is required for a server" # configuration self._configuration = configuration self._is_client = configuration.is_client self._ack_delay = K_GRANULARITY self._close_at: Optional[float] = None self._close_event: Optional[events.ConnectionTerminated] = None self._connect_called = False self._cryptos: Dict[tls.Epoch, CryptoPair] = {} self._cryptos_initial: Dict[int, CryptoPair] = {} self._crypto_buffers: Dict[tls.Epoch, Buffer] = {} self._crypto_frame_type: Optional[int] = None self._crypto_packet_version: Optional[int] = None self._crypto_retransmitted = False self._crypto_streams: Dict[tls.Epoch, QuicStream] = {} self._events: Deque[events.QuicEvent] = deque() self._handshake_complete = False self._handshake_confirmed = False self._host_cids = [ QuicConnectionId( cid=os.urandom(configuration.connection_id_length), sequence_number=0, stateless_reset_token=os.urandom(16) if not self._is_client else None, was_sent=True, ) ] self.host_cid = self._host_cids[0].cid self._host_cid_seq = 1 self._local_ack_delay_exponent = 3 self._local_active_connection_id_limit = 8 self._local_challenges: Dict[bytes, QuicNetworkPath] = {} self._local_initial_source_connection_id = self._host_cids[0].cid self._local_max_data = Limit( frame_type=QuicFrameType.MAX_DATA, name="max_data", value=configuration.max_data, ) self._local_max_stream_data_bidi_local = configuration.max_stream_data self._local_max_stream_data_bidi_remote = configuration.max_stream_data self._local_max_stream_data_uni = configuration.max_stream_data self._local_max_streams_bidi = Limit( frame_type=QuicFrameType.MAX_STREAMS_BIDI, name="max_streams_bidi", value=128, ) self._local_max_streams_uni = Limit( frame_type=QuicFrameType.MAX_STREAMS_UNI, name="max_streams_uni", value=128 ) self._local_next_stream_id_bidi = 0 if self._is_client else 1 self._local_next_stream_id_uni = 2 if self._is_client else 3 self._loss_at: Optional[float] = None self._max_datagram_size = configuration.max_datagram_size self._network_paths: List[QuicNetworkPath] = [] self._pacing_at: Optional[float] = None self._packet_number = 0 self._peer_cid = QuicConnectionId( cid=os.urandom(configuration.connection_id_length), sequence_number=None ) self._peer_cid_available: List[QuicConnectionId] = [] self._peer_cid_sequence_numbers: Set[int] = set([0]) self._peer_retire_prior_to = 0 self._peer_token = configuration.token self._quic_logger: Optional[QuicLoggerTrace] = None self._remote_ack_delay_exponent = 3 self._remote_active_connection_id_limit = 2 self._remote_initial_source_connection_id: Optional[bytes] = None self._remote_max_idle_timeout: Optional[float] = None # seconds self._remote_max_data = 0 self._remote_max_data_used = 0 self._remote_max_datagram_frame_size: Optional[int] = None self._remote_max_stream_data_bidi_local = 0 self._remote_max_stream_data_bidi_remote = 0 self._remote_max_stream_data_uni = 0 self._remote_max_streams_bidi = 0 self._remote_max_streams_uni = 0 self._remote_version_information: Optional[QuicVersionInformation] = None self._retry_count = 0 self._retry_source_connection_id = retry_source_connection_id self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {} self._spin_bit = False self._spin_highest_pn = 0 self._state = QuicConnectionState.FIRSTFLIGHT self._streams: Dict[int, QuicStream] = {} self._streams_queue: List[QuicStream] = [] self._streams_blocked_bidi: List[QuicStream] = [] self._streams_blocked_uni: List[QuicStream] = [] self._streams_finished: Set[int] = set() self._version: Optional[int] = None self._version_negotiated_compatible = False self._version_negotiated_incompatible = False if self._is_client: self._original_destination_connection_id = self._peer_cid.cid else: self._original_destination_connection_id = ( original_destination_connection_id ) # logging self._logger = QuicConnectionAdapter( logger, {"id": dump_cid(self._original_destination_connection_id)} ) if configuration.quic_logger: self._quic_logger = configuration.quic_logger.start_trace( is_client=configuration.is_client, odcid=self._original_destination_connection_id, ) # loss recovery self._loss = QuicPacketRecovery( congestion_control_algorithm=configuration.congestion_control_algorithm, initial_rtt=configuration.initial_rtt, max_datagram_size=self._max_datagram_size, peer_completed_address_validation=not self._is_client, quic_logger=self._quic_logger, send_probe=self._send_probe, logger=self._logger, ) # things to send self._close_pending = False self._datagrams_pending: Deque[bytes] = deque() self._handshake_done_pending = False self._ping_pending: List[int] = [] self._probe_pending = False self._retire_connection_ids: List[int] = [] self._streams_blocked_pending = False # callbacks self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler self._token_handler = token_handler # frame handlers self.__frame_handlers = { 0x00: (self._handle_padding_frame, EPOCHS("IH01")), 0x01: (self._handle_ping_frame, EPOCHS("IH01")), 0x02: (self._handle_ack_frame, EPOCHS("IH1")), 0x03: (self._handle_ack_frame, EPOCHS("IH1")), 0x04: (self._handle_reset_stream_frame, EPOCHS("01")), 0x05: (self._handle_stop_sending_frame, EPOCHS("01")), 0x06: (self._handle_crypto_frame, EPOCHS("IH1")), 0x07: (self._handle_new_token_frame, EPOCHS("1")), 0x08: (self._handle_stream_frame, EPOCHS("01")), 0x09: (self._handle_stream_frame, EPOCHS("01")), 0x0A: (self._handle_stream_frame, EPOCHS("01")), 0x0B: (self._handle_stream_frame, EPOCHS("01")), 0x0C: (self._handle_stream_frame, EPOCHS("01")), 0x0D: (self._handle_stream_frame, EPOCHS("01")), 0x0E: (self._handle_stream_frame, EPOCHS("01")), 0x0F: (self._handle_stream_frame, EPOCHS("01")), 0x10: (self._handle_max_data_frame, EPOCHS("01")), 0x11: (self._handle_max_stream_data_frame, EPOCHS("01")), 0x12: (self._handle_max_streams_bidi_frame, EPOCHS("01")), 0x13: (self._handle_max_streams_uni_frame, EPOCHS("01")), 0x14: (self._handle_data_blocked_frame, EPOCHS("01")), 0x15: (self._handle_stream_data_blocked_frame, EPOCHS("01")), 0x16: (self._handle_streams_blocked_frame, EPOCHS("01")), 0x17: (self._handle_streams_blocked_frame, EPOCHS("01")), 0x18: (self._handle_new_connection_id_frame, EPOCHS("01")), 0x19: (self._handle_retire_connection_id_frame, EPOCHS("01")), 0x1A: (self._handle_path_challenge_frame, EPOCHS("01")), 0x1B: (self._handle_path_response_frame, EPOCHS("01")), 0x1C: (self._handle_connection_close_frame, EPOCHS("IH01")), 0x1D: (self._handle_connection_close_frame, EPOCHS("01")), 0x1E: (self._handle_handshake_done_frame, EPOCHS("1")), 0x30: (self._handle_datagram_frame, EPOCHS("01")), 0x31: (self._handle_datagram_frame, EPOCHS("01")), } @property def configuration(self) -> QuicConfiguration: return self._configuration @property def original_destination_connection_id(self) -> bytes: return self._original_destination_connection_id def change_connection_id(self) -> None: """ Switch to the next available connection ID and retire the previous one. .. aioquic_transmit:: """ if self._peer_cid_available: # retire previous CID self._retire_peer_cid(self._peer_cid) # assign new CID self._consume_peer_cid() def close( self, error_code: int = QuicErrorCode.NO_ERROR, frame_type: Optional[int] = None, reason_phrase: str = "", ) -> None: """ Close the connection. .. aioquic_transmit:: :param error_code: An error code indicating why the connection is being closed. :param reason_phrase: A human-readable explanation of why the connection is being closed. """ if self._close_event is None and self._state not in END_STATES: self._close_event = events.ConnectionTerminated( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) self._close_pending = True def connect(self, addr: NetworkAddress, now: float) -> None: """ Initiate the TLS handshake. This method can only be called for clients and a single time. .. aioquic_transmit:: :param addr: The network address of the remote peer. :param now: The current time. """ assert ( self._is_client and not self._connect_called ), "connect() can only be called for clients and a single time" self._connect_called = True self._network_paths = [QuicNetworkPath(addr, is_validated=True)] if self._configuration.original_version is not None: self._version = self._configuration.original_version else: self._version = self._configuration.supported_versions[0] self._connect(now=now) def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: """ Return a list of `(data, addr)` tuples of datagrams which need to be sent, and the network address to which they need to be sent. After calling this method call :meth:`get_timer` to know when the next timer needs to be set. :param now: The current time. """ network_path = self._network_paths[0] if self._state in END_STATES: return [] # build datagrams builder = QuicPacketBuilder( host_cid=self.host_cid, is_client=self._is_client, max_datagram_size=self._max_datagram_size, packet_number=self._packet_number, peer_cid=self._peer_cid.cid, peer_token=self._peer_token, quic_logger=self._quic_logger, spin_bit=self._spin_bit, version=self._version, ) if self._close_pending: epoch_packet_types = [] if not self._handshake_confirmed: epoch_packet_types += [ (tls.Epoch.INITIAL, QuicPacketType.INITIAL), (tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE), ] epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT)) for epoch, packet_type in epoch_packet_types: crypto = self._cryptos[epoch] if crypto.send.is_valid(): builder.start_packet(packet_type, crypto) self._write_connection_close_frame( builder=builder, epoch=epoch, error_code=self._close_event.error_code, frame_type=self._close_event.frame_type, reason_phrase=self._close_event.reason_phrase, ) self._logger.info( "Connection close sent (code 0x%X, reason %s)", self._close_event.error_code, self._close_event.reason_phrase, ) self._close_pending = False self._close_begin(is_initiator=True, now=now) else: # congestion control builder.max_flight_bytes = ( self._loss.congestion_window - self._loss.bytes_in_flight ) if ( self._probe_pending and builder.max_flight_bytes < self._max_datagram_size ): builder.max_flight_bytes = self._max_datagram_size # limit data on un-validated network paths if not network_path.is_validated: builder.max_total_bytes = ( network_path.bytes_received * 3 - network_path.bytes_sent ) try: if not self._handshake_confirmed: for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: self._write_handshake(builder, epoch, now) self._write_application(builder, network_path, now) except QuicPacketBuilderStop: pass datagrams, packets = builder.flush() if datagrams: self._packet_number = builder.packet_number # register packets sent_handshake = False for packet in packets: packet.sent_time = now self._loss.on_packet_sent( packet=packet, space=self._spaces[packet.epoch] ) if packet.epoch == tls.Epoch.HANDSHAKE: sent_handshake = True # log packet if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_sent", data={ "frames": packet.quic_logger_frames, "header": { "packet_number": packet.packet_number, "packet_type": self._quic_logger.packet_type( packet.packet_type ), "scid": ( "" if packet.packet_type == QuicPacketType.ONE_RTT else dump_cid(self.host_cid) ), "dcid": dump_cid(self._peer_cid.cid), }, "raw": {"length": packet.sent_bytes}, }, ) # check if we can discard initial keys if sent_handshake and self._is_client: self._discard_epoch(tls.Epoch.INITIAL) # return datagrams to send and the destination network address ret = [] for datagram in datagrams: payload_length = len(datagram) network_path.bytes_sent += payload_length ret.append((datagram, network_path.addr)) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="datagrams_sent", data={ "count": 1, "raw": [ { "length": UDP_HEADER_SIZE + payload_length, "payload_length": payload_length, } ], }, ) return ret def get_next_available_stream_id(self, is_unidirectional=False) -> int: """ Return the stream ID for the next stream created by this endpoint. """ if is_unidirectional: return self._local_next_stream_id_uni else: return self._local_next_stream_id_bidi def get_timer(self) -> Optional[float]: """ Return the time at which the timer should fire or None if no timer is needed. """ timer_at = self._close_at if self._state not in END_STATES: # ack timer for space in self._loss.spaces: if space.ack_at is not None and space.ack_at < timer_at: timer_at = space.ack_at # loss detection timer self._loss_at = self._loss.get_loss_detection_time() if self._loss_at is not None and self._loss_at < timer_at: timer_at = self._loss_at # pacing timer if self._pacing_at is not None and self._pacing_at < timer_at: timer_at = self._pacing_at return timer_at def handle_timer(self, now: float) -> None: """ Handle the timer. .. aioquic_transmit:: :param now: The current time. """ # end of closing period or idle timeout if now >= self._close_at: if self._close_event is None: self._close_event = events.ConnectionTerminated( error_code=QuicErrorCode.INTERNAL_ERROR, frame_type=QuicFrameType.PADDING, reason_phrase="Idle timeout", ) self._close_end() return # loss detection timeout if self._loss_at is not None and now >= self._loss_at: self._logger.debug("Loss detection triggered") self._loss.on_loss_detection_timeout(now=now) def next_event(self) -> Optional[events.QuicEvent]: """ Retrieve the next event from the event buffer. Returns `None` if there are no buffered events. """ try: return self._events.popleft() except IndexError: return None def _idle_timeout(self) -> float: # RFC 9000 section 10.1 # Start with our local timeout. idle_timeout = self._configuration.idle_timeout if self._remote_max_idle_timeout is not None: # Our peer has a preference too, so pick the smaller timeout. idle_timeout = min(idle_timeout, self._remote_max_idle_timeout) # But not too small! return max(idle_timeout, 3 * self._loss.get_probe_timeout()) def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None: """ Handle an incoming datagram. .. aioquic_transmit:: :param data: The datagram which was received. :param addr: The network address from which the datagram was received. :param now: The current time. """ payload_length = len(data) # stop handling packets when closing if self._state in END_STATES: return # log datagram if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="datagrams_received", data={ "count": 1, "raw": [ { "length": UDP_HEADER_SIZE + payload_length, "payload_length": payload_length, } ], }, ) # For anti-amplification purposes, servers need to keep track of the # amount of data received on unvalidated network paths. We must count the # entire datagram size regardless of whether packets are processed or # dropped. # # This is particularly important when talking to clients who pad # datagrams containing INITIAL packets by appending bytes after the # long-header packets, which is legitimate behaviour. # # https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 network_path = self._find_network_path(addr) if not network_path.is_validated: network_path.bytes_received += payload_length # for servers, arm the idle timeout on the first datagram if self._close_at is None: self._close_at = now + self._idle_timeout() buf = Buffer(data=data) while not buf.eof(): start_off = buf.tell() try: header = pull_quic_header( buf, host_cid_length=self._configuration.connection_id_length ) except ValueError: if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "header_parse_error", "raw": {"length": buf.capacity - start_off}, }, ) return # RFC 9000 section 14.1 requires servers to drop all initial packets # contained in a datagram smaller than 1200 bytes. if ( not self._is_client and header.packet_type == QuicPacketType.INITIAL and payload_length < SMALLEST_MAX_DATAGRAM_SIZE ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "initial_packet_datagram_too_small", "raw": {"length": header.packet_length}, }, ) return # Check destination CID matches. destination_cid_seq: Optional[int] = None for connection_id in self._host_cids: if header.destination_cid == connection_id.cid: destination_cid_seq = connection_id.sequence_number break if ( self._is_client or header.packet_type == QuicPacketType.HANDSHAKE ) and destination_cid_seq is None: if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unknown_connection_id", "raw": {"length": header.packet_length}, }, ) return # Handle version negotiation packet. if header.packet_type == QuicPacketType.VERSION_NEGOTIATION: self._receive_version_negotiation_packet(header=header, now=now) return # Check long header packet protocol version. if ( header.version is not None and header.version not in self._configuration.supported_versions ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unsupported_version", "raw": {"length": header.packet_length}, }, ) return # Handle retry packet. if header.packet_type == QuicPacketType.RETRY: self._receive_retry_packet( header=header, packet_without_tag=buf.data_slice( start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE ), now=now, ) return crypto_frame_required = False # Server initialization. if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT: assert ( header.packet_type == QuicPacketType.INITIAL ), "first packet must be INITIAL" crypto_frame_required = True self._network_paths = [network_path] self._version = header.version self._initialize(header.destination_cid) # Determine crypto and packet space. epoch = get_epoch(header.packet_type) if epoch == tls.Epoch.INITIAL: crypto = self._cryptos_initial[header.version] else: crypto = self._cryptos[epoch] if epoch == tls.Epoch.ZERO_RTT: space = self._spaces[tls.Epoch.ONE_RTT] else: space = self._spaces[epoch] # decrypt packet encrypted_off = buf.tell() - start_off end_off = start_off + header.packet_length buf.seek(end_off) try: plain_header, plain_payload, packet_number = crypto.decrypt_packet( data[start_off:end_off], encrypted_off, space.expected_packet_number ) except KeyUnavailableError as exc: self._logger.debug(exc) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "key_unavailable", "raw": {"length": header.packet_length}, }, ) # If a client receives HANDSHAKE or 1-RTT packets before it has # handshake keys, it can assume that the server's INITIAL was lost. if ( self._is_client and epoch in (tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT) and not self._crypto_retransmitted ): self._loss.reschedule_data(now=now) self._crypto_retransmitted = True continue except CryptoError as exc: self._logger.debug(exc) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "payload_decrypt_error", "raw": {"length": header.packet_length}, }, ) continue # check reserved bits if header.packet_type == QuicPacketType.ONE_RTT: reserved_mask = 0x18 else: reserved_mask = 0x0C if plain_header[0] & reserved_mask: self.close( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Reserved bits must be zero", ) return # log packet quic_logger_frames: Optional[List[Dict]] = None if self._quic_logger is not None: quic_logger_frames = [] self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": quic_logger_frames, "header": { "packet_number": packet_number, "packet_type": self._quic_logger.packet_type( header.packet_type ), "dcid": dump_cid(header.destination_cid), "scid": dump_cid(header.source_cid), }, "raw": {"length": header.packet_length}, }, ) # raise expected packet number if packet_number > space.expected_packet_number: space.expected_packet_number = packet_number + 1 # discard initial keys and packet space if not self._is_client and epoch == tls.Epoch.HANDSHAKE: self._discard_epoch(tls.Epoch.INITIAL) # update state if self._peer_cid.sequence_number is None: self._peer_cid.cid = header.source_cid self._peer_cid.sequence_number = 0 if self._state == QuicConnectionState.FIRSTFLIGHT: self._remote_initial_source_connection_id = header.source_cid self._set_state(QuicConnectionState.CONNECTED) # update spin bit if ( header.packet_type == QuicPacketType.ONE_RTT and packet_number > self._spin_highest_pn ): spin_bit = get_spin_bit(plain_header[0]) if self._is_client: self._spin_bit = not spin_bit else: self._spin_bit = spin_bit self._spin_highest_pn = packet_number if self._quic_logger is not None: self._quic_logger.log_event( category="connectivity", event="spin_bit_updated", data={"state": self._spin_bit}, ) # handle payload context = QuicReceiveContext( epoch=epoch, host_cid=header.destination_cid, network_path=network_path, quic_logger_frames=quic_logger_frames, time=now, version=header.version, ) try: is_ack_eliciting, is_probing = self._payload_received( context, plain_payload, crypto_frame_required=crypto_frame_required ) except QuicConnectionError as exc: self._logger.warning(exc) self.close( error_code=exc.error_code, frame_type=exc.frame_type, reason_phrase=exc.reason_phrase, ) if self._state in END_STATES or self._close_pending: return # update idle timeout self._close_at = now + self._idle_timeout() # handle migration if ( not self._is_client and context.host_cid != self.host_cid and epoch == tls.Epoch.ONE_RTT ): self._logger.debug( "Peer switching to CID %s (%d)", dump_cid(context.host_cid), destination_cid_seq, ) self.host_cid = context.host_cid self.change_connection_id() # update network path if not network_path.is_validated and epoch == tls.Epoch.HANDSHAKE: self._logger.debug( "Network path %s validated by handshake", network_path.addr ) network_path.is_validated = True if network_path not in self._network_paths: self._network_paths.append(network_path) idx = self._network_paths.index(network_path) if idx and not is_probing and packet_number > space.largest_received_packet: self._logger.debug("Network path %s promoted", network_path.addr) self._network_paths.pop(idx) self._network_paths.insert(0, network_path) # record packet as received if not space.discarded: if packet_number > space.largest_received_packet: space.largest_received_packet = packet_number space.largest_received_time = now space.ack_queue.add(packet_number) if is_ack_eliciting and space.ack_at is None: space.ack_at = now + self._ack_delay def request_key_update(self) -> None: """ Request an update of the encryption keys. .. aioquic_transmit:: """ assert self._handshake_complete, "cannot change key before handshake completes" self._cryptos[tls.Epoch.ONE_RTT].update_key() def reset_stream(self, stream_id: int, error_code: int) -> None: """ Abruptly terminate the sending part of a stream. .. aioquic_transmit:: :param stream_id: The stream's ID. :param error_code: An error code indicating why the stream is being reset. """ stream = self._get_or_create_stream_for_send(stream_id) stream.sender.reset(error_code) def send_ping(self, uid: int) -> None: """ Send a PING frame to the peer. .. aioquic_transmit:: :param uid: A unique ID for this PING. """ self._ping_pending.append(uid) def send_datagram_frame(self, data: bytes) -> None: """ Send a DATAGRAM frame. .. aioquic_transmit:: :param data: The data to be sent. """ self._datagrams_pending.append(data) def send_stream_data( self, stream_id: int, data: bytes, end_stream: bool = False ) -> None: """ Send data on the specific stream. .. aioquic_transmit:: :param stream_id: The stream's ID. :param data: The data to be sent. :param end_stream: If set to `True`, the FIN bit will be set. """ stream = self._get_or_create_stream_for_send(stream_id) stream.sender.write(data, end_stream=end_stream) def stop_stream(self, stream_id: int, error_code: int) -> None: """ Request termination of the receiving part of a stream. .. aioquic_transmit:: :param stream_id: The stream's ID. :param error_code: An error code indicating why the stream is being stopped. """ if not self._stream_can_receive(stream_id): raise ValueError( "Cannot stop receiving on a local-initiated unidirectional stream" ) stream = self._streams.get(stream_id, None) if stream is None: raise ValueError("Cannot stop receiving on an unknown stream") stream.receiver.stop(error_code) # Private def _alpn_handler(self, alpn_protocol: str) -> None: """ Callback which is invoked by the TLS engine at most once, when the ALPN negotiation completes. At this point, TLS extensions have been received so we can parse the transport parameters. """ # Parse the remote transport parameters. for ext_type, ext_data in self.tls.received_extensions: if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: self._parse_transport_parameters(ext_data) break else: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_ERROR + tls.AlertDescription.missing_extension, frame_type=self._crypto_frame_type, reason_phrase="No QUIC transport parameters received", ) # For servers, determine the Negotiated Version. if not self._is_client and not self._version_negotiated_compatible: if self._remote_version_information is not None: # Pick the first version we support in the client's available versions, # which is compatible with the current version. for version in self._remote_version_information.available_versions: if version == self._version: # Stay with the current version. break elif ( version in self._configuration.supported_versions and is_version_compatible(self._version, version) ): # Change version. self._version = version self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[ version ] # Update our transport parameters to reflect the chosen version. self.tls.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, self._serialize_transport_parameters(), ) ] break self._version_negotiated_compatible = True self._logger.info( "Negotiated protocol version %s", pretty_protocol_version(self._version) ) # Notify the application. self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None: """ Check the specified stream can receive data or raises a QuicConnectionError. """ if not self._stream_can_receive(stream_id): raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Stream is send-only", ) def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None: """ Check the specified stream can send data or raises a QuicConnectionError. """ if not self._stream_can_send(stream_id): raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Stream is receive-only", ) def _consume_peer_cid(self) -> None: """ Update the destination connection ID by taking the next available connection ID provided by the peer. """ self._peer_cid = self._peer_cid_available.pop(0) self._logger.debug( "Switching to CID %s (%d)", dump_cid(self._peer_cid.cid), self._peer_cid.sequence_number, ) def _close_begin(self, is_initiator: bool, now: float) -> None: """ Begin the close procedure. """ self._close_at = now + 3 * self._loss.get_probe_timeout() if is_initiator: self._set_state(QuicConnectionState.CLOSING) else: self._set_state(QuicConnectionState.DRAINING) def _close_end(self) -> None: """ End the close procedure. """ self._close_at = None for epoch in self._spaces.keys(): self._discard_epoch(epoch) self._events.append(self._close_event) self._set_state(QuicConnectionState.TERMINATED) # signal log end if self._quic_logger is not None: self._configuration.quic_logger.end_trace(self._quic_logger) self._quic_logger = None def _connect(self, now: float) -> None: """ Start the client handshake. """ assert self._is_client if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="version_information", data={ "client_versions": self._configuration.supported_versions, "chosen_version": self._version, }, ) self._quic_logger.log_event( category="transport", event="alpn_information", data={"client_alpns": self._configuration.alpn_protocols}, ) self._close_at = now + self._idle_timeout() self._initialize(self._peer_cid.cid) self.tls.handle_message(b"", self._crypto_buffers) self._push_crypto_data() def _discard_epoch(self, epoch: tls.Epoch) -> None: if not self._spaces[epoch].discarded: self._logger.debug("Discarding epoch %s", epoch) self._cryptos[epoch].teardown() if epoch == tls.Epoch.INITIAL: # Tear the crypto pairs, but do not log the event, # to avoid duplicate log entries. for crypto in self._cryptos_initial.values(): crypto.recv._teardown_cb = NoCallback crypto.send._teardown_cb = NoCallback crypto.teardown() self._loss.discard_space(self._spaces[epoch]) self._spaces[epoch].discarded = True def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath: # check existing network paths for idx, network_path in enumerate(self._network_paths): if network_path.addr == addr: return network_path # new network path network_path = QuicNetworkPath(addr) self._logger.debug("Network path %s discovered", network_path.addr) return network_path def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream: """ Get or create a stream in response to a received frame. """ if stream_id in self._streams_finished: # the stream was created, but its state was since discarded raise StreamFinishedError stream = self._streams.get(stream_id, None) if stream is None: # check initiator if stream_is_client_initiated(stream_id) == self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Wrong stream initiator", ) # determine limits if stream_is_unidirectional(stream_id): max_stream_data_local = self._local_max_stream_data_uni max_stream_data_remote = 0 max_streams = self._local_max_streams_uni else: max_stream_data_local = self._local_max_stream_data_bidi_remote max_stream_data_remote = self._remote_max_stream_data_bidi_local max_streams = self._local_max_streams_bidi # check max streams stream_count = (stream_id // 4) + 1 if stream_count > max_streams.value: raise QuicConnectionError( error_code=QuicErrorCode.STREAM_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many streams open", ) elif stream_count > max_streams.used: max_streams.used = stream_count # create stream self._logger.debug("Stream %d created by peer" % stream_id) stream = self._streams[stream_id] = QuicStream( stream_id=stream_id, max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, writable=not stream_is_unidirectional(stream_id), ) self._streams_queue.append(stream) return stream def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream: """ Get or create a QUIC stream in order to send data to the peer. This always occurs as a result of an API call. """ if not self._stream_can_send(stream_id): raise ValueError("Cannot send data on peer-initiated unidirectional stream") stream = self._streams.get(stream_id, None) if stream is None: # check initiator if stream_is_client_initiated(stream_id) != self._is_client: raise ValueError("Cannot send data on unknown peer-initiated stream") # determine limits if stream_is_unidirectional(stream_id): max_stream_data_local = 0 max_stream_data_remote = self._remote_max_stream_data_uni max_streams = self._remote_max_streams_uni streams_blocked = self._streams_blocked_uni else: max_stream_data_local = self._local_max_stream_data_bidi_local max_stream_data_remote = self._remote_max_stream_data_bidi_remote max_streams = self._remote_max_streams_bidi streams_blocked = self._streams_blocked_bidi # create stream is_unidirectional = stream_is_unidirectional(stream_id) stream = self._streams[stream_id] = QuicStream( stream_id=stream_id, max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, readable=not is_unidirectional, ) self._streams_queue.append(stream) if is_unidirectional: self._local_next_stream_id_uni = stream_id + 4 else: self._local_next_stream_id_bidi = stream_id + 4 # mark stream as blocked if needed if stream_id // 4 >= max_streams: stream.is_blocked = True streams_blocked.append(stream) self._streams_blocked_pending = True return stream def _handle_session_ticket(self, session_ticket: tls.SessionTicket) -> None: if ( session_ticket.max_early_data_size is not None and session_ticket.max_early_data_size != MAX_EARLY_DATA ): raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.CRYPTO, reason_phrase="Invalid max_early_data value %s" % session_ticket.max_early_data_size, ) self._session_ticket_handler(session_ticket) def _initialize(self, peer_cid: bytes) -> None: # TLS self.tls = tls.Context( alpn_protocols=self._configuration.alpn_protocols, cadata=self._configuration.cadata, cafile=self._configuration.cafile, capath=self._configuration.capath, cipher_suites=self.configuration.cipher_suites, is_client=self._is_client, logger=self._logger, max_early_data=None if self._is_client else MAX_EARLY_DATA, server_name=self._configuration.server_name, verify_mode=self._configuration.verify_mode, ) self.tls.certificate = self._configuration.certificate self.tls.certificate_chain = self._configuration.certificate_chain self.tls.certificate_private_key = self._configuration.private_key self.tls.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, self._serialize_transport_parameters(), ) ] # TLS session resumption session_ticket = self._configuration.session_ticket if ( self._is_client and session_ticket is not None and session_ticket.is_valid and session_ticket.server_name == self._configuration.server_name ): self.tls.session_ticket = self._configuration.session_ticket # parse saved QUIC transport parameters - for 0-RTT if session_ticket.max_early_data_size == MAX_EARLY_DATA: for ext_type, ext_data in session_ticket.other_extensions: if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: self._parse_transport_parameters( ext_data, from_session_ticket=True ) break # TLS callbacks self.tls.alpn_cb = self._alpn_handler if self._session_ticket_fetcher is not None: self.tls.get_session_ticket_cb = self._session_ticket_fetcher if self._session_ticket_handler is not None: self.tls.new_session_ticket_cb = self._handle_session_ticket self.tls.update_traffic_key_cb = self._update_traffic_key # packet spaces def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: epoch_name = ["initial", "0rtt", "handshake", "1rtt"][epoch.value] secret_names = [ "server_%s_secret" % epoch_name, "client_%s_secret" % epoch_name, ] recv_secret_name = secret_names[not self._is_client] send_secret_name = secret_names[self._is_client] return CryptoPair( recv_setup_cb=partial(self._log_key_updated, recv_secret_name), recv_teardown_cb=partial(self._log_key_retired, recv_secret_name), send_setup_cb=partial(self._log_key_updated, send_secret_name), send_teardown_cb=partial(self._log_key_retired, send_secret_name), ) # To enable version negotiation, setup encryption keys for all # our supported versions. self._cryptos_initial = {} for version in self._configuration.supported_versions: pair = CryptoPair() pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version) self._cryptos_initial[version] = pair self._cryptos = dict( (epoch, create_crypto_pair(epoch)) for epoch in ( tls.Epoch.ZERO_RTT, tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT, ) ) self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version] self._crypto_buffers = { tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.ONE_RTT: Buffer(capacity=CRYPTO_BUFFER_SIZE), } self._crypto_streams = { tls.Epoch.INITIAL: QuicStream(), tls.Epoch.HANDSHAKE: QuicStream(), tls.Epoch.ONE_RTT: QuicStream(), } self._spaces = { tls.Epoch.INITIAL: QuicPacketSpace(), tls.Epoch.HANDSHAKE: QuicPacketSpace(), tls.Epoch.ONE_RTT: QuicPacketSpace(), } self._loss.spaces = list(self._spaces.values()) def _handle_ack_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle an ACK frame. """ ack_rangeset, ack_delay_encoded = pull_ack_frame(buf) if frame_type == QuicFrameType.ACK_ECN: buf.pull_uint_var() buf.pull_uint_var() buf.pull_uint_var() ack_delay = (ack_delay_encoded << self._remote_ack_delay_exponent) / 1000000 # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_ack_frame(ack_rangeset, ack_delay) ) # check whether peer completed address validation if not self._loss.peer_completed_address_validation and context.epoch in ( tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT, ): self._loss.peer_completed_address_validation = True self._loss.on_ack_received( ack_rangeset=ack_rangeset, ack_delay=ack_delay, now=context.time, space=self._spaces[context.epoch], ) def _handle_connection_close_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a CONNECTION_CLOSE frame. """ error_code = buf.pull_uint_var() if frame_type == QuicFrameType.TRANSPORT_CLOSE: frame_type = buf.pull_uint_var() else: frame_type = None reason_length = buf.pull_uint_var() try: reason_phrase = buf.pull_bytes(reason_length).decode("utf8") except UnicodeDecodeError: reason_phrase = "" # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_close_frame( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) ) self._logger.info( "Connection close received (code 0x%X, reason %s)", error_code, reason_phrase, ) if self._close_event is None: self._close_event = events.ConnectionTerminated( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) self._close_begin(is_initiator=False, now=context.time) def _handle_crypto_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a CRYPTO frame. """ offset = buf.pull_uint_var() length = buf.pull_uint_var() if offset + length > UINT_VAR_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="offset + length cannot exceed 2^62 - 1", ) frame = QuicStreamFrame(offset=offset, data=buf.pull_bytes(length)) # Log the frame. if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_crypto_frame(frame) ) stream = self._crypto_streams[context.epoch] pending = offset + length - stream.receiver.starting_offset() if pending > MAX_PENDING_CRYPTO: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_BUFFER_EXCEEDED, frame_type=frame_type, reason_phrase="too much crypto buffering", ) event = stream.receiver.handle_frame(frame) if event is not None: # Pass data to TLS layer, which may cause calls to: # - _alpn_handler # - _update_traffic_key self._crypto_frame_type = frame_type self._crypto_packet_version = context.version try: self.tls.handle_message(event.data, self._crypto_buffers) self._push_crypto_data() except tls.Alert as exc: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_ERROR + int(exc.description), frame_type=frame_type, reason_phrase=str(exc), ) # Update the current epoch. if not self._handshake_complete and self.tls.state in [ tls.State.CLIENT_POST_HANDSHAKE, tls.State.SERVER_POST_HANDSHAKE, ]: self._handshake_complete = True # for servers, the handshake is now confirmed if not self._is_client: self._discard_epoch(tls.Epoch.HANDSHAKE) self._handshake_confirmed = True self._handshake_done_pending = True self._replenish_connection_ids() self._events.append( events.HandshakeCompleted( alpn_protocol=self.tls.alpn_negotiated, early_data_accepted=self.tls.early_data_accepted, session_resumed=self.tls.session_resumed, ) ) self._unblock_streams(is_unidirectional=False) self._unblock_streams(is_unidirectional=True) self._logger.info( "ALPN negotiated protocol %s", self.tls.alpn_negotiated ) else: self._logger.info( "Duplicate CRYPTO data received for epoch %s", context.epoch ) # If a server receives duplicate CRYPTO in an INITIAL packet, # it can assume the client did not receive the server's CRYPTO. if ( not self._is_client and context.epoch == tls.Epoch.INITIAL and not self._crypto_retransmitted ): self._loss.reschedule_data(now=context.time) self._crypto_retransmitted = True def _handle_data_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a DATA_BLOCKED frame. """ limit = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_data_blocked_frame(limit=limit) ) def _handle_datagram_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a DATAGRAM frame. """ start = buf.tell() if frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH: length = buf.pull_uint_var() else: length = buf.capacity - start data = buf.pull_bytes(length) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_datagram_frame(length=length) ) # check frame is allowed if ( self._configuration.max_datagram_frame_size is None or buf.tell() - start >= self._configuration.max_datagram_frame_size ): raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Unexpected DATAGRAM frame", ) self._events.append(events.DatagramFrameReceived(data=data)) def _handle_handshake_done_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a HANDSHAKE_DONE frame. """ # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_handshake_done_frame() ) if not self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Clients must not send HANDSHAKE_DONE frames", ) # for clients, the handshake is now confirmed if not self._handshake_confirmed: self._discard_epoch(tls.Epoch.HANDSHAKE) self._handshake_confirmed = True self._loss.peer_completed_address_validation = True def _handle_max_data_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_DATA frame. This adjusts the total amount of we can send to the peer. """ max_data = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_data ) ) if max_data > self._remote_max_data: self._logger.debug("Remote max_data raised to %d", max_data) self._remote_max_data = max_data def _handle_max_stream_data_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAM_DATA frame. This adjusts the amount of data we can send on a specific stream. """ stream_id = buf.pull_uint_var() max_stream_data = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_max_stream_data_frame( maximum=max_stream_data, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_send(frame_type, stream_id) stream = self._get_or_create_stream(frame_type, stream_id) if max_stream_data > stream.max_stream_data_remote: self._logger.debug( "Stream %d remote max_stream_data raised to %d", stream_id, max_stream_data, ) stream.max_stream_data_remote = max_stream_data def _handle_max_streams_bidi_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAMS_BIDI frame. This raises number of bidirectional streams we can initiate to the peer. """ max_streams = buf.pull_uint_var() if max_streams > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_streams ) ) if max_streams > self._remote_max_streams_bidi: self._logger.debug("Remote max_streams_bidi raised to %d", max_streams) self._remote_max_streams_bidi = max_streams self._unblock_streams(is_unidirectional=False) def _handle_max_streams_uni_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAMS_UNI frame. This raises number of unidirectional streams we can initiate to the peer. """ max_streams = buf.pull_uint_var() if max_streams > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_streams ) ) if max_streams > self._remote_max_streams_uni: self._logger.debug("Remote max_streams_uni raised to %d", max_streams) self._remote_max_streams_uni = max_streams self._unblock_streams(is_unidirectional=True) def _handle_new_connection_id_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a NEW_CONNECTION_ID frame. """ sequence_number = buf.pull_uint_var() retire_prior_to = buf.pull_uint_var() length = buf.pull_uint8() connection_id = buf.pull_bytes(length) stateless_reset_token = buf.pull_bytes(STATELESS_RESET_TOKEN_SIZE) if not connection_id or len(connection_id) > CONNECTION_ID_MAX_SIZE: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Length must be greater than 0 and less than 20", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_new_connection_id_frame( connection_id=connection_id, retire_prior_to=retire_prior_to, sequence_number=sequence_number, stateless_reset_token=stateless_reset_token, ) ) # sanity check if retire_prior_to > sequence_number: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Retire Prior To is greater than Sequence Number", ) # only accept retire_prior_to if it is bigger than the one we know self._peer_retire_prior_to = max(retire_prior_to, self._peer_retire_prior_to) # determine which CIDs to retire change_cid = False retire = [ cid for cid in self._peer_cid_available if cid.sequence_number < self._peer_retire_prior_to ] if self._peer_cid.sequence_number < self._peer_retire_prior_to: change_cid = True retire.insert(0, self._peer_cid) # update available CIDs self._peer_cid_available = [ cid for cid in self._peer_cid_available if cid.sequence_number >= self._peer_retire_prior_to ] if ( sequence_number >= self._peer_retire_prior_to and sequence_number not in self._peer_cid_sequence_numbers ): self._peer_cid_available.append( QuicConnectionId( cid=connection_id, sequence_number=sequence_number, stateless_reset_token=stateless_reset_token, ) ) self._peer_cid_sequence_numbers.add(sequence_number) # retire previous CIDs for quic_connection_id in retire: self._retire_peer_cid(quic_connection_id) # assign new CID if we retired the active one if change_cid: self._consume_peer_cid() # check number of active connection IDs, including the selected one if 1 + len(self._peer_cid_available) > self._local_active_connection_id_limit: raise QuicConnectionError( error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many active connection IDs", ) # Check the number of retired connection IDs pending, though with a safer limit # than the 2x recommended in section 5.1.2 of the RFC. Note that we are doing # the check here and not in _retire_peer_cid() because we know the frame type to # use here, and because it is the new connection id path that is potentially # dangerous. We may transiently go a bit over the limit due to unacked frames # getting added back to the list, but that's ok as it is bounded. if len(self._retire_connection_ids) > min( self._local_active_connection_id_limit * 4, MAX_PENDING_RETIRES ): raise QuicConnectionError( error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many pending retired connection IDs", ) def _handle_new_token_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a NEW_TOKEN frame. """ length = buf.pull_uint_var() token = buf.pull_bytes(length) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_new_token_frame(token=token) ) if not self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Clients must not send NEW_TOKEN frames", ) if self._token_handler is not None: self._token_handler(token) def _handle_padding_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PADDING frame. """ # consume padding pos = buf.tell() for byte in buf.data_slice(pos, buf.capacity): if byte: break pos += 1 buf.seek(pos) # log frame if self._quic_logger is not None: context.quic_logger_frames.append(self._quic_logger.encode_padding_frame()) def _handle_path_challenge_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PATH_CHALLENGE frame. """ data = buf.pull_bytes(8) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_path_challenge_frame(data=data) ) context.network_path.remote_challenges.append(data) def _handle_path_response_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PATH_RESPONSE frame. """ data = buf.pull_bytes(8) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_path_response_frame(data=data) ) try: network_path = self._local_challenges.pop(data) except KeyError: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Response does not match challenge", ) self._logger.debug("Network path %s validated by challenge", network_path.addr) network_path.is_validated = True def _handle_ping_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PING frame. """ # log frame if self._quic_logger is not None: context.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) def _handle_reset_stream_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a RESET_STREAM frame. """ stream_id = buf.pull_uint_var() error_code = buf.pull_uint_var() final_size = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_reset_stream_frame( error_code=error_code, final_size=final_size, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) # check flow-control limits stream = self._get_or_create_stream(frame_type, stream_id) if final_size > stream.max_stream_data_local: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over stream data limit", ) newly_received = max(0, final_size - stream.receiver.highest_offset) if self._local_max_data.used + newly_received > self._local_max_data.value: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over connection data limit", ) # process reset self._logger.info( "Stream %d reset by peer (error code %d, final size %d)", stream_id, error_code, final_size, ) try: event = stream.receiver.handle_reset( error_code=error_code, final_size=final_size ) except FinalSizeError as exc: raise QuicConnectionError( error_code=QuicErrorCode.FINAL_SIZE_ERROR, frame_type=frame_type, reason_phrase=str(exc), ) if event is not None: self._events.append(event) self._local_max_data.used += newly_received def _handle_retire_connection_id_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a RETIRE_CONNECTION_ID frame. """ sequence_number = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_retire_connection_id_frame(sequence_number) ) if sequence_number >= self._host_cid_seq: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Cannot retire unknown connection ID", ) # find the connection ID by sequence number for index, connection_id in enumerate(self._host_cids): if connection_id.sequence_number == sequence_number: if connection_id.cid == context.host_cid: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Cannot retire current connection ID", ) self._logger.debug( "Peer retiring CID %s (%d)", dump_cid(connection_id.cid), connection_id.sequence_number, ) del self._host_cids[index] self._events.append( events.ConnectionIdRetired(connection_id=connection_id.cid) ) break # issue a new connection ID self._replenish_connection_ids() def _handle_stop_sending_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STOP_SENDING frame. """ stream_id = buf.pull_uint_var() error_code = buf.pull_uint_var() # application error code # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stop_sending_frame( error_code=error_code, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_send(frame_type, stream_id) # reset the stream stream = self._get_or_create_stream(frame_type, stream_id) stream.sender.reset(error_code=QuicErrorCode.NO_ERROR) self._events.append( events.StopSendingReceived(error_code=error_code, stream_id=stream_id) ) def _handle_stream_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAM frame. """ stream_id = buf.pull_uint_var() if frame_type & 4: offset = buf.pull_uint_var() else: offset = 0 if frame_type & 2: length = buf.pull_uint_var() else: length = buf.capacity - buf.tell() if offset + length > UINT_VAR_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="offset + length cannot exceed 2^62 - 1", ) frame = QuicStreamFrame( offset=offset, data=buf.pull_bytes(length), fin=bool(frame_type & 1) ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stream_frame(frame, stream_id=stream_id) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) # check flow-control limits stream = self._get_or_create_stream(frame_type, stream_id) if offset + length > stream.max_stream_data_local: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over stream data limit", ) newly_received = max(0, offset + length - stream.receiver.highest_offset) if self._local_max_data.used + newly_received > self._local_max_data.value: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over connection data limit", ) # process data try: event = stream.receiver.handle_frame(frame) except FinalSizeError as exc: raise QuicConnectionError( error_code=QuicErrorCode.FINAL_SIZE_ERROR, frame_type=frame_type, reason_phrase=str(exc), ) if event is not None: self._events.append(event) self._local_max_data.used += newly_received def _handle_stream_data_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAM_DATA_BLOCKED frame. """ stream_id = buf.pull_uint_var() limit = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stream_data_blocked_frame( limit=limit, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) self._get_or_create_stream(frame_type, stream_id) def _handle_streams_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAMS_BLOCKED frame. """ limit = buf.pull_uint_var() if limit > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_streams_blocked_frame( is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, limit=limit, ) ) def _log_key_retired(self, key_type: str, trigger: str) -> None: """ Log a key retirement. """ if self._quic_logger is not None: self._quic_logger.log_event( category="security", event="key_retired", data={"key_type": key_type, "trigger": trigger}, ) def _log_key_updated(self, key_type: str, trigger: str) -> None: """ Log a key update. """ if self._quic_logger is not None: self._quic_logger.log_event( category="security", event="key_updated", data={"key_type": key_type, "trigger": trigger}, ) def _on_ack_delivery( self, delivery: QuicDeliveryState, space: QuicPacketSpace, highest_acked: int ) -> None: """ Callback when an ACK frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: space.ack_queue.subtract(0, highest_acked + 1) def _on_connection_limit_delivery( self, delivery: QuicDeliveryState, limit: Limit ) -> None: """ Callback when a MAX_DATA or MAX_STREAMS frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: limit.sent = 0 def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None: """ Callback when a HANDSHAKE_DONE frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: self._handshake_done_pending = True def _on_max_stream_data_delivery( self, delivery: QuicDeliveryState, stream: QuicStream ) -> None: """ Callback when a MAX_STREAM_DATA frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: stream.max_stream_data_local_sent = 0 def _on_new_connection_id_delivery( self, delivery: QuicDeliveryState, connection_id: QuicConnectionId ) -> None: """ Callback when a NEW_CONNECTION_ID frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: connection_id.was_sent = False def _on_ping_delivery( self, delivery: QuicDeliveryState, uids: Sequence[int] ) -> None: """ Callback when a PING frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: self._logger.debug("Received PING%s response", "" if uids else " (probe)") for uid in uids: self._events.append(events.PingAcknowledged(uid=uid)) else: self._ping_pending.extend(uids) def _on_retire_connection_id_delivery( self, delivery: QuicDeliveryState, sequence_number: int ) -> None: """ Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: self._retire_connection_ids.append(sequence_number) def _payload_received( self, context: QuicReceiveContext, plain: bytes, crypto_frame_required: bool = False, ) -> Tuple[bool, bool]: """ Handle a QUIC packet payload. """ buf = Buffer(data=plain) crypto_frame_found = False frame_found = False is_ack_eliciting = False is_probing = None while not buf.eof(): # get frame type try: frame_type = buf.pull_uint_var() except BufferReadError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=None, reason_phrase="Malformed frame type", ) # check frame type is known try: frame_handler, frame_epochs = self.__frame_handlers[frame_type] except KeyError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Unknown frame type", ) # check frame type is allowed for the epoch if context.epoch not in frame_epochs: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Unexpected frame type", ) # handle the frame try: frame_handler(context, frame_type, buf) except BufferReadError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Failed to parse frame", ) except StreamFinishedError: # we lack the state for the stream, ignore the frame pass # update ACK only / probing flags frame_found = True if frame_type == QuicFrameType.CRYPTO: crypto_frame_found = True if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: is_ack_eliciting = True if frame_type not in PROBING_FRAME_TYPES: is_probing = False elif is_probing is None: is_probing = True if not frame_found: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Packet contains no frames", ) # RFC 9000 - 17.2.2. Initial Packet # The first packet sent by a client always includes a CRYPTO frame. if crypto_frame_required and not crypto_frame_found: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Packet contains no CRYPTO frame", ) return is_ack_eliciting, bool(is_probing) def _receive_retry_packet( self, header: QuicHeader, packet_without_tag: bytes, now: float ) -> None: """ Handle a retry packet. """ if ( self._is_client and not self._retry_count and header.destination_cid == self.host_cid and header.integrity_tag == get_retry_integrity_tag( packet_without_tag, self._peer_cid.cid, version=header.version, ) ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": [], "header": { "packet_type": "retry", "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, "raw": {"length": header.packet_length}, }, ) self._peer_cid.cid = header.source_cid self._peer_token = header.token self._retry_count += 1 self._retry_source_connection_id = header.source_cid self._logger.info("Retrying with token (%d bytes)" % len(header.token)) self._connect(now=now) else: # Unexpected or invalid retry packet. if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unexpected_packet", "raw": {"length": header.packet_length}, }, ) def _receive_version_negotiation_packet( self, header: QuicHeader, now: float ) -> None: """ Handle a version negotiation packet. This is used in "Incompatible Version Negotiation", see: https://datatracker.ietf.org/doc/html/rfc9368#section-2.2 """ # Only clients process Version Negotiation, and once a Version # Negotiation packet has been acted upon, any further # such packets must be ignored. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if ( self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT and not self._version_negotiated_incompatible ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": [], "header": { "packet_type": self._quic_logger.packet_type( header.packet_type ), "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, "raw": {"length": header.packet_length}, }, ) # Ignore any Version Negotiation packets that contain the # original version. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if self._version in header.supported_versions: self._logger.warning( "Version negotiation packet contains protocol version %s", pretty_protocol_version(self._version), ) return # Look for a common protocol version. common = [ x for x in self._configuration.supported_versions if x in header.supported_versions ] # Look for a common protocol version. chosen_version = common[0] if common else None if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="version_information", data={ "server_versions": header.supported_versions, "client_versions": self._configuration.supported_versions, "chosen_version": chosen_version, }, ) if chosen_version is None: self._logger.error("Could not find a common protocol version") self._close_event = events.ConnectionTerminated( error_code=QuicErrorCode.INTERNAL_ERROR, frame_type=QuicFrameType.PADDING, reason_phrase="Could not find a common protocol version", ) self._close_end() return self._packet_number = 0 self._version = chosen_version self._version_negotiated_incompatible = True self._logger.info( "Retrying with protocol version %s", pretty_protocol_version(self._version), ) self._connect(now=now) else: # Unexpected version negotiation packet. if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unexpected_packet", "raw": {"length": header.packet_length}, }, ) def _replenish_connection_ids(self) -> None: """ Generate new connection IDs. """ while len(self._host_cids) < min(8, self._remote_active_connection_id_limit): self._host_cids.append( QuicConnectionId( cid=os.urandom(self._configuration.connection_id_length), sequence_number=self._host_cid_seq, stateless_reset_token=os.urandom(16), ) ) self._host_cid_seq += 1 def _retire_peer_cid(self, connection_id: QuicConnectionId) -> None: """ Retire a destination connection ID. """ self._logger.debug( "Retiring CID %s (%d) [%d]", dump_cid(connection_id.cid), connection_id.sequence_number, len(self._retire_connection_ids) + 1, ) self._retire_connection_ids.append(connection_id.sequence_number) def _push_crypto_data(self) -> None: for epoch, buf in self._crypto_buffers.items(): self._crypto_streams[epoch].sender.write(buf.data) buf.seek(0) def _send_probe(self) -> None: self._probe_pending = True def _parse_transport_parameters( self, data: bytes, from_session_ticket: bool = False ) -> None: """ Parse and apply remote transport parameters. `from_session_ticket` is `True` when restoring saved transport parameters, and `False` when handling received transport parameters. """ try: quic_transport_parameters = pull_quic_transport_parameters( Buffer(data=data) ) except ValueError: raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="Could not parse QUIC transport parameters", ) # log event if self._quic_logger is not None and not from_session_ticket: self._quic_logger.log_event( category="transport", event="parameters_set", data=self._quic_logger.encode_transport_parameters( owner="remote", parameters=quic_transport_parameters ), ) # Validate remote parameters. if not self._is_client: for attr in [ "original_destination_connection_id", "preferred_address", "retry_source_connection_id", "stateless_reset_token", ]: if getattr(quic_transport_parameters, attr) is not None: raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="%s is not allowed for clients" % attr, ) if not from_session_ticket: if ( quic_transport_parameters.initial_source_connection_id != self._remote_initial_source_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="initial_source_connection_id does not match", ) if self._is_client and ( quic_transport_parameters.original_destination_connection_id != self._original_destination_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="original_destination_connection_id does not match", ) if self._is_client and ( quic_transport_parameters.retry_source_connection_id != self._retry_source_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="retry_source_connection_id does not match", ) if ( quic_transport_parameters.active_connection_id_limit is not None and quic_transport_parameters.active_connection_id_limit < 2 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="active_connection_id_limit must be no less than 2", ) if ( quic_transport_parameters.ack_delay_exponent is not None and quic_transport_parameters.ack_delay_exponent > 20 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="ack_delay_exponent must be <= 20", ) if ( quic_transport_parameters.max_ack_delay is not None and quic_transport_parameters.max_ack_delay >= 2**14 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="max_ack_delay must be < 2^14", ) if quic_transport_parameters.max_udp_payload_size is not None and ( quic_transport_parameters.max_udp_payload_size < SMALLEST_MAX_DATAGRAM_SIZE ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=( f"max_udp_payload_size must be >= {SMALLEST_MAX_DATAGRAM_SIZE}" ), ) # Validate Version Information extension. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if quic_transport_parameters.version_information is not None: version_information = quic_transport_parameters.version_information # If a server receives Version Information where the Chosen Version # is not included in Available Versions, it MUST treat is as a # parsing failure. if ( not self._is_client and version_information.chosen_version not in version_information.available_versions ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=( "version_information's chosen_version is not included " "in available_versions" ), ) # Validate that the Chosen Version matches the version in use for the # connection. if version_information.chosen_version != self._crypto_packet_version: raise QuicConnectionError( error_code=QuicErrorCode.VERSION_NEGOTIATION_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=( "version_information's chosen_version does not match " "the version in use" ), ) # Store remote parameters. if not from_session_ticket: if quic_transport_parameters.ack_delay_exponent is not None: self._remote_ack_delay_exponent = self._remote_ack_delay_exponent if quic_transport_parameters.max_ack_delay is not None: self._loss.max_ack_delay = ( quic_transport_parameters.max_ack_delay / 1000.0 ) if ( self._is_client and self._peer_cid.sequence_number == 0 and quic_transport_parameters.stateless_reset_token is not None ): self._peer_cid.stateless_reset_token = ( quic_transport_parameters.stateless_reset_token ) self._remote_version_information = ( quic_transport_parameters.version_information ) if quic_transport_parameters.active_connection_id_limit is not None: self._remote_active_connection_id_limit = ( quic_transport_parameters.active_connection_id_limit ) if quic_transport_parameters.max_idle_timeout is not None: self._remote_max_idle_timeout = ( quic_transport_parameters.max_idle_timeout / 1000.0 ) self._remote_max_datagram_frame_size = ( quic_transport_parameters.max_datagram_frame_size ) for param in [ "max_data", "max_stream_data_bidi_local", "max_stream_data_bidi_remote", "max_stream_data_uni", "max_streams_bidi", "max_streams_uni", ]: value = getattr(quic_transport_parameters, "initial_" + param) if value is not None: setattr(self, "_remote_" + param, value) def _serialize_transport_parameters(self) -> bytes: quic_transport_parameters = QuicTransportParameters( ack_delay_exponent=self._local_ack_delay_exponent, active_connection_id_limit=self._local_active_connection_id_limit, max_idle_timeout=int(self._configuration.idle_timeout * 1000), initial_max_data=self._local_max_data.value, initial_max_stream_data_bidi_local=self._local_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote=self._local_max_stream_data_bidi_remote, initial_max_stream_data_uni=self._local_max_stream_data_uni, initial_max_streams_bidi=self._local_max_streams_bidi.value, initial_max_streams_uni=self._local_max_streams_uni.value, initial_source_connection_id=self._local_initial_source_connection_id, max_ack_delay=25, max_datagram_frame_size=self._configuration.max_datagram_frame_size, quantum_readiness=( b"Q" * SMALLEST_MAX_DATAGRAM_SIZE if self._configuration.quantum_readiness_test else None ), stateless_reset_token=self._host_cids[0].stateless_reset_token, version_information=QuicVersionInformation( chosen_version=self._version, available_versions=self._configuration.supported_versions, ), ) if not self._is_client: quic_transport_parameters.original_destination_connection_id = ( self._original_destination_connection_id ) quic_transport_parameters.retry_source_connection_id = ( self._retry_source_connection_id ) # log event if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="parameters_set", data=self._quic_logger.encode_transport_parameters( owner="local", parameters=quic_transport_parameters ), ) buf = Buffer(capacity=3 * self._max_datagram_size) push_quic_transport_parameters(buf, quic_transport_parameters) return buf.data def _set_state(self, state: QuicConnectionState) -> None: self._logger.debug("%s -> %s", self._state, state) self._state = state def _stream_can_receive(self, stream_id: int) -> bool: return stream_is_client_initiated( stream_id ) != self._is_client or not stream_is_unidirectional(stream_id) def _stream_can_send(self, stream_id: int) -> bool: return stream_is_client_initiated( stream_id ) == self._is_client or not stream_is_unidirectional(stream_id) def _unblock_streams(self, is_unidirectional: bool) -> None: if is_unidirectional: max_stream_data_remote = self._remote_max_stream_data_uni max_streams = self._remote_max_streams_uni streams_blocked = self._streams_blocked_uni else: max_stream_data_remote = self._remote_max_stream_data_bidi_remote max_streams = self._remote_max_streams_bidi streams_blocked = self._streams_blocked_bidi while streams_blocked and streams_blocked[0].stream_id // 4 < max_streams: stream = streams_blocked.pop(0) stream.is_blocked = False stream.max_stream_data_remote = max_stream_data_remote if not self._streams_blocked_bidi and not self._streams_blocked_uni: self._streams_blocked_pending = False def _update_traffic_key( self, direction: tls.Direction, epoch: tls.Epoch, cipher_suite: tls.CipherSuite, secret: bytes, ) -> None: """ Callback which is invoked by the TLS engine when new traffic keys are available. """ # For clients, determine the negotiated protocol version. if ( self._is_client and self._crypto_packet_version is not None and not self._version_negotiated_compatible ): self._version = self._crypto_packet_version self._version_negotiated_compatible = True self._logger.info( "Negotiated protocol version %s", pretty_protocol_version(self._version) ) secrets_log_file = self._configuration.secrets_log_file if secrets_log_file is not None: label_row = self._is_client == (direction == tls.Direction.DECRYPT) label = SECRETS_LABELS[label_row][epoch.value] secrets_log_file.write( "%s %s %s\n" % (label, self.tls.client_random.hex(), secret.hex()) ) secrets_log_file.flush() crypto = self._cryptos[epoch] if direction == tls.Direction.ENCRYPT: crypto.send.setup( cipher_suite=cipher_suite, secret=secret, version=self._version ) else: crypto.recv.setup( cipher_suite=cipher_suite, secret=secret, version=self._version ) def _add_local_challenge(self, challenge: bytes, network_path: QuicNetworkPath): self._local_challenges[challenge] = network_path while len(self._local_challenges) > MAX_LOCAL_CHALLENGES: # Dictionaries are ordered, so pop the first key until we are below the # limit. key = next(iter(self._local_challenges.keys())) del self._local_challenges[key] def _write_application( self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float ) -> None: crypto_stream: Optional[QuicStream] = None if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ONE_RTT] crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] packet_type = QuicPacketType.ONE_RTT elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ZERO_RTT] packet_type = QuicPacketType.ZERO_RTT else: return space = self._spaces[tls.Epoch.ONE_RTT] while True: # apply pacing, except if we have ACKs to send if space.ack_at is None or space.ack_at >= now: self._pacing_at = self._loss._pacer.next_send_time(now=now) if self._pacing_at is not None: break builder.start_packet(packet_type, crypto) if self._handshake_complete: # ACK if space.ack_at is not None and space.ack_at <= now: self._write_ack_frame(builder=builder, space=space, now=now) # HANDSHAKE_DONE if self._handshake_done_pending: self._write_handshake_done_frame(builder=builder) self._handshake_done_pending = False # PATH CHALLENGE if not (network_path.is_validated or network_path.local_challenge_sent): challenge = os.urandom(8) self._add_local_challenge( challenge=challenge, network_path=network_path ) self._write_path_challenge_frame( builder=builder, challenge=challenge ) network_path.local_challenge_sent = True # PATH RESPONSE while len(network_path.remote_challenges) > 0: challenge = network_path.remote_challenges.popleft() self._write_path_response_frame( builder=builder, challenge=challenge ) # NEW_CONNECTION_ID for connection_id in self._host_cids: if not connection_id.was_sent: self._write_new_connection_id_frame( builder=builder, connection_id=connection_id ) # RETIRE_CONNECTION_ID for sequence_number in self._retire_connection_ids[:]: self._write_retire_connection_id_frame( builder=builder, sequence_number=sequence_number ) self._retire_connection_ids.pop(0) # STREAMS_BLOCKED if self._streams_blocked_pending: if self._streams_blocked_bidi: self._write_streams_blocked_frame( builder=builder, frame_type=QuicFrameType.STREAMS_BLOCKED_BIDI, limit=self._remote_max_streams_bidi, ) if self._streams_blocked_uni: self._write_streams_blocked_frame( builder=builder, frame_type=QuicFrameType.STREAMS_BLOCKED_UNI, limit=self._remote_max_streams_uni, ) self._streams_blocked_pending = False # MAX_DATA and MAX_STREAMS self._write_connection_limits(builder=builder, space=space) # stream-level limits for stream in self._streams.values(): self._write_stream_limits(builder=builder, space=space, stream=stream) # PING (user-request) if self._ping_pending: self._write_ping_frame(builder, self._ping_pending) self._ping_pending.clear() # PING (probe) if self._probe_pending: self._write_ping_frame(builder, comment="probe") self._probe_pending = False # CRYPTO if crypto_stream is not None and not crypto_stream.sender.buffer_is_empty: self._write_crypto_frame( builder=builder, space=space, stream=crypto_stream ) # DATAGRAM while self._datagrams_pending: try: self._write_datagram_frame( builder=builder, data=self._datagrams_pending[0], frame_type=QuicFrameType.DATAGRAM_WITH_LENGTH, ) self._datagrams_pending.popleft() except QuicPacketBuilderStop: break sent: Set[QuicStream] = set() discarded: Set[QuicStream] = set() try: for stream in self._streams_queue: # if the stream is finished, discard it if stream.is_finished: self._logger.debug("Stream %d discarded", stream.stream_id) self._streams.pop(stream.stream_id) self._streams_finished.add(stream.stream_id) discarded.add(stream) continue if stream.receiver.stop_pending: # STOP_SENDING self._write_stop_sending_frame(builder=builder, stream=stream) if stream.sender.reset_pending: # RESET_STREAM self._write_reset_stream_frame(builder=builder, stream=stream) elif not stream.is_blocked and not stream.sender.buffer_is_empty: # STREAM used = self._write_stream_frame( builder=builder, space=space, stream=stream, max_offset=min( stream.sender.highest_offset + self._remote_max_data - self._remote_max_data_used, stream.max_stream_data_remote, ), ) self._remote_max_data_used += used if used > 0: sent.add(stream) finally: # Make a new stream service order, putting served ones at the end. # # This method of updating the streams queue ensures that discarded # streams are removed and ones which sent are moved to the end even # if an exception occurs in the loop. self._streams_queue = [ stream for stream in self._streams_queue if not (stream in discarded or stream in sent) ] self._streams_queue.extend(sent) if builder.packet_is_empty: break else: self._loss._pacer.update_after_send(now=now) def _write_handshake( self, builder: QuicPacketBuilder, epoch: tls.Epoch, now: float ) -> None: crypto = self._cryptos[epoch] if not crypto.send.is_valid(): return crypto_stream = self._crypto_streams[epoch] space = self._spaces[epoch] while True: if epoch == tls.Epoch.INITIAL: packet_type = QuicPacketType.INITIAL else: packet_type = QuicPacketType.HANDSHAKE builder.start_packet(packet_type, crypto) # ACK if space.ack_at is not None: self._write_ack_frame(builder=builder, space=space, now=now) # CRYPTO if not crypto_stream.sender.buffer_is_empty: if self._write_crypto_frame( builder=builder, space=space, stream=crypto_stream ): self._probe_pending = False # PING (probe) if ( self._probe_pending and not self._handshake_complete and ( epoch == tls.Epoch.HANDSHAKE or not self._cryptos[tls.Epoch.HANDSHAKE].send.is_valid() ) ): self._write_ping_frame(builder, comment="probe") self._probe_pending = False if builder.packet_is_empty: break def _write_ack_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, now: float ) -> None: # calculate ACK delay ack_delay = now - space.largest_received_time ack_delay_encoded = int(ack_delay * 1000000) >> self._local_ack_delay_exponent buf = builder.start_frame( QuicFrameType.ACK, capacity=ACK_FRAME_CAPACITY, handler=self._on_ack_delivery, handler_args=(space, space.largest_received_packet), ) ranges = push_ack_frame(buf, space.ack_queue, ack_delay_encoded) space.ack_at = None # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_ack_frame( ranges=space.ack_queue, delay=ack_delay ) ) # check if we need to trigger an ACK-of-ACK if ranges > 1 and builder.packet_number % 8 == 0: self._write_ping_frame(builder, comment="ACK-of-ACK trigger") def _write_connection_close_frame( self, builder: QuicPacketBuilder, epoch: tls.Epoch, error_code: int, frame_type: Optional[int], reason_phrase: str, ) -> None: # convert application-level close to transport-level close in early stages if frame_type is None and epoch in (tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE): error_code = QuicErrorCode.APPLICATION_ERROR frame_type = QuicFrameType.PADDING reason_phrase = "" reason_bytes = reason_phrase.encode("utf8") reason_length = len(reason_bytes) if frame_type is None: buf = builder.start_frame( QuicFrameType.APPLICATION_CLOSE, capacity=APPLICATION_CLOSE_FRAME_CAPACITY + reason_length, ) buf.push_uint_var(error_code) buf.push_uint_var(reason_length) buf.push_bytes(reason_bytes) else: buf = builder.start_frame( QuicFrameType.TRANSPORT_CLOSE, capacity=TRANSPORT_CLOSE_FRAME_CAPACITY + reason_length, ) buf.push_uint_var(error_code) buf.push_uint_var(frame_type) buf.push_uint_var(reason_length) buf.push_bytes(reason_bytes) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_connection_close_frame( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) ) def _write_connection_limits( self, builder: QuicPacketBuilder, space: QuicPacketSpace ) -> None: """ Raise MAX_DATA or MAX_STREAMS if needed. """ for limit in ( self._local_max_data, self._local_max_streams_bidi, self._local_max_streams_uni, ): if limit.used * 2 > limit.value: limit.value *= 2 self._logger.debug("Local %s raised to %d", limit.name, limit.value) if limit.value != limit.sent: buf = builder.start_frame( limit.frame_type, capacity=CONNECTION_LIMIT_FRAME_CAPACITY, handler=self._on_connection_limit_delivery, handler_args=(limit,), ) buf.push_uint_var(limit.value) limit.sent = limit.value # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=limit.frame_type, maximum=limit.value, ) ) def _write_crypto_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream ) -> bool: frame_overhead = 3 + size_uint_var(stream.sender.next_offset) frame = stream.sender.get_frame(builder.remaining_flight_space - frame_overhead) if frame is not None: buf = builder.start_frame( QuicFrameType.CRYPTO, capacity=frame_overhead, handler=stream.sender.on_data_delivery, handler_args=(frame.offset, frame.offset + len(frame.data), False), ) buf.push_uint_var(frame.offset) buf.push_uint16(len(frame.data) | 0x4000) buf.push_bytes(frame.data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_crypto_frame(frame) ) return True return False def _write_datagram_frame( self, builder: QuicPacketBuilder, data: bytes, frame_type: QuicFrameType ) -> bool: """ Write a DATAGRAM frame. Returns True if the frame was processed, False otherwise. """ assert frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH length = len(data) frame_size = 1 + size_uint_var(length) + length buf = builder.start_frame(frame_type, capacity=frame_size) buf.push_uint_var(length) buf.push_bytes(data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_datagram_frame(length=length) ) return True def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None: builder.start_frame( QuicFrameType.HANDSHAKE_DONE, capacity=HANDSHAKE_DONE_FRAME_CAPACITY, handler=self._on_handshake_done_delivery, ) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_handshake_done_frame() ) def _write_new_connection_id_frame( self, builder: QuicPacketBuilder, connection_id: QuicConnectionId ) -> None: retire_prior_to = 0 # FIXME buf = builder.start_frame( QuicFrameType.NEW_CONNECTION_ID, capacity=NEW_CONNECTION_ID_FRAME_CAPACITY, handler=self._on_new_connection_id_delivery, handler_args=(connection_id,), ) buf.push_uint_var(connection_id.sequence_number) buf.push_uint_var(retire_prior_to) buf.push_uint8(len(connection_id.cid)) buf.push_bytes(connection_id.cid) buf.push_bytes(connection_id.stateless_reset_token) connection_id.was_sent = True self._events.append(events.ConnectionIdIssued(connection_id=connection_id.cid)) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_new_connection_id_frame( connection_id=connection_id.cid, retire_prior_to=retire_prior_to, sequence_number=connection_id.sequence_number, stateless_reset_token=connection_id.stateless_reset_token, ) ) def _write_path_challenge_frame( self, builder: QuicPacketBuilder, challenge: bytes ) -> None: buf = builder.start_frame( QuicFrameType.PATH_CHALLENGE, capacity=PATH_CHALLENGE_FRAME_CAPACITY ) buf.push_bytes(challenge) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_path_challenge_frame(data=challenge) ) def _write_path_response_frame( self, builder: QuicPacketBuilder, challenge: bytes ) -> None: buf = builder.start_frame( QuicFrameType.PATH_RESPONSE, capacity=PATH_RESPONSE_FRAME_CAPACITY ) buf.push_bytes(challenge) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_path_response_frame(data=challenge) ) def _write_ping_frame( self, builder: QuicPacketBuilder, uids: List[int] = [], comment="" ): builder.start_frame( QuicFrameType.PING, capacity=PING_FRAME_CAPACITY, handler=self._on_ping_delivery, handler_args=(tuple(uids),), ) self._logger.debug( "Sending PING%s in packet %d", " (%s)" % comment if comment else "", builder.packet_number, ) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) def _write_reset_stream_frame( self, builder: QuicPacketBuilder, stream: QuicStream, ) -> None: buf = builder.start_frame( frame_type=QuicFrameType.RESET_STREAM, capacity=RESET_STREAM_FRAME_CAPACITY, handler=stream.sender.on_reset_delivery, ) frame = stream.sender.get_reset_frame() buf.push_uint_var(frame.stream_id) buf.push_uint_var(frame.error_code) buf.push_uint_var(frame.final_size) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_reset_stream_frame( error_code=frame.error_code, final_size=frame.final_size, stream_id=frame.stream_id, ) ) def _write_retire_connection_id_frame( self, builder: QuicPacketBuilder, sequence_number: int ) -> None: buf = builder.start_frame( QuicFrameType.RETIRE_CONNECTION_ID, capacity=RETIRE_CONNECTION_ID_CAPACITY, handler=self._on_retire_connection_id_delivery, handler_args=(sequence_number,), ) buf.push_uint_var(sequence_number) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_retire_connection_id_frame(sequence_number) ) def _write_stop_sending_frame( self, builder: QuicPacketBuilder, stream: QuicStream, ) -> None: buf = builder.start_frame( frame_type=QuicFrameType.STOP_SENDING, capacity=STOP_SENDING_FRAME_CAPACITY, handler=stream.receiver.on_stop_sending_delivery, ) frame = stream.receiver.get_stop_frame() buf.push_uint_var(frame.stream_id) buf.push_uint_var(frame.error_code) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_stop_sending_frame( error_code=frame.error_code, stream_id=frame.stream_id ) ) def _write_stream_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream, max_offset: int, ) -> int: # the frame data size is constrained by our peer's MAX_DATA and # the space available in the current packet frame_overhead = ( 3 + size_uint_var(stream.stream_id) + ( size_uint_var(stream.sender.next_offset) if stream.sender.next_offset else 0 ) ) previous_send_highest = stream.sender.highest_offset frame = stream.sender.get_frame( builder.remaining_flight_space - frame_overhead, max_offset ) if frame is not None: frame_type = QuicFrameType.STREAM_BASE | 2 # length if frame.offset: frame_type |= 4 if frame.fin: frame_type |= 1 buf = builder.start_frame( frame_type, capacity=frame_overhead, handler=stream.sender.on_data_delivery, handler_args=(frame.offset, frame.offset + len(frame.data), frame.fin), ) buf.push_uint_var(stream.stream_id) if frame.offset: buf.push_uint_var(frame.offset) buf.push_uint16(len(frame.data) | 0x4000) buf.push_bytes(frame.data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_stream_frame( frame, stream_id=stream.stream_id ) ) return stream.sender.highest_offset - previous_send_highest else: return 0 def _write_stream_limits( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream ) -> None: """ Raise MAX_STREAM_DATA if needed. The only case where `stream.max_stream_data_local` is zero is for locally created unidirectional streams. We skip such streams to avoid spurious logging. """ if ( stream.max_stream_data_local and stream.receiver.highest_offset * 2 > stream.max_stream_data_local ): stream.max_stream_data_local *= 2 self._logger.debug( "Stream %d local max_stream_data raised to %d", stream.stream_id, stream.max_stream_data_local, ) if stream.max_stream_data_local_sent != stream.max_stream_data_local: buf = builder.start_frame( QuicFrameType.MAX_STREAM_DATA, capacity=MAX_STREAM_DATA_FRAME_CAPACITY, handler=self._on_max_stream_data_delivery, handler_args=(stream,), ) buf.push_uint_var(stream.stream_id) buf.push_uint_var(stream.max_stream_data_local) stream.max_stream_data_local_sent = stream.max_stream_data_local # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_max_stream_data_frame( maximum=stream.max_stream_data_local, stream_id=stream.stream_id ) ) def _write_streams_blocked_frame( self, builder: QuicPacketBuilder, frame_type: QuicFrameType, limit: int ) -> None: buf = builder.start_frame(frame_type, capacity=STREAMS_BLOCKED_CAPACITY) buf.push_uint_var(limit) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_streams_blocked_frame( is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, limit=limit, ) ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/crypto.py0000644000175100001770000001772600000000000021043 0ustar00runnerdocker00000000000000import binascii from typing import Callable, Optional, Tuple from .._crypto import AEAD, CryptoError, HeaderProtection from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract from .packet import ( QuicProtocolVersion, decode_packet_number, is_long_header, ) CIPHER_SUITES = { CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"), CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"), CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"), } INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256 INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a") INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9") SAMPLE_SIZE = 16 Callback = Callable[[str], None] def NoCallback(trigger: str) -> None: pass class KeyUnavailableError(CryptoError): pass def derive_key_iv_hp( *, cipher_suite: CipherSuite, secret: bytes, version: int ) -> Tuple[bytes, bytes, bytes]: algorithm = cipher_suite_hash(cipher_suite) if cipher_suite in [ CipherSuite.AES_256_GCM_SHA384, CipherSuite.CHACHA20_POLY1305_SHA256, ]: key_size = 32 else: key_size = 16 if version == QuicProtocolVersion.VERSION_2: return ( hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size), hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12), hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size), ) else: return ( hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size), hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12), hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size), ) class CryptoContext: def __init__( self, key_phase: int = 0, setup_cb: Callback = NoCallback, teardown_cb: Callback = NoCallback, ) -> None: self.aead: Optional[AEAD] = None self.cipher_suite: Optional[CipherSuite] = None self.hp: Optional[HeaderProtection] = None self.key_phase = key_phase self.secret: Optional[bytes] = None self.version: Optional[int] = None self._setup_cb = setup_cb self._teardown_cb = teardown_cb def decrypt_packet( self, packet: bytes, encrypted_offset: int, expected_packet_number: int ) -> Tuple[bytes, bytes, int, bool]: if self.aead is None: raise KeyUnavailableError("Decryption key is not available") # header protection plain_header, packet_number = self.hp.remove(packet, encrypted_offset) first_byte = plain_header[0] # packet number pn_length = (first_byte & 0x03) + 1 packet_number = decode_packet_number( packet_number, pn_length * 8, expected_packet_number ) # detect key phase change crypto = self if not is_long_header(first_byte): key_phase = (first_byte & 4) >> 2 if key_phase != self.key_phase: crypto = next_key_phase(self) # payload protection payload = crypto.aead.decrypt( packet[len(plain_header) :], plain_header, packet_number ) return plain_header, payload, packet_number, crypto != self def encrypt_packet( self, plain_header: bytes, plain_payload: bytes, packet_number: int ) -> bytes: assert self.is_valid(), "Encryption key is not available" # payload protection protected_payload = self.aead.encrypt( plain_payload, plain_header, packet_number ) # header protection return self.hp.apply(plain_header, protected_payload) def is_valid(self) -> bool: return self.aead is not None def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None: hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite] key, iv, hp = derive_key_iv_hp( cipher_suite=cipher_suite, secret=secret, version=version, ) self.aead = AEAD(aead_cipher_name, key, iv) self.cipher_suite = cipher_suite self.hp = HeaderProtection(hp_cipher_name, hp) self.secret = secret self.version = version # trigger callback self._setup_cb("tls") def teardown(self) -> None: self.aead = None self.cipher_suite = None self.hp = None self.secret = None # trigger callback self._teardown_cb("tls") def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None: self.aead = crypto.aead self.key_phase = crypto.key_phase self.secret = crypto.secret # trigger callback self._setup_cb(trigger) def next_key_phase(self: CryptoContext) -> CryptoContext: algorithm = cipher_suite_hash(self.cipher_suite) crypto = CryptoContext(key_phase=int(not self.key_phase)) crypto.setup( cipher_suite=self.cipher_suite, secret=hkdf_expand_label( algorithm, self.secret, b"quic ku", b"", algorithm.digest_size ), version=self.version, ) return crypto class CryptoPair: def __init__( self, recv_setup_cb: Callback = NoCallback, recv_teardown_cb: Callback = NoCallback, send_setup_cb: Callback = NoCallback, send_teardown_cb: Callback = NoCallback, ) -> None: self.aead_tag_size = 16 self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb) self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb) self._update_key_requested = False def decrypt_packet( self, packet: bytes, encrypted_offset: int, expected_packet_number: int ) -> Tuple[bytes, bytes, int]: plain_header, payload, packet_number, update_key = self.recv.decrypt_packet( packet, encrypted_offset, expected_packet_number ) if update_key: self._update_key("remote_update") return plain_header, payload, packet_number def encrypt_packet( self, plain_header: bytes, plain_payload: bytes, packet_number: int ) -> bytes: if self._update_key_requested: self._update_key("local_update") return self.send.encrypt_packet(plain_header, plain_payload, packet_number) def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None: if is_client: recv_label, send_label = b"server in", b"client in" else: recv_label, send_label = b"client in", b"server in" if version == QuicProtocolVersion.VERSION_2: initial_salt = INITIAL_SALT_VERSION_2 else: initial_salt = INITIAL_SALT_VERSION_1 algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE) initial_secret = hkdf_extract(algorithm, initial_salt, cid) self.recv.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=hkdf_expand_label( algorithm, initial_secret, recv_label, b"", algorithm.digest_size ), version=version, ) self.send.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=hkdf_expand_label( algorithm, initial_secret, send_label, b"", algorithm.digest_size ), version=version, ) def teardown(self) -> None: self.recv.teardown() self.send.teardown() def update_key(self) -> None: self._update_key_requested = True @property def key_phase(self) -> int: if self._update_key_requested: return int(not self.recv.key_phase) else: return self.recv.key_phase def _update_key(self, trigger: str) -> None: apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger) apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger) self._update_key_requested = False ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/events.py0000644000175100001770000000525000000000000021014 0ustar00runnerdocker00000000000000from dataclasses import dataclass from typing import Optional class QuicEvent: """ Base class for QUIC events. """ pass @dataclass class ConnectionIdIssued(QuicEvent): connection_id: bytes @dataclass class ConnectionIdRetired(QuicEvent): connection_id: bytes @dataclass class ConnectionTerminated(QuicEvent): """ The ConnectionTerminated event is fired when the QUIC connection is terminated. """ error_code: int "The error code which was specified when closing the connection." frame_type: Optional[int] "The frame type which caused the connection to be closed, or `None`." reason_phrase: str "The human-readable reason for which the connection was closed." @dataclass class DatagramFrameReceived(QuicEvent): """ The DatagramFrameReceived event is fired when a DATAGRAM frame is received. """ data: bytes "The data which was received." @dataclass class HandshakeCompleted(QuicEvent): """ The HandshakeCompleted event is fired when the TLS handshake completes. """ alpn_protocol: Optional[str] "The protocol which was negotiated using ALPN, or `None`." early_data_accepted: bool "Whether early (0-RTT) data was accepted by the remote peer." session_resumed: bool "Whether a TLS session was resumed." @dataclass class PingAcknowledged(QuicEvent): """ The PingAcknowledged event is fired when a PING frame is acknowledged. """ uid: int "The unique ID of the PING." @dataclass class ProtocolNegotiated(QuicEvent): """ The ProtocolNegotiated event is fired when ALPN negotiation completes. """ alpn_protocol: Optional[str] "The protocol which was negotiated using ALPN, or `None`." @dataclass class StopSendingReceived(QuicEvent): """ The StopSendingReceived event is fired when the remote peer requests stopping data transmission on a stream. """ error_code: int "The error code that was sent from the peer." stream_id: int "The ID of the stream that the peer requested stopping data transmission." @dataclass class StreamDataReceived(QuicEvent): """ The StreamDataReceived event is fired whenever data is received on a stream. """ data: bytes "The data which was received." end_stream: bool "Whether the STREAM frame had the FIN bit set." stream_id: int "The ID of the stream the data was received for." @dataclass class StreamReset(QuicEvent): """ The StreamReset event is fired when the remote peer resets a stream. """ error_code: int "The error code that triggered the reset." stream_id: int "The ID of the stream that was reset." ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/logger.py0000644000175100001770000002414000000000000020766 0ustar00runnerdocker00000000000000import binascii import json import os import time from collections import deque from typing import Any, Deque, Dict, List, Optional from ..h3.events import Headers from .packet import ( QuicFrameType, QuicPacketType, QuicStreamFrame, QuicTransportParameters, ) from .rangeset import RangeSet PACKET_TYPE_NAMES = { QuicPacketType.INITIAL: "initial", QuicPacketType.HANDSHAKE: "handshake", QuicPacketType.ZERO_RTT: "0RTT", QuicPacketType.ONE_RTT: "1RTT", QuicPacketType.RETRY: "retry", QuicPacketType.VERSION_NEGOTIATION: "version_negotiation", } QLOG_VERSION = "0.3" def hexdump(data: bytes) -> str: return binascii.hexlify(data).decode("ascii") class QuicLoggerTrace: """ A QUIC event trace. Events are logged in the format defined by qlog. See: - https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02 - https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events - https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events """ def __init__(self, *, is_client: bool, odcid: bytes) -> None: self._odcid = odcid self._events: Deque[Dict[str, Any]] = deque() self._vantage_point = { "name": "aioquic", "type": "client" if is_client else "server", } # QUIC def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict: return { "ack_delay": self.encode_time(delay), "acked_ranges": [[x.start, x.stop - 1] for x in ranges], "frame_type": "ack", } def encode_connection_close_frame( self, error_code: int, frame_type: Optional[int], reason_phrase: str ) -> Dict: attrs = { "error_code": error_code, "error_space": "application" if frame_type is None else "transport", "frame_type": "connection_close", "raw_error_code": error_code, "reason": reason_phrase, } if frame_type is not None: attrs["trigger_frame_type"] = frame_type return attrs def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict: if frame_type == QuicFrameType.MAX_DATA: return {"frame_type": "max_data", "maximum": maximum} else: return { "frame_type": "max_streams", "maximum": maximum, "stream_type": "unidirectional" if frame_type == QuicFrameType.MAX_STREAMS_UNI else "bidirectional", } def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict: return { "frame_type": "crypto", "length": len(frame.data), "offset": frame.offset, } def encode_data_blocked_frame(self, limit: int) -> Dict: return {"frame_type": "data_blocked", "limit": limit} def encode_datagram_frame(self, length: int) -> Dict: return {"frame_type": "datagram", "length": length} def encode_handshake_done_frame(self) -> Dict: return {"frame_type": "handshake_done"} def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict: return { "frame_type": "max_stream_data", "maximum": maximum, "stream_id": stream_id, } def encode_new_connection_id_frame( self, connection_id: bytes, retire_prior_to: int, sequence_number: int, stateless_reset_token: bytes, ) -> Dict: return { "connection_id": hexdump(connection_id), "frame_type": "new_connection_id", "length": len(connection_id), "reset_token": hexdump(stateless_reset_token), "retire_prior_to": retire_prior_to, "sequence_number": sequence_number, } def encode_new_token_frame(self, token: bytes) -> Dict: return { "frame_type": "new_token", "length": len(token), "token": hexdump(token), } def encode_padding_frame(self) -> Dict: return {"frame_type": "padding"} def encode_path_challenge_frame(self, data: bytes) -> Dict: return {"data": hexdump(data), "frame_type": "path_challenge"} def encode_path_response_frame(self, data: bytes) -> Dict: return {"data": hexdump(data), "frame_type": "path_response"} def encode_ping_frame(self) -> Dict: return {"frame_type": "ping"} def encode_reset_stream_frame( self, error_code: int, final_size: int, stream_id: int ) -> Dict: return { "error_code": error_code, "final_size": final_size, "frame_type": "reset_stream", "stream_id": stream_id, } def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict: return { "frame_type": "retire_connection_id", "sequence_number": sequence_number, } def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict: return { "frame_type": "stream_data_blocked", "limit": limit, "stream_id": stream_id, } def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict: return { "frame_type": "stop_sending", "error_code": error_code, "stream_id": stream_id, } def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict: return { "fin": frame.fin, "frame_type": "stream", "length": len(frame.data), "offset": frame.offset, "stream_id": stream_id, } def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict: return { "frame_type": "streams_blocked", "limit": limit, "stream_type": "unidirectional" if is_unidirectional else "bidirectional", } def encode_time(self, seconds: float) -> float: """ Convert a time to milliseconds. """ return seconds * 1000 def encode_transport_parameters( self, owner: str, parameters: QuicTransportParameters ) -> Dict[str, Any]: data: Dict[str, Any] = {"owner": owner} for param_name, param_value in parameters.__dict__.items(): if isinstance(param_value, bool): data[param_name] = param_value elif isinstance(param_value, bytes): data[param_name] = hexdump(param_value) elif isinstance(param_value, int): data[param_name] = param_value return data def packet_type(self, packet_type: QuicPacketType) -> str: return PACKET_TYPE_NAMES[packet_type] # HTTP/3 def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict: return { "frame": {"frame_type": "data"}, "length": length, "stream_id": stream_id, } def encode_http3_headers_frame( self, length: int, headers: Headers, stream_id: int ) -> Dict: return { "frame": { "frame_type": "headers", "headers": self._encode_http3_headers(headers), }, "length": length, "stream_id": stream_id, } def encode_http3_push_promise_frame( self, length: int, headers: Headers, push_id: int, stream_id: int ) -> Dict: return { "frame": { "frame_type": "push_promise", "headers": self._encode_http3_headers(headers), "push_id": push_id, }, "length": length, "stream_id": stream_id, } def _encode_http3_headers(self, headers: Headers) -> List[Dict]: return [ {"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers ] # CORE def log_event(self, *, category: str, event: str, data: Dict) -> None: self._events.append( { "data": data, "name": category + ":" + event, "time": self.encode_time(time.time()), } ) def to_dict(self) -> Dict[str, Any]: """ Return the trace as a dictionary which can be written as JSON. """ return { "common_fields": { "ODCID": hexdump(self._odcid), }, "events": list(self._events), "vantage_point": self._vantage_point, } class QuicLogger: """ A QUIC event logger which stores traces in memory. """ def __init__(self) -> None: self._traces: List[QuicLoggerTrace] = [] def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace: trace = QuicLoggerTrace(is_client=is_client, odcid=odcid) self._traces.append(trace) return trace def end_trace(self, trace: QuicLoggerTrace) -> None: assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger" def to_dict(self) -> Dict[str, Any]: """ Return the traces as a dictionary which can be written as JSON. """ return { "qlog_format": "JSON", "qlog_version": QLOG_VERSION, "traces": [trace.to_dict() for trace in self._traces], } class QuicFileLogger(QuicLogger): """ A QUIC event logger which writes one trace per file. """ def __init__(self, path: str) -> None: if not os.path.isdir(path): raise ValueError("QUIC log output directory '%s' does not exist" % path) self.path = path super().__init__() def end_trace(self, trace: QuicLoggerTrace) -> None: trace_dict = trace.to_dict() trace_path = os.path.join( self.path, trace_dict["common_fields"]["ODCID"] + ".qlog" ) with open(trace_path, "w") as logger_fp: json.dump( { "qlog_format": "JSON", "qlog_version": QLOG_VERSION, "traces": [trace_dict], }, logger_fp, ) self._traces.remove(trace) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/packet.py0000644000175100001770000004715400000000000020770 0ustar00runnerdocker00000000000000import binascii import ipaddress import os from dataclasses import dataclass from enum import Enum, IntEnum from typing import List, Optional, Tuple from cryptography.hazmat.primitives.ciphers.aead import AESGCM from ..buffer import Buffer from .rangeset import RangeSet PACKET_LONG_HEADER = 0x80 PACKET_FIXED_BIT = 0x40 PACKET_SPIN_BIT = 0x20 CONNECTION_ID_MAX_SIZE = 20 PACKET_NUMBER_MAX_SIZE = 4 RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e") RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92") RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb") RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a") RETRY_INTEGRITY_TAG_SIZE = 16 STATELESS_RESET_TOKEN_SIZE = 16 class QuicErrorCode(IntEnum): NO_ERROR = 0x0 INTERNAL_ERROR = 0x1 CONNECTION_REFUSED = 0x2 FLOW_CONTROL_ERROR = 0x3 STREAM_LIMIT_ERROR = 0x4 STREAM_STATE_ERROR = 0x5 FINAL_SIZE_ERROR = 0x6 FRAME_ENCODING_ERROR = 0x7 TRANSPORT_PARAMETER_ERROR = 0x8 CONNECTION_ID_LIMIT_ERROR = 0x9 PROTOCOL_VIOLATION = 0xA INVALID_TOKEN = 0xB APPLICATION_ERROR = 0xC CRYPTO_BUFFER_EXCEEDED = 0xD KEY_UPDATE_ERROR = 0xE AEAD_LIMIT_REACHED = 0xF VERSION_NEGOTIATION_ERROR = 0x11 CRYPTO_ERROR = 0x100 class QuicPacketType(Enum): INITIAL = 0 ZERO_RTT = 1 HANDSHAKE = 2 RETRY = 3 VERSION_NEGOTIATION = 4 ONE_RTT = 5 # For backwards compatibility only, use `QuicPacketType` in new code. PACKET_TYPE_INITIAL = QuicPacketType.INITIAL # QUIC version 1 # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 PACKET_LONG_TYPE_ENCODE_VERSION_1 = { QuicPacketType.INITIAL: 0, QuicPacketType.ZERO_RTT: 1, QuicPacketType.HANDSHAKE: 2, QuicPacketType.RETRY: 3, } PACKET_LONG_TYPE_DECODE_VERSION_1 = dict( (v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items() ) # QUIC version 2 # https://datatracker.ietf.org/doc/html/rfc9369#section-3.2 PACKET_LONG_TYPE_ENCODE_VERSION_2 = { QuicPacketType.INITIAL: 1, QuicPacketType.ZERO_RTT: 2, QuicPacketType.HANDSHAKE: 3, QuicPacketType.RETRY: 0, } PACKET_LONG_TYPE_DECODE_VERSION_2 = dict( (v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items() ) class QuicProtocolVersion(IntEnum): NEGOTIATION = 0 VERSION_1 = 0x00000001 VERSION_2 = 0x6B3343CF @dataclass class QuicHeader: version: Optional[int] "The protocol version. Only present in long header packets." packet_type: QuicPacketType "The type of the packet." packet_length: int "The total length of the packet, in bytes." destination_cid: bytes "The destination connection ID." source_cid: bytes "The destination connection ID." token: bytes "The address verification token. Only present in `INITIAL` and `RETRY` packets." integrity_tag: bytes "The retry integrity tag. Only present in `RETRY` packets." supported_versions: List[int] "Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets." def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int: """ Recover a packet number from a truncated packet number. See: Appendix A - Sample Packet Number Decoding Algorithm """ window = 1 << num_bits half_window = window // 2 candidate = (expected & ~(window - 1)) | truncated if candidate <= expected - half_window and candidate < (1 << 62) - window: return candidate + window elif candidate > expected + half_window and candidate >= window: return candidate - window else: return candidate def get_retry_integrity_tag( packet_without_tag: bytes, original_destination_cid: bytes, version: int ) -> bytes: """ Calculate the integrity tag for a RETRY packet. """ # build Retry pseudo packet buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag)) buf.push_uint8(len(original_destination_cid)) buf.push_bytes(original_destination_cid) buf.push_bytes(packet_without_tag) assert buf.eof() if version == QuicProtocolVersion.VERSION_2: aead_key = RETRY_AEAD_KEY_VERSION_2 aead_nonce = RETRY_AEAD_NONCE_VERSION_2 else: aead_key = RETRY_AEAD_KEY_VERSION_1 aead_nonce = RETRY_AEAD_NONCE_VERSION_1 # run AES-128-GCM aead = AESGCM(aead_key) integrity_tag = aead.encrypt(aead_nonce, b"", buf.data) assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE return integrity_tag def get_spin_bit(first_byte: int) -> bool: return bool(first_byte & PACKET_SPIN_BIT) def is_long_header(first_byte: int) -> bool: return bool(first_byte & PACKET_LONG_HEADER) def pretty_protocol_version(version: int) -> str: """ Return a user-friendly representation of a protocol version. """ try: version_name = QuicProtocolVersion(version).name except ValueError: version_name = "UNKNOWN" return f"0x{version:08x} ({version_name})" def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader: packet_start = buf.tell() version = None integrity_tag = b"" supported_versions = [] token = b"" first_byte = buf.pull_uint8() if is_long_header(first_byte): # Long Header Packets. # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 version = buf.pull_uint32() destination_cid_length = buf.pull_uint8() if destination_cid_length > CONNECTION_ID_MAX_SIZE: raise ValueError( "Destination CID is too long (%d bytes)" % destination_cid_length ) destination_cid = buf.pull_bytes(destination_cid_length) source_cid_length = buf.pull_uint8() if source_cid_length > CONNECTION_ID_MAX_SIZE: raise ValueError("Source CID is too long (%d bytes)" % source_cid_length) source_cid = buf.pull_bytes(source_cid_length) if version == QuicProtocolVersion.NEGOTIATION: # Version Negotiation Packet. # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1 packet_type = QuicPacketType.VERSION_NEGOTIATION while not buf.eof(): supported_versions.append(buf.pull_uint32()) packet_end = buf.tell() else: if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") if version == QuicProtocolVersion.VERSION_2: packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[ (first_byte & 0x30) >> 4 ] else: packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[ (first_byte & 0x30) >> 4 ] if packet_type == QuicPacketType.INITIAL: token_length = buf.pull_uint_var() token = buf.pull_bytes(token_length) rest_length = buf.pull_uint_var() elif packet_type == QuicPacketType.ZERO_RTT: rest_length = buf.pull_uint_var() elif packet_type == QuicPacketType.HANDSHAKE: rest_length = buf.pull_uint_var() else: token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE token = buf.pull_bytes(token_length) integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE) rest_length = 0 # Check remainder length. packet_end = buf.tell() + rest_length if packet_end > buf.capacity: raise ValueError("Packet payload is truncated") else: # Short Header Packets. # https://datatracker.ietf.org/doc/html/rfc9000#section-17.3 if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") version = None packet_type = QuicPacketType.ONE_RTT destination_cid = buf.pull_bytes(host_cid_length) source_cid = b"" packet_end = buf.capacity return QuicHeader( version=version, packet_type=packet_type, packet_length=packet_end - packet_start, destination_cid=destination_cid, source_cid=source_cid, token=token, integrity_tag=integrity_tag, supported_versions=supported_versions, ) def encode_long_header_first_byte( version: int, packet_type: QuicPacketType, bits: int ) -> int: """ Encode the first byte of a long header packet. """ if version == QuicProtocolVersion.VERSION_2: long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2 else: long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1 return ( PACKET_LONG_HEADER | PACKET_FIXED_BIT | long_type_encode[packet_type] << 4 | bits ) def encode_quic_retry( version: int, source_cid: bytes, destination_cid: bytes, original_destination_cid: bytes, retry_token: bytes, unused: int = 0, ) -> bytes: buf = Buffer( capacity=7 + len(destination_cid) + len(source_cid) + len(retry_token) + RETRY_INTEGRITY_TAG_SIZE ) buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused)) buf.push_uint32(version) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) buf.push_uint8(len(source_cid)) buf.push_bytes(source_cid) buf.push_bytes(retry_token) buf.push_bytes( get_retry_integrity_tag(buf.data, original_destination_cid, version=version) ) assert buf.eof() return buf.data def encode_quic_version_negotiation( source_cid: bytes, destination_cid: bytes, supported_versions: List[int] ) -> bytes: buf = Buffer( capacity=7 + len(destination_cid) + len(source_cid) + 4 * len(supported_versions) ) buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER) buf.push_uint32(QuicProtocolVersion.NEGOTIATION) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) buf.push_uint8(len(source_cid)) buf.push_bytes(source_cid) for version in supported_versions: buf.push_uint32(version) return buf.data # TLS EXTENSION @dataclass class QuicPreferredAddress: ipv4_address: Optional[Tuple[str, int]] ipv6_address: Optional[Tuple[str, int]] connection_id: bytes stateless_reset_token: bytes @dataclass class QuicVersionInformation: chosen_version: int available_versions: List[int] @dataclass class QuicTransportParameters: original_destination_connection_id: Optional[bytes] = None max_idle_timeout: Optional[int] = None stateless_reset_token: Optional[bytes] = None max_udp_payload_size: Optional[int] = None initial_max_data: Optional[int] = None initial_max_stream_data_bidi_local: Optional[int] = None initial_max_stream_data_bidi_remote: Optional[int] = None initial_max_stream_data_uni: Optional[int] = None initial_max_streams_bidi: Optional[int] = None initial_max_streams_uni: Optional[int] = None ack_delay_exponent: Optional[int] = None max_ack_delay: Optional[int] = None disable_active_migration: Optional[bool] = False preferred_address: Optional[QuicPreferredAddress] = None active_connection_id_limit: Optional[int] = None initial_source_connection_id: Optional[bytes] = None retry_source_connection_id: Optional[bytes] = None version_information: Optional[QuicVersionInformation] = None max_datagram_frame_size: Optional[int] = None quantum_readiness: Optional[bytes] = None PARAMS = { 0x00: ("original_destination_connection_id", bytes), 0x01: ("max_idle_timeout", int), 0x02: ("stateless_reset_token", bytes), 0x03: ("max_udp_payload_size", int), 0x04: ("initial_max_data", int), 0x05: ("initial_max_stream_data_bidi_local", int), 0x06: ("initial_max_stream_data_bidi_remote", int), 0x07: ("initial_max_stream_data_uni", int), 0x08: ("initial_max_streams_bidi", int), 0x09: ("initial_max_streams_uni", int), 0x0A: ("ack_delay_exponent", int), 0x0B: ("max_ack_delay", int), 0x0C: ("disable_active_migration", bool), 0x0D: ("preferred_address", QuicPreferredAddress), 0x0E: ("active_connection_id_limit", int), 0x0F: ("initial_source_connection_id", bytes), 0x10: ("retry_source_connection_id", bytes), # https://datatracker.ietf.org/doc/html/rfc9368#section-3 0x11: ("version_information", QuicVersionInformation), # extensions 0x0020: ("max_datagram_frame_size", int), 0x0C37: ("quantum_readiness", bytes), } def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress: ipv4_address = None ipv4_host = buf.pull_bytes(4) ipv4_port = buf.pull_uint16() if ipv4_host != bytes(4): ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port) ipv6_address = None ipv6_host = buf.pull_bytes(16) ipv6_port = buf.pull_uint16() if ipv6_host != bytes(16): ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port) connection_id_length = buf.pull_uint8() connection_id = buf.pull_bytes(connection_id_length) stateless_reset_token = buf.pull_bytes(16) return QuicPreferredAddress( ipv4_address=ipv4_address, ipv6_address=ipv6_address, connection_id=connection_id, stateless_reset_token=stateless_reset_token, ) def push_quic_preferred_address( buf: Buffer, preferred_address: QuicPreferredAddress ) -> None: if preferred_address.ipv4_address is not None: buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed) buf.push_uint16(preferred_address.ipv4_address[1]) else: buf.push_bytes(bytes(6)) if preferred_address.ipv6_address is not None: buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed) buf.push_uint16(preferred_address.ipv6_address[1]) else: buf.push_bytes(bytes(18)) buf.push_uint8(len(preferred_address.connection_id)) buf.push_bytes(preferred_address.connection_id) buf.push_bytes(preferred_address.stateless_reset_token) def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation: chosen_version = buf.pull_uint32() available_versions = [] for i in range(length // 4 - 1): available_versions.append(buf.pull_uint32()) # If an endpoint receives a Chosen Version equal to zero, or any Available Version # equal to zero, it MUST treat it as a parsing failure. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if chosen_version == 0 or 0 in available_versions: raise ValueError("Version Information must not contain version 0") return QuicVersionInformation( chosen_version=chosen_version, available_versions=available_versions, ) def push_quic_version_information( buf: Buffer, version_information: QuicVersionInformation ) -> None: buf.push_uint32(version_information.chosen_version) for version in version_information.available_versions: buf.push_uint32(version) def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters: params = QuicTransportParameters() while not buf.eof(): param_id = buf.pull_uint_var() param_len = buf.pull_uint_var() param_start = buf.tell() if param_id in PARAMS: # Parse known parameter. param_name, param_type = PARAMS[param_id] if param_type is int: setattr(params, param_name, buf.pull_uint_var()) elif param_type is bytes: setattr(params, param_name, buf.pull_bytes(param_len)) elif param_type is QuicPreferredAddress: setattr(params, param_name, pull_quic_preferred_address(buf)) elif param_type is QuicVersionInformation: setattr( params, param_name, pull_quic_version_information(buf, param_len), ) else: setattr(params, param_name, True) else: # Skip unknown parameter. buf.pull_bytes(param_len) if buf.tell() != param_start + param_len: raise ValueError("Transport parameter length does not match") return params def push_quic_transport_parameters( buf: Buffer, params: QuicTransportParameters ) -> None: for param_id, (param_name, param_type) in PARAMS.items(): param_value = getattr(params, param_name) if param_value is not None and param_value is not False: param_buf = Buffer(capacity=65536) if param_type is int: param_buf.push_uint_var(param_value) elif param_type is bytes: param_buf.push_bytes(param_value) elif param_type is QuicPreferredAddress: push_quic_preferred_address(param_buf, param_value) elif param_type is QuicVersionInformation: push_quic_version_information(param_buf, param_value) buf.push_uint_var(param_id) buf.push_uint_var(param_buf.tell()) buf.push_bytes(param_buf.data) # FRAMES class QuicFrameType(IntEnum): PADDING = 0x00 PING = 0x01 ACK = 0x02 ACK_ECN = 0x03 RESET_STREAM = 0x04 STOP_SENDING = 0x05 CRYPTO = 0x06 NEW_TOKEN = 0x07 STREAM_BASE = 0x08 MAX_DATA = 0x10 MAX_STREAM_DATA = 0x11 MAX_STREAMS_BIDI = 0x12 MAX_STREAMS_UNI = 0x13 DATA_BLOCKED = 0x14 STREAM_DATA_BLOCKED = 0x15 STREAMS_BLOCKED_BIDI = 0x16 STREAMS_BLOCKED_UNI = 0x17 NEW_CONNECTION_ID = 0x18 RETIRE_CONNECTION_ID = 0x19 PATH_CHALLENGE = 0x1A PATH_RESPONSE = 0x1B TRANSPORT_CLOSE = 0x1C APPLICATION_CLOSE = 0x1D HANDSHAKE_DONE = 0x1E DATAGRAM = 0x30 DATAGRAM_WITH_LENGTH = 0x31 NON_ACK_ELICITING_FRAME_TYPES = frozenset( [ QuicFrameType.ACK, QuicFrameType.ACK_ECN, QuicFrameType.PADDING, QuicFrameType.TRANSPORT_CLOSE, QuicFrameType.APPLICATION_CLOSE, ] ) NON_IN_FLIGHT_FRAME_TYPES = frozenset( [ QuicFrameType.ACK, QuicFrameType.ACK_ECN, QuicFrameType.TRANSPORT_CLOSE, QuicFrameType.APPLICATION_CLOSE, ] ) PROBING_FRAME_TYPES = frozenset( [ QuicFrameType.PATH_CHALLENGE, QuicFrameType.PATH_RESPONSE, QuicFrameType.PADDING, QuicFrameType.NEW_CONNECTION_ID, ] ) @dataclass class QuicResetStreamFrame: error_code: int final_size: int stream_id: int @dataclass class QuicStopSendingFrame: error_code: int stream_id: int @dataclass class QuicStreamFrame: data: bytes = b"" fin: bool = False offset: int = 0 def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]: rangeset = RangeSet() end = buf.pull_uint_var() # largest acknowledged delay = buf.pull_uint_var() ack_range_count = buf.pull_uint_var() ack_count = buf.pull_uint_var() # first ack range rangeset.add(end - ack_count, end + 1) end -= ack_count for _ in range(ack_range_count): end -= buf.pull_uint_var() + 2 ack_count = buf.pull_uint_var() rangeset.add(end - ack_count, end + 1) end -= ack_count return rangeset, delay def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int: ranges = len(rangeset) index = ranges - 1 r = rangeset[index] buf.push_uint_var(r.stop - 1) buf.push_uint_var(delay) buf.push_uint_var(index) buf.push_uint_var(r.stop - 1 - r.start) start = r.start while index > 0: index -= 1 r = rangeset[index] buf.push_uint_var(start - r.stop - 1) buf.push_uint_var(r.stop - r.start - 1) start = r.start return ranges ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/packet_builder.py0000644000175100001770000003142000000000000022463 0ustar00runnerdocker00000000000000from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from ..buffer import Buffer, size_uint_var from ..tls import Epoch from .crypto import CryptoPair from .logger import QuicLoggerTrace from .packet import ( NON_ACK_ELICITING_FRAME_TYPES, NON_IN_FLIGHT_FRAME_TYPES, PACKET_FIXED_BIT, PACKET_NUMBER_MAX_SIZE, QuicFrameType, QuicPacketType, encode_long_header_first_byte, ) PACKET_LENGTH_SEND_SIZE = 2 PACKET_NUMBER_SEND_SIZE = 2 QuicDeliveryHandler = Callable[..., None] class QuicDeliveryState(Enum): ACKED = 0 LOST = 1 @dataclass class QuicSentPacket: epoch: Epoch in_flight: bool is_ack_eliciting: bool is_crypto_packet: bool packet_number: int packet_type: QuicPacketType sent_time: Optional[float] = None sent_bytes: int = 0 delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field( default_factory=list ) quic_logger_frames: List[Dict] = field(default_factory=list) class QuicPacketBuilderStop(Exception): pass class QuicPacketBuilder: """ Helper for building QUIC packets. """ def __init__( self, *, host_cid: bytes, peer_cid: bytes, version: int, is_client: bool, max_datagram_size: int, packet_number: int = 0, peer_token: bytes = b"", quic_logger: Optional[QuicLoggerTrace] = None, spin_bit: bool = False, ): self.max_flight_bytes: Optional[int] = None self.max_total_bytes: Optional[int] = None self.quic_logger_frames: Optional[List[Dict]] = None self._host_cid = host_cid self._is_client = is_client self._peer_cid = peer_cid self._peer_token = peer_token self._quic_logger = quic_logger self._spin_bit = spin_bit self._version = version # assembled datagrams and packets self._datagrams: List[bytes] = [] self._datagram_flight_bytes = 0 self._datagram_init = True self._datagram_needs_padding = False self._packets: List[QuicSentPacket] = [] self._flight_bytes = 0 self._total_bytes = 0 # current packet self._header_size = 0 self._packet: Optional[QuicSentPacket] = None self._packet_crypto: Optional[CryptoPair] = None self._packet_number = packet_number self._packet_start = 0 self._packet_type: Optional[QuicPacketType] = None self._buffer = Buffer(max_datagram_size) self._buffer_capacity = max_datagram_size self._flight_capacity = max_datagram_size @property def packet_is_empty(self) -> bool: """ Returns `True` if the current packet is empty. """ assert self._packet is not None packet_size = self._buffer.tell() - self._packet_start return packet_size <= self._header_size @property def packet_number(self) -> int: """ Returns the packet number for the next packet. """ return self._packet_number @property def remaining_buffer_space(self) -> int: """ Returns the remaining number of bytes which can be used in the current packet. """ return ( self._buffer_capacity - self._buffer.tell() - self._packet_crypto.aead_tag_size ) @property def remaining_flight_space(self) -> int: """ Returns the remaining number of bytes which can be used in the current packet. """ return ( self._flight_capacity - self._buffer.tell() - self._packet_crypto.aead_tag_size ) def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]: """ Returns the assembled datagrams. """ if self._packet is not None: self._end_packet() self._flush_current_datagram() datagrams = self._datagrams packets = self._packets self._datagrams = [] self._packets = [] return datagrams, packets def start_frame( self, frame_type: int, capacity: int = 1, handler: Optional[QuicDeliveryHandler] = None, handler_args: Sequence[Any] = [], ) -> Buffer: """ Starts a new frame. """ if self.remaining_buffer_space < capacity or ( frame_type not in NON_IN_FLIGHT_FRAME_TYPES and self.remaining_flight_space < capacity ): raise QuicPacketBuilderStop self._buffer.push_uint_var(frame_type) if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: self._packet.is_ack_eliciting = True if frame_type not in NON_IN_FLIGHT_FRAME_TYPES: self._packet.in_flight = True if frame_type == QuicFrameType.CRYPTO: self._packet.is_crypto_packet = True if handler is not None: self._packet.delivery_handlers.append((handler, handler_args)) return self._buffer def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None: """ Starts a new packet. """ assert packet_type in ( QuicPacketType.INITIAL, QuicPacketType.HANDSHAKE, QuicPacketType.ZERO_RTT, QuicPacketType.ONE_RTT, ), "Invalid packet type" buf = self._buffer # finish previous datagram if self._packet is not None: self._end_packet() # if there is too little space remaining, start a new datagram # FIXME: the limit is arbitrary! packet_start = buf.tell() if self._buffer_capacity - packet_start < 128: self._flush_current_datagram() packet_start = 0 # initialize datagram if needed if self._datagram_init: if self.max_total_bytes is not None: remaining_total_bytes = self.max_total_bytes - self._total_bytes if remaining_total_bytes < self._buffer_capacity: self._buffer_capacity = remaining_total_bytes self._flight_capacity = self._buffer_capacity if self.max_flight_bytes is not None: remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes if remaining_flight_bytes < self._flight_capacity: self._flight_capacity = remaining_flight_bytes self._datagram_flight_bytes = 0 self._datagram_init = False self._datagram_needs_padding = False # calculate header size if packet_type != QuicPacketType.ONE_RTT: header_size = 11 + len(self._peer_cid) + len(self._host_cid) if packet_type == QuicPacketType.INITIAL: token_length = len(self._peer_token) header_size += size_uint_var(token_length) + token_length else: header_size = 3 + len(self._peer_cid) # check we have enough space if packet_start + header_size >= self._buffer_capacity: raise QuicPacketBuilderStop # determine ack epoch if packet_type == QuicPacketType.INITIAL: epoch = Epoch.INITIAL elif packet_type == QuicPacketType.HANDSHAKE: epoch = Epoch.HANDSHAKE else: epoch = Epoch.ONE_RTT self._header_size = header_size self._packet = QuicSentPacket( epoch=epoch, in_flight=False, is_ack_eliciting=False, is_crypto_packet=False, packet_number=self._packet_number, packet_type=packet_type, ) self._packet_crypto = crypto self._packet_start = packet_start self._packet_type = packet_type self.quic_logger_frames = self._packet.quic_logger_frames buf.seek(self._packet_start + self._header_size) def _end_packet(self) -> None: """ Ends the current packet. """ buf = self._buffer packet_size = buf.tell() - self._packet_start if packet_size > self._header_size: # padding to ensure sufficient sample size padding_size = ( PACKET_NUMBER_MAX_SIZE - PACKET_NUMBER_SEND_SIZE + self._header_size - packet_size ) # Padding for datagrams containing initial packets; see RFC 9000 # section 14.1. if ( self._is_client or self._packet.is_ack_eliciting ) and self._packet_type == QuicPacketType.INITIAL: self._datagram_needs_padding = True # For datagrams containing 1-RTT data, we *must* apply the padding # inside the packet, we cannot tack bytes onto the end of the # datagram. if ( self._datagram_needs_padding and self._packet_type == QuicPacketType.ONE_RTT ): if self.remaining_flight_space > padding_size: padding_size = self.remaining_flight_space self._datagram_needs_padding = False # write padding if padding_size > 0: buf.push_bytes(bytes(padding_size)) packet_size += padding_size self._packet.in_flight = True # log frame if self._quic_logger is not None: self._packet.quic_logger_frames.append( self._quic_logger.encode_padding_frame() ) # write header if self._packet_type != QuicPacketType.ONE_RTT: length = ( packet_size - self._header_size + PACKET_NUMBER_SEND_SIZE + self._packet_crypto.aead_tag_size ) buf.seek(self._packet_start) buf.push_uint8( encode_long_header_first_byte( self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1 ) ) buf.push_uint32(self._version) buf.push_uint8(len(self._peer_cid)) buf.push_bytes(self._peer_cid) buf.push_uint8(len(self._host_cid)) buf.push_bytes(self._host_cid) if self._packet_type == QuicPacketType.INITIAL: buf.push_uint_var(len(self._peer_token)) buf.push_bytes(self._peer_token) buf.push_uint16(length | 0x4000) buf.push_uint16(self._packet_number & 0xFFFF) else: buf.seek(self._packet_start) buf.push_uint8( PACKET_FIXED_BIT | (self._spin_bit << 5) | (self._packet_crypto.key_phase << 2) | (PACKET_NUMBER_SEND_SIZE - 1) ) buf.push_bytes(self._peer_cid) buf.push_uint16(self._packet_number & 0xFFFF) # encrypt in place plain = buf.data_slice(self._packet_start, self._packet_start + packet_size) buf.seek(self._packet_start) buf.push_bytes( self._packet_crypto.encrypt_packet( plain[0 : self._header_size], plain[self._header_size : packet_size], self._packet_number, ) ) self._packet.sent_bytes = buf.tell() - self._packet_start self._packets.append(self._packet) if self._packet.in_flight: self._datagram_flight_bytes += self._packet.sent_bytes # Short header packets cannot be coalesced, we need a new datagram. if self._packet_type == QuicPacketType.ONE_RTT: self._flush_current_datagram() self._packet_number += 1 else: # "cancel" the packet buf.seek(self._packet_start) self._packet = None self.quic_logger_frames = None def _flush_current_datagram(self) -> None: datagram_bytes = self._buffer.tell() if datagram_bytes: # Padding for datagrams containing initial packets; see RFC 9000 # section 14.1. if self._datagram_needs_padding: extra_bytes = self._flight_capacity - self._buffer.tell() if extra_bytes > 0: self._buffer.push_bytes(bytes(extra_bytes)) self._datagram_flight_bytes += extra_bytes datagram_bytes += extra_bytes self._datagrams.append(self._buffer.data) self._flight_bytes += self._datagram_flight_bytes self._total_bytes += datagram_bytes self._datagram_init = True self._buffer.seek(0) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/rangeset.py0000644000175100001770000000607500000000000021326 0ustar00runnerdocker00000000000000from collections.abc import Sequence from typing import Any, Iterable, List, Optional class RangeSet(Sequence): def __init__(self, ranges: Iterable[range] = []): self.__ranges: List[range] = [] for r in ranges: assert r.step == 1 self.add(r.start, r.stop) def add(self, start: int, stop: Optional[int] = None) -> None: if stop is None: stop = start + 1 assert stop > start for i, r in enumerate(self.__ranges): # the added range is entirely before current item, insert here if stop < r.start: self.__ranges.insert(i, range(start, stop)) return # the added range is entirely after current item, keep looking if start > r.stop: continue # the added range touches the current item, merge it start = min(start, r.start) stop = max(stop, r.stop) while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop: stop = max(self.__ranges[i + 1].stop, stop) self.__ranges.pop(i + 1) self.__ranges[i] = range(start, stop) return # the added range is entirely after all existing items, append it self.__ranges.append(range(start, stop)) def bounds(self) -> range: return range(self.__ranges[0].start, self.__ranges[-1].stop) def shift(self) -> range: return self.__ranges.pop(0) def subtract(self, start: int, stop: int) -> None: assert stop > start i = 0 while i < len(self.__ranges): r = self.__ranges[i] # the removed range is entirely before current item, stop here if stop <= r.start: return # the removed range is entirely after current item, keep looking if start >= r.stop: i += 1 continue # the removed range completely covers the current item, remove it if start <= r.start and stop >= r.stop: self.__ranges.pop(i) continue # the removed range touches the current item if start > r.start: self.__ranges[i] = range(r.start, start) if stop < r.stop: self.__ranges.insert(i + 1, range(stop, r.stop)) else: self.__ranges[i] = range(stop, r.stop) i += 1 def __bool__(self) -> bool: raise NotImplementedError def __contains__(self, val: Any) -> bool: for r in self.__ranges: if val in r: return True return False def __eq__(self, other: object) -> bool: if not isinstance(other, RangeSet): return NotImplemented return self.__ranges == other.__ranges def __getitem__(self, key: Any) -> range: return self.__ranges[key] def __len__(self) -> int: return len(self.__ranges) def __repr__(self) -> str: return "RangeSet({})".format(repr(self.__ranges)) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/recovery.py0000644000175100001770000003233600000000000021353 0ustar00runnerdocker00000000000000import logging import math from typing import Any, Callable, Dict, Iterable, List, Optional from .congestion import cubic, reno # noqa from .congestion.base import K_GRANULARITY, create_congestion_control from .logger import QuicLoggerTrace from .packet_builder import QuicDeliveryState, QuicSentPacket from .rangeset import RangeSet # loss detection K_PACKET_THRESHOLD = 3 K_TIME_THRESHOLD = 9 / 8 K_MICRO_SECOND = 0.000001 K_SECOND = 1.0 class QuicPacketSpace: def __init__(self) -> None: self.ack_at: Optional[float] = None self.ack_queue = RangeSet() self.discarded = False self.expected_packet_number = 0 self.largest_received_packet = -1 self.largest_received_time: Optional[float] = None # sent packets and loss self.ack_eliciting_in_flight = 0 self.largest_acked_packet = 0 self.loss_time: Optional[float] = None self.sent_packets: Dict[int, QuicSentPacket] = {} class QuicPacketPacer: def __init__(self, *, max_datagram_size: int) -> None: self._max_datagram_size = max_datagram_size self.bucket_max: float = 0.0 self.bucket_time: float = 0.0 self.evaluation_time: float = 0.0 self.packet_time: Optional[float] = None def next_send_time(self, now: float) -> float: if self.packet_time is not None: self.update_bucket(now=now) if self.bucket_time <= 0: return now + self.packet_time return None def update_after_send(self, now: float) -> None: if self.packet_time is not None: self.update_bucket(now=now) if self.bucket_time < self.packet_time: self.bucket_time = 0.0 else: self.bucket_time -= self.packet_time def update_bucket(self, now: float) -> None: if now > self.evaluation_time: self.bucket_time = min( self.bucket_time + (now - self.evaluation_time), self.bucket_max ) self.evaluation_time = now def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND) self.packet_time = max( K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND) ) self.bucket_max = ( max( 2 * self._max_datagram_size, min(congestion_window // 4, 16 * self._max_datagram_size), ) / pacing_rate ) if self.bucket_time > self.bucket_max: self.bucket_time = self.bucket_max class QuicPacketRecovery: """ Packet loss and congestion controller. """ def __init__( self, *, congestion_control_algorithm: str, initial_rtt: float, max_datagram_size: int, peer_completed_address_validation: bool, send_probe: Callable[[], None], logger: Optional[logging.LoggerAdapter] = None, quic_logger: Optional[QuicLoggerTrace] = None, ) -> None: self.max_ack_delay = 0.025 self.peer_completed_address_validation = peer_completed_address_validation self.spaces: List[QuicPacketSpace] = [] # callbacks self._logger = logger self._quic_logger = quic_logger self._send_probe = send_probe # loss detection self._pto_count = 0 self._rtt_initial = initial_rtt self._rtt_initialized = False self._rtt_latest = 0.0 self._rtt_min = math.inf self._rtt_smoothed = 0.0 self._rtt_variance = 0.0 self._time_of_last_sent_ack_eliciting_packet = 0.0 # congestion control self._cc = create_congestion_control( congestion_control_algorithm, max_datagram_size=max_datagram_size ) self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size) @property def bytes_in_flight(self) -> int: return self._cc.bytes_in_flight @property def congestion_window(self) -> int: return self._cc.congestion_window def discard_space(self, space: QuicPacketSpace) -> None: assert space in self.spaces self._cc.on_packets_expired( packets=filter(lambda x: x.in_flight, space.sent_packets.values()) ) space.sent_packets.clear() space.ack_at = None space.ack_eliciting_in_flight = 0 space.loss_time = None # reset PTO count self._pto_count = 0 if self._quic_logger is not None: self._log_metrics_updated() def get_loss_detection_time(self) -> float: # loss timer loss_space = self._get_loss_space() if loss_space is not None: return loss_space.loss_time # packet timer if ( not self.peer_completed_address_validation or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0 ): timeout = self.get_probe_timeout() * (2**self._pto_count) return self._time_of_last_sent_ack_eliciting_packet + timeout return None def get_probe_timeout(self) -> float: if not self._rtt_initialized: return 2 * self._rtt_initial return ( self._rtt_smoothed + max(4 * self._rtt_variance, K_GRANULARITY) + self.max_ack_delay ) def on_ack_received( self, *, ack_rangeset: RangeSet, ack_delay: float, now: float, space: QuicPacketSpace, ) -> None: """ Update metrics as the result of an ACK being received. """ is_ack_eliciting = False largest_acked = ack_rangeset.bounds().stop - 1 largest_newly_acked = None largest_sent_time = None if largest_acked > space.largest_acked_packet: space.largest_acked_packet = largest_acked for packet_number in sorted(space.sent_packets.keys()): if packet_number > largest_acked: break if packet_number in ack_rangeset: # remove packet and update counters packet = space.sent_packets.pop(packet_number) if packet.is_ack_eliciting: is_ack_eliciting = True space.ack_eliciting_in_flight -= 1 if packet.in_flight: self._cc.on_packet_acked(packet=packet, now=now) largest_newly_acked = packet_number largest_sent_time = packet.sent_time # trigger callbacks for handler, args in packet.delivery_handlers: handler(QuicDeliveryState.ACKED, *args) # nothing to do if there are no newly acked packets if largest_newly_acked is None: return if largest_acked == largest_newly_acked and is_ack_eliciting: latest_rtt = now - largest_sent_time log_rtt = True # limit ACK delay to max_ack_delay ack_delay = min(ack_delay, self.max_ack_delay) # update RTT estimate, which cannot be < 1 ms self._rtt_latest = max(latest_rtt, 0.001) if self._rtt_latest < self._rtt_min: self._rtt_min = self._rtt_latest if self._rtt_latest > self._rtt_min + ack_delay: self._rtt_latest -= ack_delay if not self._rtt_initialized: self._rtt_initialized = True self._rtt_variance = latest_rtt / 2 self._rtt_smoothed = latest_rtt else: self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs( self._rtt_min - self._rtt_latest ) self._rtt_smoothed = ( 7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest ) # inform congestion controller self._cc.on_rtt_measurement(now=now, rtt=latest_rtt) self._pacer.update_rate( congestion_window=self._cc.congestion_window, smoothed_rtt=self._rtt_smoothed, ) else: log_rtt = False self._detect_loss(now=now, space=space) # reset PTO count self._pto_count = 0 if self._quic_logger is not None: self._log_metrics_updated(log_rtt=log_rtt) def on_loss_detection_timeout(self, *, now: float) -> None: loss_space = self._get_loss_space() if loss_space is not None: self._detect_loss(now=now, space=loss_space) else: self._pto_count += 1 self.reschedule_data(now=now) def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None: space.sent_packets[packet.packet_number] = packet if packet.is_ack_eliciting: space.ack_eliciting_in_flight += 1 if packet.in_flight: if packet.is_ack_eliciting: self._time_of_last_sent_ack_eliciting_packet = packet.sent_time # add packet to bytes in flight self._cc.on_packet_sent(packet=packet) if self._quic_logger is not None: self._log_metrics_updated() def reschedule_data(self, *, now: float) -> None: """ Schedule some data for retransmission. """ # if there is any outstanding CRYPTO, retransmit it crypto_scheduled = False for space in self.spaces: packets = tuple( filter(lambda i: i.is_crypto_packet, space.sent_packets.values()) ) if packets: self._on_packets_lost(now=now, packets=packets, space=space) crypto_scheduled = True if crypto_scheduled and self._logger is not None: self._logger.debug("Scheduled CRYPTO data for retransmission") # ensure an ACK-elliciting packet is sent self._send_probe() def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None: """ Check whether any packets should be declared lost. """ loss_delay = K_TIME_THRESHOLD * ( max(self._rtt_latest, self._rtt_smoothed) if self._rtt_initialized else self._rtt_initial ) packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD time_threshold = now - loss_delay lost_packets = [] space.loss_time = None for packet_number, packet in space.sent_packets.items(): if packet_number > space.largest_acked_packet: break if packet_number <= packet_threshold or packet.sent_time <= time_threshold: lost_packets.append(packet) else: packet_loss_time = packet.sent_time + loss_delay if space.loss_time is None or space.loss_time > packet_loss_time: space.loss_time = packet_loss_time self._on_packets_lost(now=now, packets=lost_packets, space=space) def _get_loss_space(self) -> Optional[QuicPacketSpace]: loss_space = None for space in self.spaces: if space.loss_time is not None and ( loss_space is None or space.loss_time < loss_space.loss_time ): loss_space = space return loss_space def _log_metrics_updated(self, log_rtt=False) -> None: data: Dict[str, Any] = self._cc.get_log_data() if log_rtt: data.update( { "latest_rtt": self._quic_logger.encode_time(self._rtt_latest), "min_rtt": self._quic_logger.encode_time(self._rtt_min), "smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed), "rtt_variance": self._quic_logger.encode_time(self._rtt_variance), } ) self._quic_logger.log_event( category="recovery", event="metrics_updated", data=data ) def _on_packets_lost( self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace ) -> None: lost_packets_cc = [] for packet in packets: del space.sent_packets[packet.packet_number] if packet.in_flight: lost_packets_cc.append(packet) if packet.is_ack_eliciting: space.ack_eliciting_in_flight -= 1 if self._quic_logger is not None: self._quic_logger.log_event( category="recovery", event="packet_lost", data={ "type": self._quic_logger.packet_type(packet.packet_type), "packet_number": packet.packet_number, }, ) self._log_metrics_updated() # trigger callbacks for handler, args in packet.delivery_handlers: handler(QuicDeliveryState.LOST, *args) # inform congestion controller if lost_packets_cc: self._cc.on_packets_lost(now=now, packets=lost_packets_cc) self._pacer.update_rate( congestion_window=self._cc.congestion_window, smoothed_rtt=self._rtt_smoothed, ) if self._quic_logger is not None: self._log_metrics_updated() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/retry.py0000644000175100001770000000353200000000000020656 0ustar00runnerdocker00000000000000import ipaddress from typing import Tuple from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding, rsa from ..buffer import Buffer from ..tls import pull_opaque, push_opaque from .connection import NetworkAddress def encode_address(addr: NetworkAddress) -> bytes: return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF]) class QuicRetryTokenHandler: def __init__(self) -> None: self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048) def create_token( self, addr: NetworkAddress, original_destination_connection_id: bytes, retry_source_connection_id: bytes, ) -> bytes: buf = Buffer(capacity=512) push_opaque(buf, 1, encode_address(addr)) push_opaque(buf, 1, original_destination_connection_id) push_opaque(buf, 1, retry_source_connection_id) return self._key.public_key().encrypt( buf.data, padding.OAEP( mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None ), ) def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]: buf = Buffer( data=self._key.decrypt( token, padding.OAEP( mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) ) encoded_addr = pull_opaque(buf, 1) original_destination_connection_id = pull_opaque(buf, 1) retry_source_connection_id = pull_opaque(buf, 1) if encoded_addr != encode_address(addr): raise ValueError("Remote address does not match.") return original_destination_connection_id, retry_source_connection_id ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/quic/stream.py0000644000175100001770000002746100000000000021013 0ustar00runnerdocker00000000000000from typing import Optional from . import events from .packet import ( QuicErrorCode, QuicResetStreamFrame, QuicStopSendingFrame, QuicStreamFrame, ) from .packet_builder import QuicDeliveryState from .rangeset import RangeSet class FinalSizeError(Exception): pass class StreamFinishedError(Exception): pass class QuicStreamReceiver: """ The receive part of a QUIC stream. It finishes: - immediately for a send-only stream - upon reception of a STREAM_RESET frame - upon reception of a data frame with the FIN bit set """ def __init__(self, stream_id: Optional[int], readable: bool) -> None: self.highest_offset = 0 # the highest offset ever seen self.is_finished = False self.stop_pending = False self._buffer = bytearray() self._buffer_start = 0 # the offset for the start of the buffer self._final_size: Optional[int] = None self._ranges = RangeSet() self._stream_id = stream_id self._stop_error_code: Optional[int] = None def get_stop_frame(self) -> QuicStopSendingFrame: self.stop_pending = False return QuicStopSendingFrame( error_code=self._stop_error_code, stream_id=self._stream_id, ) def starting_offset(self) -> int: return self._buffer_start def handle_frame( self, frame: QuicStreamFrame ) -> Optional[events.StreamDataReceived]: """ Handle a frame of received data. """ pos = frame.offset - self._buffer_start count = len(frame.data) frame_end = frame.offset + count # we should receive no more data beyond FIN! if self._final_size is not None: if frame_end > self._final_size: raise FinalSizeError("Data received beyond final size") elif frame.fin and frame_end != self._final_size: raise FinalSizeError("Cannot change final size") if frame.fin: self._final_size = frame_end if frame_end > self.highest_offset: self.highest_offset = frame_end # fast path: new in-order chunk if pos == 0 and count and not self._buffer: self._buffer_start += count if frame.fin: # all data up to the FIN has been received, we're done receiving self.is_finished = True return events.StreamDataReceived( data=frame.data, end_stream=frame.fin, stream_id=self._stream_id ) # discard duplicate data if pos < 0: frame.data = frame.data[-pos:] frame.offset -= pos pos = 0 count = len(frame.data) # marked received range if frame_end > frame.offset: self._ranges.add(frame.offset, frame_end) # add new data gap = pos - len(self._buffer) if gap > 0: self._buffer += bytearray(gap) self._buffer[pos : pos + count] = frame.data # return data from the front of the buffer data = self._pull_data() end_stream = self._buffer_start == self._final_size if end_stream: # all data up to the FIN has been received, we're done receiving self.is_finished = True if data or end_stream: return events.StreamDataReceived( data=data, end_stream=end_stream, stream_id=self._stream_id ) else: return None def handle_reset( self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR ) -> Optional[events.StreamReset]: """ Handle an abrupt termination of the receiving part of the QUIC stream. """ if self._final_size is not None and final_size != self._final_size: raise FinalSizeError("Cannot change final size") # we are done receiving self._final_size = final_size self.is_finished = True return events.StreamReset(error_code=error_code, stream_id=self._stream_id) def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None: """ Callback when a STOP_SENDING is ACK'd. """ if delivery != QuicDeliveryState.ACKED: self.stop_pending = True def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None: """ Request the peer stop sending data on the QUIC stream. """ self._stop_error_code = error_code self.stop_pending = True def _pull_data(self) -> bytes: """ Remove data from the front of the buffer. """ try: has_data_to_read = self._ranges[0].start == self._buffer_start except IndexError: has_data_to_read = False if not has_data_to_read: return b"" r = self._ranges.shift() pos = r.stop - r.start data = bytes(self._buffer[:pos]) del self._buffer[:pos] self._buffer_start = r.stop return data class QuicStreamSender: """ The send part of a QUIC stream. It finishes: - immediately for a receive-only stream - upon acknowledgement of a STREAM_RESET frame - upon acknowledgement of a data frame with the FIN bit set """ def __init__(self, stream_id: Optional[int], writable: bool) -> None: self.buffer_is_empty = True self.highest_offset = 0 self.is_finished = not writable self.reset_pending = False self._acked = RangeSet() self._acked_fin = False self._buffer = bytearray() self._buffer_fin: Optional[int] = None self._buffer_start = 0 # the offset for the start of the buffer self._buffer_stop = 0 # the offset for the stop of the buffer self._pending = RangeSet() self._pending_eof = False self._reset_error_code: Optional[int] = None self._stream_id = stream_id @property def next_offset(self) -> int: """ The offset for the next frame to send. This is used to determine the space needed for the frame's `offset` field. """ try: return self._pending[0].start except IndexError: return self._buffer_stop def get_frame( self, max_size: int, max_offset: Optional[int] = None ) -> Optional[QuicStreamFrame]: """ Get a frame of data to send. """ assert self._reset_error_code is None, "cannot call get_frame() after reset()" # get the first pending data range try: r = self._pending[0] except IndexError: if self._pending_eof: # FIN only self._pending_eof = False return QuicStreamFrame(fin=True, offset=self._buffer_fin) self.buffer_is_empty = True return None # apply flow control start = r.start stop = min(r.stop, start + max_size) if max_offset is not None and stop > max_offset: stop = max_offset if stop <= start: return None # create frame frame = QuicStreamFrame( data=bytes( self._buffer[start - self._buffer_start : stop - self._buffer_start] ), offset=start, ) self._pending.subtract(start, stop) # track the highest offset ever sent if stop > self.highest_offset: self.highest_offset = stop # if the buffer is empty and EOF was written, set the FIN bit if self._buffer_fin == stop: frame.fin = True self._pending_eof = False return frame def get_reset_frame(self) -> QuicResetStreamFrame: self.reset_pending = False return QuicResetStreamFrame( error_code=self._reset_error_code, final_size=self.highest_offset, stream_id=self._stream_id, ) def on_data_delivery( self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool ) -> None: """ Callback when sent data is ACK'd. """ # If the frame had the FIN bit set, its end MUST match otherwise # we have a programming error. assert ( not fin or stop == self._buffer_fin ), "on_data_delivered() was called with inconsistent fin / stop" # If a reset has been requested, stop processing data delivery. # The transition to the finished state only depends on the reset # being acknowledged. if self._reset_error_code is not None: return if delivery == QuicDeliveryState.ACKED: if stop > start: # Some data has been ACK'd, discard it. self._acked.add(start, stop) first_range = self._acked[0] if first_range.start == self._buffer_start: size = first_range.stop - first_range.start self._acked.shift() self._buffer_start += size del self._buffer[:size] if fin: # The FIN has been ACK'd. self._acked_fin = True if self._buffer_start == self._buffer_fin and self._acked_fin: # All data and the FIN have been ACK'd, we're done sending. self.is_finished = True else: if stop > start: # Some data has been lost, reschedule it. self.buffer_is_empty = False self._pending.add(start, stop) if fin: # The FIN has been lost, reschedule it. self.buffer_is_empty = False self._pending_eof = True def on_reset_delivery(self, delivery: QuicDeliveryState) -> None: """ Callback when a reset is ACK'd. """ if delivery == QuicDeliveryState.ACKED: # The reset has been ACK'd, we're done sending. self.is_finished = True else: # The reset has been lost, reschedule it. self.reset_pending = True def reset(self, error_code: int) -> None: """ Abruptly terminate the sending part of the QUIC stream. """ assert self._reset_error_code is None, "cannot call reset() more than once" self._reset_error_code = error_code self.reset_pending = True # Prevent any more data from being sent or re-sent. self.buffer_is_empty = True def write(self, data: bytes, end_stream: bool = False) -> None: """ Write some data bytes to the QUIC stream. """ assert self._buffer_fin is None, "cannot call write() after FIN" assert self._reset_error_code is None, "cannot call write() after reset()" size = len(data) if size: self.buffer_is_empty = False self._pending.add(self._buffer_stop, self._buffer_stop + size) self._buffer += data self._buffer_stop += size if end_stream: self.buffer_is_empty = False self._buffer_fin = self._buffer_stop self._pending_eof = True class QuicStream: def __init__( self, stream_id: Optional[int] = None, max_stream_data_local: int = 0, max_stream_data_remote: int = 0, readable: bool = True, writable: bool = True, ) -> None: self.is_blocked = False self.max_stream_data_local = max_stream_data_local self.max_stream_data_local_sent = max_stream_data_local self.max_stream_data_remote = max_stream_data_remote self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable) self.sender = QuicStreamSender(stream_id=stream_id, writable=writable) self.stream_id = stream_id @property def is_finished(self) -> bool: return self.receiver.is_finished and self.sender.is_finished ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/src/aioquic/tls.py0000644000175100001770000022517200000000000017360 0ustar00runnerdocker00000000000000import datetime import ipaddress import logging import os import ssl import struct from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum, IntEnum from functools import partial from typing import ( Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, TypeVar, Union, cast, ) import certifi import service_identity from cryptography import x509 from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ( dsa, ec, ed448, ed25519, padding, rsa, x448, x25519, ) from cryptography.hazmat.primitives.asymmetric.types import ( CertificateIssuerPublicKeyTypes, PrivateKeyTypes, ) from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from OpenSSL import crypto from .buffer import Buffer, BufferReadError TLS_VERSION_1_2 = 0x0303 TLS_VERSION_1_3 = 0x0304 TLS_VERSION_1_3_DRAFT_28 = 0x7F1C TLS_VERSION_1_3_DRAFT_27 = 0x7F1B TLS_VERSION_1_3_DRAFT_26 = 0x7F1A CLIENT_CONTEXT_STRING = b"TLS 1.3, client CertificateVerify" SERVER_CONTEXT_STRING = b"TLS 1.3, server CertificateVerify" T = TypeVar("T") # facilitate mocking for the test suite def utcnow() -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) class AlertDescription(IntEnum): close_notify = 0 unexpected_message = 10 bad_record_mac = 20 record_overflow = 22 handshake_failure = 40 bad_certificate = 42 unsupported_certificate = 43 certificate_revoked = 44 certificate_expired = 45 certificate_unknown = 46 illegal_parameter = 47 unknown_ca = 48 access_denied = 49 decode_error = 50 decrypt_error = 51 protocol_version = 70 insufficient_security = 71 internal_error = 80 inappropriate_fallback = 86 user_canceled = 90 missing_extension = 109 unsupported_extension = 110 unrecognized_name = 112 bad_certificate_status_response = 113 unknown_psk_identity = 115 certificate_required = 116 no_application_protocol = 120 class Alert(Exception): description: AlertDescription class AlertBadCertificate(Alert): description = AlertDescription.bad_certificate class AlertCertificateExpired(Alert): description = AlertDescription.certificate_expired class AlertDecodeError(Alert): description = AlertDescription.decode_error class AlertDecryptError(Alert): description = AlertDescription.decrypt_error class AlertHandshakeFailure(Alert): description = AlertDescription.handshake_failure class AlertIllegalParameter(Alert): description = AlertDescription.illegal_parameter class AlertInternalError(Alert): description = AlertDescription.internal_error class AlertProtocolVersion(Alert): description = AlertDescription.protocol_version class AlertUnexpectedMessage(Alert): description = AlertDescription.unexpected_message class Direction(Enum): DECRYPT = 0 ENCRYPT = 1 class Epoch(Enum): INITIAL = 0 ZERO_RTT = 1 HANDSHAKE = 2 ONE_RTT = 3 class State(Enum): CLIENT_HANDSHAKE_START = 0 CLIENT_EXPECT_SERVER_HELLO = 1 CLIENT_EXPECT_ENCRYPTED_EXTENSIONS = 2 CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE = 3 CLIENT_EXPECT_CERTIFICATE = 4 CLIENT_EXPECT_CERTIFICATE_VERIFY = 5 CLIENT_EXPECT_FINISHED = 6 CLIENT_POST_HANDSHAKE = 7 SERVER_EXPECT_CLIENT_HELLO = 8 SERVER_EXPECT_CERTIFICATE = 9 SERVER_EXPECT_CERTIFICATE_VERIFY = 10 SERVER_EXPECT_FINISHED = 11 SERVER_POST_HANDSHAKE = 12 def hkdf_label(label: bytes, hash_value: bytes, length: int) -> bytes: full_label = b"tls13 " + label return ( struct.pack("!HB", length, len(full_label)) + full_label + struct.pack("!B", len(hash_value)) + hash_value ) def hkdf_expand_label( algorithm: hashes.HashAlgorithm, secret: bytes, label: bytes, hash_value: bytes, length: int, ) -> bytes: return HKDFExpand( algorithm=algorithm, length=length, info=hkdf_label(label, hash_value, length), ).derive(secret) def hkdf_extract( algorithm: hashes.HashAlgorithm, salt: bytes, key_material: bytes ) -> bytes: h = hmac.HMAC(salt, algorithm) h.update(key_material) return h.finalize() def load_pem_private_key( data: bytes, password: Optional[bytes] = None ) -> PrivateKeyTypes: """ Load a PEM-encoded private key. """ return serialization.load_pem_private_key(data, password=password) def load_pem_x509_certificates(data: bytes) -> List[x509.Certificate]: """ Load a chain of PEM-encoded X509 certificates. """ boundary = b"-----END CERTIFICATE-----\n" certificates = [] for chunk in data.split(boundary): if chunk: certificates.append(x509.load_pem_x509_certificate(chunk + boundary)) return certificates def verify_certificate( certificate: x509.Certificate, chain: List[x509.Certificate] = [], server_name: Optional[str] = None, cadata: Optional[bytes] = None, cafile: Optional[str] = None, capath: Optional[str] = None, ) -> None: # verify dates now = utcnow() if now < certificate.not_valid_before_utc: raise AlertCertificateExpired("Certificate is not valid yet") if now > certificate.not_valid_after_utc: raise AlertCertificateExpired("Certificate is no longer valid") # verify subject if server_name is not None: try: ipaddress.ip_address(server_name) except ValueError: is_ip = False else: is_ip = True try: if is_ip: service_identity.cryptography.verify_certificate_ip_address( certificate, server_name ) else: service_identity.cryptography.verify_certificate_hostname( certificate, server_name ) except ( service_identity.CertificateError, service_identity.VerificationError, ) as exc: patterns = service_identity.cryptography.extract_patterns(certificate) if len(patterns) == 0: errmsg = str(exc) elif len(patterns) == 1: errmsg = f"hostname {server_name!r} doesn't match {patterns[0]!r}" else: patterns_repr = ", ".join(repr(pattern) for pattern in patterns) errmsg = ( f"hostname {server_name!r} doesn't match " f"either of {patterns_repr}" ) raise AlertBadCertificate(errmsg) from exc # load CAs store = crypto.X509Store() if cadata is None and cafile is None and capath is None: # Load defaults from certifi. store.load_locations(certifi.where()) if cadata is not None: for cert in load_pem_x509_certificates(cadata): store.add_cert(crypto.X509.from_cryptography(cert)) if cafile is not None or capath is not None: store.load_locations(cafile, capath) # verify certificate chain store_ctx = crypto.X509StoreContext( store, crypto.X509.from_cryptography(certificate), [crypto.X509.from_cryptography(cert) for cert in chain], ) try: store_ctx.verify_certificate() except crypto.X509StoreContextError as exc: raise AlertBadCertificate(exc.args[0]) class CipherSuite(IntEnum): AES_128_GCM_SHA256 = 0x1301 AES_256_GCM_SHA384 = 0x1302 CHACHA20_POLY1305_SHA256 = 0x1303 EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF class CompressionMethod(IntEnum): NULL = 0 class ExtensionType(IntEnum): SERVER_NAME = 0 STATUS_REQUEST = 5 SUPPORTED_GROUPS = 10 SIGNATURE_ALGORITHMS = 13 ALPN = 16 COMPRESS_CERTIFICATE = 27 PRE_SHARED_KEY = 41 EARLY_DATA = 42 SUPPORTED_VERSIONS = 43 COOKIE = 44 PSK_KEY_EXCHANGE_MODES = 45 KEY_SHARE = 51 QUIC_TRANSPORT_PARAMETERS = 0x0039 QUIC_TRANSPORT_PARAMETERS_DRAFT = 0xFFA5 ENCRYPTED_SERVER_NAME = 65486 class Group(IntEnum): SECP256R1 = 0x0017 SECP384R1 = 0x0018 SECP521R1 = 0x0019 X25519 = 0x001D X448 = 0x001E GREASE = 0xAAAA class HandshakeType(IntEnum): CLIENT_HELLO = 1 SERVER_HELLO = 2 NEW_SESSION_TICKET = 4 END_OF_EARLY_DATA = 5 ENCRYPTED_EXTENSIONS = 8 CERTIFICATE = 11 CERTIFICATE_REQUEST = 13 CERTIFICATE_VERIFY = 15 FINISHED = 20 KEY_UPDATE = 24 COMPRESSED_CERTIFICATE = 25 MESSAGE_HASH = 254 class NameType(IntEnum): HOST_NAME = 0 class PskKeyExchangeMode(IntEnum): PSK_KE = 0 PSK_DHE_KE = 1 class SignatureAlgorithm(IntEnum): ECDSA_SECP256R1_SHA256 = 0x0403 ECDSA_SECP384R1_SHA384 = 0x0503 ECDSA_SECP521R1_SHA512 = 0x0603 ED25519 = 0x0807 ED448 = 0x0808 RSA_PKCS1_SHA256 = 0x0401 RSA_PKCS1_SHA384 = 0x0501 RSA_PKCS1_SHA512 = 0x0601 RSA_PSS_PSS_SHA256 = 0x0809 RSA_PSS_PSS_SHA384 = 0x080A RSA_PSS_PSS_SHA512 = 0x080B RSA_PSS_RSAE_SHA256 = 0x0804 RSA_PSS_RSAE_SHA384 = 0x0805 RSA_PSS_RSAE_SHA512 = 0x0806 # legacy RSA_PKCS1_SHA1 = 0x0201 SHA1_DSA = 0x0202 ECDSA_SHA1 = 0x0203 # BLOCKS @contextmanager def pull_block(buf: Buffer, capacity: int) -> Generator: length = int.from_bytes(buf.pull_bytes(capacity), byteorder="big") end = buf.tell() + length yield length if buf.tell() != end: # There was trailing garbage or our parsing was bad. raise AlertDecodeError("extra bytes at the end of a block") @contextmanager def push_block(buf: Buffer, capacity: int) -> Generator: """ Context manager to push a variable-length block, with `capacity` bytes to write the length. """ start = buf.tell() + capacity buf.seek(start) yield end = buf.tell() length = end - start buf.seek(start - capacity) buf.push_bytes(length.to_bytes(capacity, byteorder="big")) buf.seek(end) # LISTS class SkipItem(Exception): "There is nothing to append for this invocation of a pull_list() func" def pull_list(buf: Buffer, capacity: int, func: Callable[[], T]) -> List[T]: """ Pull a list of items. If the callable raises SkipItem, then iteration continues but nothing is added to the list. """ items = [] with pull_block(buf, capacity) as length: end = buf.tell() + length while buf.tell() < end: try: items.append(func()) except SkipItem: pass return items def push_list( buf: Buffer, capacity: int, func: Callable[[T], None], values: Sequence[T] ) -> None: """ Push a list of items. """ with push_block(buf, capacity): for value in values: func(value) def pull_opaque(buf: Buffer, capacity: int) -> bytes: """ Pull an opaque value prefixed by a length. """ with pull_block(buf, capacity) as length: return buf.pull_bytes(length) def push_opaque(buf: Buffer, capacity: int, value: bytes) -> None: """ Push an opaque value prefix by a length. """ with push_block(buf, capacity): buf.push_bytes(value) @contextmanager def push_extension(buf: Buffer, extension_type: int) -> Generator: buf.push_uint16(extension_type) with push_block(buf, 2): yield # ServerName def pull_server_name(buf: Buffer) -> str: with pull_block(buf, 2): name_type = buf.pull_uint8() if name_type != NameType.HOST_NAME: # We don't know this name_type. raise AlertIllegalParameter( f"ServerName has an unknown name type {name_type}" ) return pull_opaque(buf, 2).decode("ascii") def push_server_name(buf: Buffer, server_name: str) -> None: with push_block(buf, 2): buf.push_uint8(NameType.HOST_NAME) push_opaque(buf, 2, server_name.encode("ascii")) # KeyShareEntry KeyShareEntry = Tuple[int, bytes] def pull_key_share(buf: Buffer) -> KeyShareEntry: group = buf.pull_uint16() data = pull_opaque(buf, 2) return (group, data) def push_key_share(buf: Buffer, value: KeyShareEntry) -> None: buf.push_uint16(value[0]) push_opaque(buf, 2, value[1]) # ALPN def pull_alpn_protocol(buf: Buffer) -> str: try: return pull_opaque(buf, 1).decode("ascii") except UnicodeDecodeError: # We can get arbitrary bytes values for alpns from greasing, # but we expect them to be strings in the rest of the API, so # we ignore them if they don't decode as ASCII. raise SkipItem def push_alpn_protocol(buf: Buffer, protocol: str) -> None: push_opaque(buf, 1, protocol.encode("ascii")) # PRE SHARED KEY PskIdentity = Tuple[bytes, int] @dataclass class OfferedPsks: identities: List[PskIdentity] binders: List[bytes] def pull_psk_identity(buf: Buffer) -> PskIdentity: identity = pull_opaque(buf, 2) obfuscated_ticket_age = buf.pull_uint32() return (identity, obfuscated_ticket_age) def push_psk_identity(buf: Buffer, entry: PskIdentity) -> None: push_opaque(buf, 2, entry[0]) buf.push_uint32(entry[1]) def pull_psk_binder(buf: Buffer) -> bytes: return pull_opaque(buf, 1) def push_psk_binder(buf: Buffer, binder: bytes) -> None: push_opaque(buf, 1, binder) def pull_offered_psks(buf: Buffer) -> OfferedPsks: return OfferedPsks( identities=pull_list(buf, 2, partial(pull_psk_identity, buf)), binders=pull_list(buf, 2, partial(pull_psk_binder, buf)), ) def push_offered_psks(buf: Buffer, pre_shared_key: OfferedPsks) -> None: push_list( buf, 2, partial(push_psk_identity, buf), pre_shared_key.identities, ) push_list( buf, 2, partial(push_psk_binder, buf), pre_shared_key.binders, ) # MESSAGES Extension = Tuple[int, bytes] @dataclass class ClientHello: random: bytes legacy_session_id: bytes cipher_suites: List[int] legacy_compression_methods: List[int] # extensions alpn_protocols: Optional[List[str]] = None early_data: bool = False key_share: Optional[List[KeyShareEntry]] = None pre_shared_key: Optional[OfferedPsks] = None psk_key_exchange_modes: Optional[List[int]] = None server_name: Optional[str] = None signature_algorithms: Optional[List[int]] = None supported_groups: Optional[List[int]] = None supported_versions: Optional[List[int]] = None other_extensions: List[Extension] = field(default_factory=list) def pull_handshake_type(buf: Buffer, expected_type: HandshakeType) -> None: """ Pull the message type and assert it is the expected one. If it is not, we have a programming error. """ message_type = buf.pull_uint8() assert message_type == expected_type def pull_client_hello(buf: Buffer) -> ClientHello: pull_handshake_type(buf, HandshakeType.CLIENT_HELLO) with pull_block(buf, 3): if buf.pull_uint16() != TLS_VERSION_1_2: raise AlertDecodeError("ClientHello version is not 1.2") hello = ClientHello( random=buf.pull_bytes(32), legacy_session_id=pull_opaque(buf, 1), cipher_suites=pull_list(buf, 2, buf.pull_uint16), legacy_compression_methods=pull_list(buf, 1, buf.pull_uint8), ) # extensions after_psk = False def pull_extension() -> None: # pre_shared_key MUST be last nonlocal after_psk if after_psk: # the alert is Illegal Parameter per RFC 8446 section 4.2.11. raise AlertIllegalParameter("PreSharedKey is not the last extension") extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() if extension_type == ExtensionType.KEY_SHARE: hello.key_share = pull_list(buf, 2, partial(pull_key_share, buf)) elif extension_type == ExtensionType.SUPPORTED_VERSIONS: hello.supported_versions = pull_list(buf, 1, buf.pull_uint16) elif extension_type == ExtensionType.SIGNATURE_ALGORITHMS: hello.signature_algorithms = pull_list(buf, 2, buf.pull_uint16) elif extension_type == ExtensionType.SUPPORTED_GROUPS: hello.supported_groups = pull_list(buf, 2, buf.pull_uint16) elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES: hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8) elif extension_type == ExtensionType.SERVER_NAME: hello.server_name = pull_server_name(buf) elif extension_type == ExtensionType.ALPN: hello.alpn_protocols = pull_list( buf, 2, partial(pull_alpn_protocol, buf) ) elif extension_type == ExtensionType.EARLY_DATA: hello.early_data = True elif extension_type == ExtensionType.PRE_SHARED_KEY: hello.pre_shared_key = pull_offered_psks(buf) after_psk = True else: hello.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) ) pull_list(buf, 2, pull_extension) return hello def push_client_hello(buf: Buffer, hello: ClientHello) -> None: buf.push_uint8(HandshakeType.CLIENT_HELLO) with push_block(buf, 3): buf.push_uint16(TLS_VERSION_1_2) buf.push_bytes(hello.random) push_opaque(buf, 1, hello.legacy_session_id) push_list(buf, 2, buf.push_uint16, hello.cipher_suites) push_list(buf, 1, buf.push_uint8, hello.legacy_compression_methods) # extensions with push_block(buf, 2): with push_extension(buf, ExtensionType.KEY_SHARE): push_list(buf, 2, partial(push_key_share, buf), hello.key_share) with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS): push_list(buf, 1, buf.push_uint16, hello.supported_versions) with push_extension(buf, ExtensionType.SIGNATURE_ALGORITHMS): push_list(buf, 2, buf.push_uint16, hello.signature_algorithms) with push_extension(buf, ExtensionType.SUPPORTED_GROUPS): push_list(buf, 2, buf.push_uint16, hello.supported_groups) if hello.psk_key_exchange_modes is not None: with push_extension(buf, ExtensionType.PSK_KEY_EXCHANGE_MODES): push_list(buf, 1, buf.push_uint8, hello.psk_key_exchange_modes) if hello.server_name is not None: with push_extension(buf, ExtensionType.SERVER_NAME): push_server_name(buf, hello.server_name) if hello.alpn_protocols is not None: with push_extension(buf, ExtensionType.ALPN): push_list( buf, 2, partial(push_alpn_protocol, buf), hello.alpn_protocols ) for extension_type, extension_value in hello.other_extensions: with push_extension(buf, extension_type): buf.push_bytes(extension_value) if hello.early_data: with push_extension(buf, ExtensionType.EARLY_DATA): pass # pre_shared_key MUST be last if hello.pre_shared_key is not None: with push_extension(buf, ExtensionType.PRE_SHARED_KEY): push_offered_psks(buf, hello.pre_shared_key) @dataclass class ServerHello: random: bytes legacy_session_id: bytes cipher_suite: int compression_method: int # extensions key_share: Optional[KeyShareEntry] = None pre_shared_key: Optional[int] = None supported_version: Optional[int] = None other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) def pull_server_hello(buf: Buffer) -> ServerHello: pull_handshake_type(buf, HandshakeType.SERVER_HELLO) with pull_block(buf, 3): if buf.pull_uint16() != TLS_VERSION_1_2: raise AlertDecodeError("ServerHello version is not 1.2") hello = ServerHello( random=buf.pull_bytes(32), legacy_session_id=pull_opaque(buf, 1), cipher_suite=buf.pull_uint16(), compression_method=buf.pull_uint8(), ) # extensions def pull_extension() -> None: extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() if extension_type == ExtensionType.SUPPORTED_VERSIONS: hello.supported_version = buf.pull_uint16() elif extension_type == ExtensionType.KEY_SHARE: hello.key_share = pull_key_share(buf) elif extension_type == ExtensionType.PRE_SHARED_KEY: hello.pre_shared_key = buf.pull_uint16() else: hello.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) ) pull_list(buf, 2, pull_extension) return hello def push_server_hello(buf: Buffer, hello: ServerHello) -> None: buf.push_uint8(HandshakeType.SERVER_HELLO) with push_block(buf, 3): buf.push_uint16(TLS_VERSION_1_2) buf.push_bytes(hello.random) push_opaque(buf, 1, hello.legacy_session_id) buf.push_uint16(hello.cipher_suite) buf.push_uint8(hello.compression_method) # extensions with push_block(buf, 2): if hello.supported_version is not None: with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS): buf.push_uint16(hello.supported_version) if hello.key_share is not None: with push_extension(buf, ExtensionType.KEY_SHARE): push_key_share(buf, hello.key_share) if hello.pre_shared_key is not None: with push_extension(buf, ExtensionType.PRE_SHARED_KEY): buf.push_uint16(hello.pre_shared_key) for extension_type, extension_value in hello.other_extensions: with push_extension(buf, extension_type): buf.push_bytes(extension_value) @dataclass class NewSessionTicket: ticket_lifetime: int = 0 ticket_age_add: int = 0 ticket_nonce: bytes = b"" ticket: bytes = b"" # extensions max_early_data_size: Optional[int] = None other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket: new_session_ticket = NewSessionTicket() pull_handshake_type(buf, HandshakeType.NEW_SESSION_TICKET) with pull_block(buf, 3): new_session_ticket.ticket_lifetime = buf.pull_uint32() new_session_ticket.ticket_age_add = buf.pull_uint32() new_session_ticket.ticket_nonce = pull_opaque(buf, 1) new_session_ticket.ticket = pull_opaque(buf, 2) def pull_extension() -> None: extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() if extension_type == ExtensionType.EARLY_DATA: new_session_ticket.max_early_data_size = buf.pull_uint32() else: new_session_ticket.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) ) pull_list(buf, 2, pull_extension) return new_session_ticket def push_new_session_ticket(buf: Buffer, new_session_ticket: NewSessionTicket) -> None: buf.push_uint8(HandshakeType.NEW_SESSION_TICKET) with push_block(buf, 3): buf.push_uint32(new_session_ticket.ticket_lifetime) buf.push_uint32(new_session_ticket.ticket_age_add) push_opaque(buf, 1, new_session_ticket.ticket_nonce) push_opaque(buf, 2, new_session_ticket.ticket) with push_block(buf, 2): if new_session_ticket.max_early_data_size is not None: with push_extension(buf, ExtensionType.EARLY_DATA): buf.push_uint32(new_session_ticket.max_early_data_size) for extension_type, extension_value in new_session_ticket.other_extensions: with push_extension(buf, extension_type): buf.push_bytes(extension_value) @dataclass class EncryptedExtensions: alpn_protocol: Optional[str] = None early_data: bool = False other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions: extensions = EncryptedExtensions() pull_handshake_type(buf, HandshakeType.ENCRYPTED_EXTENSIONS) with pull_block(buf, 3): def pull_extension() -> None: extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() if extension_type == ExtensionType.ALPN: extensions.alpn_protocol = pull_list( buf, 2, partial(pull_alpn_protocol, buf) )[0] elif extension_type == ExtensionType.EARLY_DATA: extensions.early_data = True else: extensions.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) ) pull_list(buf, 2, pull_extension) return extensions def push_encrypted_extensions(buf: Buffer, extensions: EncryptedExtensions) -> None: buf.push_uint8(HandshakeType.ENCRYPTED_EXTENSIONS) with push_block(buf, 3): with push_block(buf, 2): if extensions.alpn_protocol is not None: with push_extension(buf, ExtensionType.ALPN): push_list( buf, 2, partial(push_alpn_protocol, buf), [extensions.alpn_protocol], ) if extensions.early_data: with push_extension(buf, ExtensionType.EARLY_DATA): pass for extension_type, extension_value in extensions.other_extensions: with push_extension(buf, extension_type): buf.push_bytes(extension_value) CertificateEntry = Tuple[bytes, bytes] @dataclass class Certificate: request_context: bytes = b"" certificates: List[CertificateEntry] = field(default_factory=list) def pull_certificate(buf: Buffer) -> Certificate: certificate = Certificate() pull_handshake_type(buf, HandshakeType.CERTIFICATE) with pull_block(buf, 3): certificate.request_context = pull_opaque(buf, 1) def pull_certificate_entry(buf: Buffer) -> CertificateEntry: data = pull_opaque(buf, 3) extensions = pull_opaque(buf, 2) return (data, extensions) certificate.certificates = pull_list( buf, 3, partial(pull_certificate_entry, buf) ) return certificate def push_certificate(buf: Buffer, certificate: Certificate) -> None: buf.push_uint8(HandshakeType.CERTIFICATE) with push_block(buf, 3): push_opaque(buf, 1, certificate.request_context) def push_certificate_entry(buf: Buffer, entry: CertificateEntry) -> None: push_opaque(buf, 3, entry[0]) push_opaque(buf, 2, entry[1]) push_list( buf, 3, partial(push_certificate_entry, buf), certificate.certificates ) @dataclass class CertificateRequest: request_context: bytes = b"" signature_algorithms: Optional[List[int]] = None other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) def pull_certificate_request(buf: Buffer) -> CertificateRequest: certificate_request = CertificateRequest() pull_handshake_type(buf, HandshakeType.CERTIFICATE_REQUEST) with pull_block(buf, 3): certificate_request.request_context = pull_opaque(buf, 1) def pull_extension() -> None: extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() if extension_type == ExtensionType.SIGNATURE_ALGORITHMS: certificate_request.signature_algorithms = pull_list( buf, 2, buf.pull_uint16 ) else: certificate_request.other_extensions.append( (extension_type, buf.pull_bytes(extension_length)) ) pull_list(buf, 2, pull_extension) return certificate_request def push_certificate_request( buf: Buffer, certificate_request: CertificateRequest ) -> None: buf.push_uint8(HandshakeType.CERTIFICATE_REQUEST) with push_block(buf, 3): push_opaque(buf, 1, certificate_request.request_context) with push_block(buf, 2): with push_extension(buf, ExtensionType.SIGNATURE_ALGORITHMS): push_list( buf, 2, buf.push_uint16, certificate_request.signature_algorithms ) for extension_type, extension_value in certificate_request.other_extensions: with push_extension(buf, extension_type): buf.push_bytes(extension_value) @dataclass class CertificateVerify: algorithm: int signature: bytes def pull_certificate_verify(buf: Buffer) -> CertificateVerify: pull_handshake_type(buf, HandshakeType.CERTIFICATE_VERIFY) with pull_block(buf, 3): algorithm = buf.pull_uint16() signature = pull_opaque(buf, 2) return CertificateVerify(algorithm=algorithm, signature=signature) def push_certificate_verify(buf: Buffer, verify: CertificateVerify) -> None: buf.push_uint8(HandshakeType.CERTIFICATE_VERIFY) with push_block(buf, 3): buf.push_uint16(verify.algorithm) push_opaque(buf, 2, verify.signature) @dataclass class Finished: verify_data: bytes = b"" def pull_finished(buf: Buffer) -> Finished: finished = Finished() pull_handshake_type(buf, HandshakeType.FINISHED) finished.verify_data = pull_opaque(buf, 3) return finished def push_finished(buf: Buffer, finished: Finished) -> None: buf.push_uint8(HandshakeType.FINISHED) push_opaque(buf, 3, finished.verify_data) # CONTEXT class KeySchedule: def __init__(self, cipher_suite: CipherSuite): self.algorithm = cipher_suite_hash(cipher_suite) self.cipher_suite = cipher_suite self.generation = 0 self.hash = hashes.Hash(self.algorithm) self.hash_empty_value = self.hash.copy().finalize() self.secret = bytes(self.algorithm.digest_size) def certificate_verify_data(self, context_string: bytes) -> bytes: return b" " * 64 + context_string + b"\x00" + self.hash.copy().finalize() def finished_verify_data(self, secret: bytes) -> bytes: hmac_key = hkdf_expand_label( algorithm=self.algorithm, secret=secret, label=b"finished", hash_value=b"", length=self.algorithm.digest_size, ) h = hmac.HMAC(hmac_key, algorithm=self.algorithm) h.update(self.hash.copy().finalize()) return h.finalize() def derive_secret(self, label: bytes) -> bytes: return hkdf_expand_label( algorithm=self.algorithm, secret=self.secret, label=label, hash_value=self.hash.copy().finalize(), length=self.algorithm.digest_size, ) def extract(self, key_material: Optional[bytes] = None) -> None: if key_material is None: key_material = bytes(self.algorithm.digest_size) if self.generation: self.secret = hkdf_expand_label( algorithm=self.algorithm, secret=self.secret, label=b"derived", hash_value=self.hash_empty_value, length=self.algorithm.digest_size, ) self.generation += 1 self.secret = hkdf_extract( algorithm=self.algorithm, salt=self.secret, key_material=key_material ) def update_hash(self, data: bytes) -> None: self.hash.update(data) class KeyScheduleProxy: def __init__(self, cipher_suites: List[CipherSuite]): self.__schedules = dict(map(lambda c: (c, KeySchedule(c)), cipher_suites)) def extract(self, key_material: Optional[bytes] = None) -> None: for k in self.__schedules.values(): k.extract(key_material) def select(self, cipher_suite: CipherSuite) -> KeySchedule: return self.__schedules[cipher_suite] def update_hash(self, data: bytes) -> None: for k in self.__schedules.values(): k.update_hash(data) CIPHER_SUITES: Dict = { CipherSuite.AES_128_GCM_SHA256: hashes.SHA256, CipherSuite.AES_256_GCM_SHA384: hashes.SHA384, CipherSuite.CHACHA20_POLY1305_SHA256: hashes.SHA256, } SIGNATURE_ALGORITHMS: Dict = { SignatureAlgorithm.ECDSA_SECP256R1_SHA256: (None, hashes.SHA256), SignatureAlgorithm.ECDSA_SECP384R1_SHA384: (None, hashes.SHA384), SignatureAlgorithm.ECDSA_SECP521R1_SHA512: (None, hashes.SHA512), SignatureAlgorithm.RSA_PKCS1_SHA1: (padding.PKCS1v15, hashes.SHA1), SignatureAlgorithm.RSA_PKCS1_SHA256: (padding.PKCS1v15, hashes.SHA256), SignatureAlgorithm.RSA_PKCS1_SHA384: (padding.PKCS1v15, hashes.SHA384), SignatureAlgorithm.RSA_PKCS1_SHA512: (padding.PKCS1v15, hashes.SHA512), SignatureAlgorithm.RSA_PSS_RSAE_SHA256: (padding.PSS, hashes.SHA256), SignatureAlgorithm.RSA_PSS_RSAE_SHA384: (padding.PSS, hashes.SHA384), SignatureAlgorithm.RSA_PSS_RSAE_SHA512: (padding.PSS, hashes.SHA512), } GROUP_TO_CURVE: Dict = { Group.SECP256R1: ec.SECP256R1, Group.SECP384R1: ec.SECP384R1, Group.SECP521R1: ec.SECP521R1, } CURVE_TO_GROUP = dict((v, k) for k, v in GROUP_TO_CURVE.items()) def cipher_suite_hash(cipher_suite: CipherSuite) -> hashes.HashAlgorithm: return CIPHER_SUITES[cipher_suite]() def decode_public_key( key_share: KeyShareEntry, ) -> Union[ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey, None]: if key_share[0] == Group.X25519: return x25519.X25519PublicKey.from_public_bytes(key_share[1]) elif key_share[0] == Group.X448: return x448.X448PublicKey.from_public_bytes(key_share[1]) elif key_share[0] in GROUP_TO_CURVE: return ec.EllipticCurvePublicKey.from_encoded_point( GROUP_TO_CURVE[key_share[0]](), key_share[1] ) else: return None def encode_public_key( public_key: Union[ ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey ], ) -> KeyShareEntry: if isinstance(public_key, x25519.X25519PublicKey): return (Group.X25519, public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)) elif isinstance(public_key, x448.X448PublicKey): return (Group.X448, public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)) return ( CURVE_TO_GROUP[public_key.curve.__class__], public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint), ) def negotiate( supported: List[T], offered: Optional[List[Any]], exc: Optional[Alert] = None ) -> T: if offered is not None: for c in supported: if c in offered: return c if exc is not None: raise exc return None def signature_algorithm_params(signature_algorithm: int) -> Tuple: if signature_algorithm in (SignatureAlgorithm.ED25519, SignatureAlgorithm.ED448): return tuple() padding_cls, algorithm_cls = SIGNATURE_ALGORITHMS[signature_algorithm] algorithm = algorithm_cls() if padding_cls is None: return (ec.ECDSA(algorithm),) elif padding_cls == padding.PSS: padding_obj = padding_cls( mgf=padding.MGF1(algorithm), salt_length=algorithm.digest_size ) else: padding_obj = padding_cls() return padding_obj, algorithm @contextmanager def push_message( key_schedule: Union[KeySchedule, KeyScheduleProxy], buf: Buffer ) -> Generator: hash_start = buf.tell() yield key_schedule.update_hash(buf.data_slice(hash_start, buf.tell())) # callback types @dataclass class SessionTicket: """ A TLS session ticket for session resumption. """ age_add: int cipher_suite: CipherSuite not_valid_after: datetime.datetime not_valid_before: datetime.datetime resumption_secret: bytes server_name: str ticket: bytes max_early_data_size: Optional[int] = None other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) @property def is_valid(self) -> bool: now = utcnow() return now >= self.not_valid_before and now <= self.not_valid_after @property def obfuscated_age(self) -> int: age = int((utcnow() - self.not_valid_before).total_seconds() * 1000) return (age + self.age_add) % (1 << 32) AlpnHandler = Callable[[str], None] SessionTicketFetcher = Callable[[bytes], Optional[SessionTicket]] SessionTicketHandler = Callable[[SessionTicket], None] class Context: def __init__( self, is_client: bool, alpn_protocols: Optional[List[str]] = None, cadata: Optional[bytes] = None, cafile: Optional[str] = None, capath: Optional[str] = None, cipher_suites: Optional[List[CipherSuite]] = None, logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None, max_early_data: Optional[int] = None, server_name: Optional[str] = None, verify_mode: Optional[int] = None, ): # configuration self._alpn_protocols = alpn_protocols self._cadata = cadata self._cafile = cafile self._capath = capath self.certificate: Optional[x509.Certificate] = None self.certificate_chain: List[x509.Certificate] = [] self.certificate_private_key: Optional[ Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey] ] = None self.handshake_extensions: List[Extension] = [] self._is_client = is_client self._max_early_data = max_early_data self.session_ticket: Optional[SessionTicket] = None self._request_client_certificate = False # For test purposes only self._server_name = server_name if verify_mode is not None: self._verify_mode = verify_mode else: self._verify_mode = ssl.CERT_REQUIRED if is_client else ssl.CERT_NONE # callbacks self.alpn_cb: Optional[AlpnHandler] = None self.get_session_ticket_cb: Optional[SessionTicketFetcher] = None self.new_session_ticket_cb: Optional[SessionTicketHandler] = None self.update_traffic_key_cb: Callable[ [Direction, Epoch, CipherSuite, bytes], None ] = lambda d, e, c, s: None # supported parameters if cipher_suites is not None: self._cipher_suites = cipher_suites else: self._cipher_suites = [ CipherSuite.AES_256_GCM_SHA384, CipherSuite.AES_128_GCM_SHA256, CipherSuite.CHACHA20_POLY1305_SHA256, ] self._legacy_compression_methods: List[int] = [CompressionMethod.NULL] self._psk_key_exchange_modes: List[int] = [PskKeyExchangeMode.PSK_DHE_KE] self._signature_algorithms: List[int] = [ SignatureAlgorithm.ECDSA_SECP256R1_SHA256, SignatureAlgorithm.RSA_PSS_RSAE_SHA256, SignatureAlgorithm.RSA_PKCS1_SHA256, SignatureAlgorithm.ECDSA_SECP384R1_SHA384, SignatureAlgorithm.RSA_PSS_RSAE_SHA384, SignatureAlgorithm.RSA_PKCS1_SHA384, SignatureAlgorithm.RSA_PKCS1_SHA1, ] if default_backend().ed25519_supported(): self._signature_algorithms.append(SignatureAlgorithm.ED25519) if default_backend().ed448_supported(): self._signature_algorithms.append(SignatureAlgorithm.ED448) self._supported_groups = [Group.SECP256R1, Group.SECP384R1] if default_backend().x25519_supported(): self._supported_groups.append(Group.X25519) if default_backend().x448_supported(): self._supported_groups.append(Group.X448) self._supported_versions = [TLS_VERSION_1_3] # state self.alpn_negotiated: Optional[str] = None self.early_data_accepted = False self.key_schedule: Optional[KeySchedule] = None self.received_extensions: Optional[List[Extension]] = None self._certificate_request: Optional[CertificateRequest] = None self._key_schedule_psk: Optional[KeySchedule] = None self._key_schedule_proxy: Optional[KeyScheduleProxy] = None self._new_session_ticket: Optional[NewSessionTicket] = None self._peer_certificate: Optional[x509.Certificate] = None self._peer_certificate_chain: List[x509.Certificate] = [] self._psk_key_exchange_mode: Optional[int] = None self._receive_buffer = b"" self._session_resumed = False self._enc_key: Optional[bytes] = None self._dec_key: Optional[bytes] = None self.__logger = logger self._ec_private_keys: List[ec.EllipticCurvePrivateKey] = [] self._x25519_private_key: Optional[x25519.X25519PrivateKey] = None self._x448_private_key: Optional[x448.X448PrivateKey] = None if is_client: self.client_random = os.urandom(32) self.legacy_session_id = b"" self.state = State.CLIENT_HANDSHAKE_START else: self.client_random = None self.legacy_session_id = None self.state = State.SERVER_EXPECT_CLIENT_HELLO @property def session_resumed(self) -> bool: """ Returns True if session resumption was successfully used. """ return self._session_resumed def handle_message( self, input_data: bytes, output_buf: Dict[Epoch, Buffer] ) -> None: if self.state == State.CLIENT_HANDSHAKE_START: self._client_send_hello(output_buf[Epoch.INITIAL]) return self._receive_buffer += input_data while len(self._receive_buffer) >= 4: # determine message length message_type = self._receive_buffer[0] message_length = 4 + int.from_bytes( self._receive_buffer[1:4], byteorder="big" ) # check message is complete if len(self._receive_buffer) < message_length: break message = self._receive_buffer[:message_length] self._receive_buffer = self._receive_buffer[message_length:] # process the message try: self._handle_reassembled_message( message_type=message_type, input_buf=Buffer(data=message), output_buf=output_buf, ) except BufferReadError: raise AlertDecodeError("Could not parse TLS message") def _handle_reassembled_message( self, message_type: int, input_buf: Buffer, output_buf: Dict[Epoch, Buffer] ) -> None: # client states if self.state == State.CLIENT_EXPECT_SERVER_HELLO: if message_type == HandshakeType.SERVER_HELLO: self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL]) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS: if message_type == HandshakeType.ENCRYPTED_EXTENSIONS: self._client_handle_encrypted_extensions(input_buf) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE: if message_type == HandshakeType.CERTIFICATE: self._client_handle_certificate(input_buf) elif message_type == HandshakeType.CERTIFICATE_REQUEST: self._client_handle_certificate_request(input_buf) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_EXPECT_CERTIFICATE: if message_type == HandshakeType.CERTIFICATE: self._client_handle_certificate(input_buf) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY: if message_type == HandshakeType.CERTIFICATE_VERIFY: self._client_handle_certificate_verify(input_buf) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_EXPECT_FINISHED: if message_type == HandshakeType.FINISHED: self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE]) else: raise AlertUnexpectedMessage elif self.state == State.CLIENT_POST_HANDSHAKE: if message_type == HandshakeType.NEW_SESSION_TICKET: self._client_handle_new_session_ticket(input_buf) else: raise AlertUnexpectedMessage # server states elif self.state == State.SERVER_EXPECT_CLIENT_HELLO: if message_type == HandshakeType.CLIENT_HELLO: self._server_handle_hello( input_buf, output_buf[Epoch.INITIAL], output_buf[Epoch.HANDSHAKE], output_buf[Epoch.ONE_RTT], ) else: raise AlertUnexpectedMessage elif self.state == State.SERVER_EXPECT_CERTIFICATE: if message_type == HandshakeType.CERTIFICATE: self._server_handle_certificate(input_buf, output_buf[Epoch.ONE_RTT]) else: raise AlertUnexpectedMessage elif self.state == State.SERVER_EXPECT_CERTIFICATE_VERIFY: if message_type == HandshakeType.CERTIFICATE_VERIFY: self._server_handle_certificate_verify( input_buf, output_buf[Epoch.ONE_RTT] ) else: raise AlertUnexpectedMessage elif self.state == State.SERVER_EXPECT_FINISHED: if message_type == HandshakeType.FINISHED: self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT]) else: raise AlertUnexpectedMessage elif self.state == State.SERVER_POST_HANDSHAKE: raise AlertUnexpectedMessage # This condition should never be reached, because if the message # contains any extra bytes, the `pull_block` inside the message # parser will raise `AlertDecodeError`. assert input_buf.eof() def _build_session_ticket( self, new_session_ticket: NewSessionTicket, other_extensions: List[Extension] ) -> SessionTicket: resumption_master_secret = self.key_schedule.derive_secret(b"res master") resumption_secret = hkdf_expand_label( algorithm=self.key_schedule.algorithm, secret=resumption_master_secret, label=b"resumption", hash_value=new_session_ticket.ticket_nonce, length=self.key_schedule.algorithm.digest_size, ) timestamp = utcnow() return SessionTicket( age_add=new_session_ticket.ticket_age_add, cipher_suite=self.key_schedule.cipher_suite, max_early_data_size=new_session_ticket.max_early_data_size, not_valid_after=timestamp + datetime.timedelta(seconds=new_session_ticket.ticket_lifetime), not_valid_before=timestamp, other_extensions=other_extensions, resumption_secret=resumption_secret, server_name=self._server_name, ticket=new_session_ticket.ticket, ) def _check_certificate_verify_signature(self, verify: CertificateVerify) -> None: if verify.algorithm not in self._signature_algorithms: raise AlertDecryptError( "CertificateVerify has a signature algorithm we did not advertise" ) try: # The type of public_key() is CertificatePublicKeyTypes, but along with # ed25519 and ed448, which are fine, this type includes # x25519 and x448 which can be public keys but can't sign. We know # we won't get x25519 and x448 as they are not on our list of # signature algorithms, so we can cast public key to # CertificateIssuerPublicKeyTypes safely and make mypy happy. public_key = cast( CertificateIssuerPublicKeyTypes, self._peer_certificate.public_key() ) public_key.verify( verify.signature, self.key_schedule.certificate_verify_data( SERVER_CONTEXT_STRING if self._is_client else CLIENT_CONTEXT_STRING ), *signature_algorithm_params(verify.algorithm), ) except InvalidSignature: raise AlertDecryptError def _client_send_hello(self, output_buf: Buffer) -> None: key_share: List[KeyShareEntry] = [] supported_groups: List[int] = [] for group in self._supported_groups: if group == Group.X25519: self._x25519_private_key = x25519.X25519PrivateKey.generate() key_share.append( encode_public_key(self._x25519_private_key.public_key()) ) supported_groups.append(Group.X25519) elif group == Group.X448: self._x448_private_key = x448.X448PrivateKey.generate() key_share.append(encode_public_key(self._x448_private_key.public_key())) supported_groups.append(Group.X448) elif group == Group.GREASE: key_share.append((Group.GREASE, b"\x00")) supported_groups.append(Group.GREASE) elif group in GROUP_TO_CURVE: ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[group]()) self._ec_private_keys.append(ec_private_key) key_share.append(encode_public_key(ec_private_key.public_key())) supported_groups.append(group) assert len(key_share), "no key share entries" # Literal IPv4 and IPv6 addresses are not permitted in # Server Name Indication (SNI) hostname. try: ipaddress.ip_address(self._server_name) except ValueError: server_name = self._server_name else: server_name = None hello = ClientHello( random=self.client_random, legacy_session_id=self.legacy_session_id, cipher_suites=[int(x) for x in self._cipher_suites], legacy_compression_methods=self._legacy_compression_methods, alpn_protocols=self._alpn_protocols, key_share=key_share, psk_key_exchange_modes=( self._psk_key_exchange_modes if (self.session_ticket or self.new_session_ticket_cb is not None) else None ), server_name=server_name, signature_algorithms=self._signature_algorithms, supported_groups=supported_groups, supported_versions=self._supported_versions, other_extensions=self.handshake_extensions, ) # PSK if self.session_ticket and self.session_ticket.is_valid: self._key_schedule_psk = KeySchedule(self.session_ticket.cipher_suite) self._key_schedule_psk.extract(self.session_ticket.resumption_secret) binder_key = self._key_schedule_psk.derive_secret(b"res binder") binder_length = self._key_schedule_psk.algorithm.digest_size # update hello if self.session_ticket.max_early_data_size is not None: hello.early_data = True hello.pre_shared_key = OfferedPsks( identities=[ (self.session_ticket.ticket, self.session_ticket.obfuscated_age) ], binders=[bytes(binder_length)], ) # serialize hello without binder tmp_buf = Buffer(capacity=1024) push_client_hello(tmp_buf, hello) # calculate binder hash_offset = tmp_buf.tell() - binder_length - 3 self._key_schedule_psk.update_hash(tmp_buf.data_slice(0, hash_offset)) binder = self._key_schedule_psk.finished_verify_data(binder_key) hello.pre_shared_key.binders[0] = binder self._key_schedule_psk.update_hash( tmp_buf.data_slice(hash_offset, hash_offset + 3) + binder ) # calculate early data key if hello.early_data: early_key = self._key_schedule_psk.derive_secret(b"c e traffic") self.update_traffic_key_cb( Direction.ENCRYPT, Epoch.ZERO_RTT, self._key_schedule_psk.cipher_suite, early_key, ) self._key_schedule_proxy = KeyScheduleProxy(self._cipher_suites) self._key_schedule_proxy.extract(None) with push_message(self._key_schedule_proxy, output_buf): push_client_hello(output_buf, hello) self._set_state(State.CLIENT_EXPECT_SERVER_HELLO) def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: peer_hello = pull_server_hello(input_buf) cipher_suite = negotiate( self._cipher_suites, [peer_hello.cipher_suite], AlertHandshakeFailure("Unsupported cipher suite"), ) if peer_hello.compression_method not in self._legacy_compression_methods: raise AlertIllegalParameter( "ServerHello has a compression method we did not advertise" ) if peer_hello.supported_version not in self._supported_versions: raise AlertIllegalParameter( "ServerHello has a version we did not advertise" ) # select key schedule if peer_hello.pre_shared_key is not None: if ( self._key_schedule_psk is None or peer_hello.pre_shared_key != 0 or cipher_suite != self._key_schedule_psk.cipher_suite ): raise AlertIllegalParameter self.key_schedule = self._key_schedule_psk self._session_resumed = True else: self.key_schedule = self._key_schedule_proxy.select(cipher_suite) self._key_schedule_psk = None self._key_schedule_proxy = None # perform key exchange peer_public_key = decode_public_key(peer_hello.key_share) shared_key: Optional[bytes] = None if ( isinstance(peer_public_key, x25519.X25519PublicKey) and self._x25519_private_key is not None ): shared_key = self._x25519_private_key.exchange(peer_public_key) elif ( isinstance(peer_public_key, x448.X448PublicKey) and self._x448_private_key is not None ): shared_key = self._x448_private_key.exchange(peer_public_key) elif isinstance(peer_public_key, ec.EllipticCurvePublicKey): for ec_private_key in self._ec_private_keys: if ( ec_private_key.public_key().curve.__class__ == peer_public_key.curve.__class__ ): shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key) assert shared_key is not None self.key_schedule.update_hash(input_buf.data) self.key_schedule.extract(shared_key) self._setup_traffic_protection( Direction.DECRYPT, Epoch.HANDSHAKE, b"s hs traffic" ) self._set_state(State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS) def _client_handle_encrypted_extensions(self, input_buf: Buffer) -> None: encrypted_extensions = pull_encrypted_extensions(input_buf) self.alpn_negotiated = encrypted_extensions.alpn_protocol self.early_data_accepted = encrypted_extensions.early_data self.received_extensions = encrypted_extensions.other_extensions # notify application if self.alpn_cb: self.alpn_cb(self.alpn_negotiated) self._setup_traffic_protection( Direction.ENCRYPT, Epoch.HANDSHAKE, b"c hs traffic" ) self.key_schedule.update_hash(input_buf.data) # if the server accepted our PSK we are done, other we want its certificate if self._session_resumed: self._set_state(State.CLIENT_EXPECT_FINISHED) else: self._set_state(State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE) def _client_handle_certificate_request(self, input_buf: Buffer) -> None: self._certificate_request = pull_certificate_request(input_buf) self.key_schedule.update_hash(input_buf.data) self._set_state(State.CLIENT_EXPECT_CERTIFICATE) def _client_handle_certificate(self, input_buf: Buffer) -> None: certificate = pull_certificate(input_buf) self.key_schedule.update_hash(input_buf.data) self._set_peer_certificate(certificate) self._set_state(State.CLIENT_EXPECT_CERTIFICATE_VERIFY) def _client_handle_certificate_verify(self, input_buf: Buffer) -> None: verify = pull_certificate_verify(input_buf) # check signature self._check_certificate_verify_signature(verify) # check certificate if self._verify_mode != ssl.CERT_NONE: verify_certificate( cadata=self._cadata, cafile=self._cafile, capath=self._capath, certificate=self._peer_certificate, chain=self._peer_certificate_chain, server_name=self._server_name, ) self.key_schedule.update_hash(input_buf.data) self._set_state(State.CLIENT_EXPECT_FINISHED) def _client_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None: finished = pull_finished(input_buf) # check verify data expected_verify_data = self.key_schedule.finished_verify_data(self._dec_key) if finished.verify_data != expected_verify_data: raise AlertDecryptError self.key_schedule.update_hash(input_buf.data) # prepare traffic keys assert self.key_schedule.generation == 2 self.key_schedule.extract(None) self._setup_traffic_protection( Direction.DECRYPT, Epoch.ONE_RTT, b"s ap traffic" ) next_enc_key = self.key_schedule.derive_secret(b"c ap traffic") if self._certificate_request is not None: # check whether we have a suitable signature algorithm if ( self.certificate is not None and self.certificate_private_key is not None ): signature_algorithm = negotiate( self._signature_algorithms_for_private_key(), self._certificate_request.signature_algorithms, ) else: signature_algorithm = None # send certificate with push_message(self.key_schedule, output_buf): push_certificate( output_buf, Certificate( request_context=self._certificate_request.request_context, certificates=( [ (x.public_bytes(Encoding.DER), b"") for x in [self.certificate] + self.certificate_chain ] if signature_algorithm else [] ), ), ) # send certificate verify if signature_algorithm: signature = self.certificate_private_key.sign( self.key_schedule.certificate_verify_data(CLIENT_CONTEXT_STRING), *signature_algorithm_params(signature_algorithm), ) with push_message(self.key_schedule, output_buf): push_certificate_verify( output_buf, CertificateVerify( algorithm=signature_algorithm, signature=signature ), ) # send finished with push_message(self.key_schedule, output_buf): push_finished( output_buf, Finished( verify_data=self.key_schedule.finished_verify_data(self._enc_key) ), ) # commit traffic key self._enc_key = next_enc_key self.update_traffic_key_cb( Direction.ENCRYPT, Epoch.ONE_RTT, self.key_schedule.cipher_suite, self._enc_key, ) self._set_state(State.CLIENT_POST_HANDSHAKE) def _client_handle_new_session_ticket(self, input_buf: Buffer) -> None: new_session_ticket = pull_new_session_ticket(input_buf) # notify application if self.new_session_ticket_cb is not None: ticket = self._build_session_ticket( new_session_ticket, self.received_extensions ) self.new_session_ticket_cb(ticket) def _server_expect_finished(self, onertt_buf: Buffer): # anticipate client's FINISHED self._expected_verify_data = self.key_schedule.finished_verify_data( self._dec_key ) buf = Buffer(capacity=64) push_finished(buf, Finished(verify_data=self._expected_verify_data)) self.key_schedule.update_hash(buf.data) # create a new session ticket if ( self.new_session_ticket_cb is not None and self._psk_key_exchange_mode is not None ): self._new_session_ticket = NewSessionTicket( ticket_lifetime=86400, ticket_age_add=struct.unpack("I", os.urandom(4))[0], ticket_nonce=b"", ticket=os.urandom(64), max_early_data_size=self._max_early_data, ) # send message push_new_session_ticket(onertt_buf, self._new_session_ticket) # notify application ticket = self._build_session_ticket( self._new_session_ticket, self.handshake_extensions ) self.new_session_ticket_cb(ticket) self._set_state(State.SERVER_EXPECT_FINISHED) def _server_handle_hello( self, input_buf: Buffer, initial_buf: Buffer, handshake_buf: Buffer, onertt_buf: Buffer, ) -> None: peer_hello = pull_client_hello(input_buf) # negotiate parameters cipher_suite = negotiate( self._cipher_suites, peer_hello.cipher_suites, AlertHandshakeFailure("No supported cipher suite"), ) compression_method = negotiate( self._legacy_compression_methods, peer_hello.legacy_compression_methods, AlertHandshakeFailure("No supported compression method"), ) psk_key_exchange_mode = negotiate( self._psk_key_exchange_modes, peer_hello.psk_key_exchange_modes ) signature_algorithm = negotiate( self._signature_algorithms_for_private_key(), peer_hello.signature_algorithms, AlertHandshakeFailure("No supported signature algorithm"), ) supported_version = negotiate( self._supported_versions, peer_hello.supported_versions, AlertProtocolVersion("No supported protocol version"), ) # negotiate ALPN if self._alpn_protocols is not None: self.alpn_negotiated = negotiate( self._alpn_protocols, peer_hello.alpn_protocols, AlertHandshakeFailure("No common ALPN protocols"), ) self.client_random = peer_hello.random self.server_random = os.urandom(32) self.legacy_session_id = peer_hello.legacy_session_id self.received_extensions = peer_hello.other_extensions # notify application if self.alpn_cb: self.alpn_cb(self.alpn_negotiated) # select key schedule pre_shared_key = None if ( self.get_session_ticket_cb is not None and psk_key_exchange_mode is not None and peer_hello.pre_shared_key is not None and len(peer_hello.pre_shared_key.identities) == 1 and len(peer_hello.pre_shared_key.binders) == 1 ): # ask application to find session ticket identity = peer_hello.pre_shared_key.identities[0] session_ticket = self.get_session_ticket_cb(identity[0]) # validate session ticket if ( session_ticket is not None and session_ticket.is_valid and session_ticket.cipher_suite == cipher_suite ): self.key_schedule = KeySchedule(cipher_suite) self.key_schedule.extract(session_ticket.resumption_secret) binder_key = self.key_schedule.derive_secret(b"res binder") binder_length = self.key_schedule.algorithm.digest_size hash_offset = input_buf.tell() - binder_length - 3 binder = input_buf.data_slice( hash_offset + 3, hash_offset + 3 + binder_length ) self.key_schedule.update_hash(input_buf.data_slice(0, hash_offset)) expected_binder = self.key_schedule.finished_verify_data(binder_key) if binder != expected_binder: raise AlertHandshakeFailure("PSK validation failed") self.key_schedule.update_hash( input_buf.data_slice(hash_offset, hash_offset + 3 + binder_length) ) self._session_resumed = True # calculate early data key if peer_hello.early_data: early_key = self.key_schedule.derive_secret(b"c e traffic") self.early_data_accepted = True self.update_traffic_key_cb( Direction.DECRYPT, Epoch.ZERO_RTT, self.key_schedule.cipher_suite, early_key, ) pre_shared_key = 0 # if PSK is not used, initialize key schedule if pre_shared_key is None: self.key_schedule = KeySchedule(cipher_suite) self.key_schedule.extract(None) self.key_schedule.update_hash(input_buf.data) # perform key exchange public_key: Union[ ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey ] shared_key: Optional[bytes] = None for key_share in peer_hello.key_share: peer_public_key = decode_public_key(key_share) if isinstance(peer_public_key, x25519.X25519PublicKey): self._x25519_private_key = x25519.X25519PrivateKey.generate() public_key = self._x25519_private_key.public_key() shared_key = self._x25519_private_key.exchange(peer_public_key) break elif isinstance(peer_public_key, x448.X448PublicKey): self._x448_private_key = x448.X448PrivateKey.generate() public_key = self._x448_private_key.public_key() shared_key = self._x448_private_key.exchange(peer_public_key) break elif isinstance(peer_public_key, ec.EllipticCurvePublicKey): ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[key_share[0]]()) self._ec_private_keys.append(ec_private_key) public_key = ec_private_key.public_key() shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key) break assert shared_key is not None # send hello hello = ServerHello( random=self.server_random, legacy_session_id=self.legacy_session_id, cipher_suite=cipher_suite, compression_method=compression_method, key_share=encode_public_key(public_key), pre_shared_key=pre_shared_key, supported_version=supported_version, ) with push_message(self.key_schedule, initial_buf): push_server_hello(initial_buf, hello) self.key_schedule.extract(shared_key) self._setup_traffic_protection( Direction.ENCRYPT, Epoch.HANDSHAKE, b"s hs traffic" ) self._setup_traffic_protection( Direction.DECRYPT, Epoch.HANDSHAKE, b"c hs traffic" ) # send encrypted extensions with push_message(self.key_schedule, handshake_buf): push_encrypted_extensions( handshake_buf, EncryptedExtensions( alpn_protocol=self.alpn_negotiated, early_data=self.early_data_accepted, other_extensions=self.handshake_extensions, ), ) if pre_shared_key is None: # send certificate request if self._request_client_certificate: with push_message(self.key_schedule, handshake_buf): push_certificate_request( handshake_buf, CertificateRequest( request_context=b"", signature_algorithms=self._signature_algorithms, ), ) # send certificate with push_message(self.key_schedule, handshake_buf): push_certificate( handshake_buf, Certificate( request_context=b"", certificates=[ (x.public_bytes(Encoding.DER), b"") for x in [self.certificate] + self.certificate_chain ], ), ) # send certificate verify signature = self.certificate_private_key.sign( self.key_schedule.certificate_verify_data(SERVER_CONTEXT_STRING), *signature_algorithm_params(signature_algorithm), ) with push_message(self.key_schedule, handshake_buf): push_certificate_verify( handshake_buf, CertificateVerify( algorithm=signature_algorithm, signature=signature ), ) # send finished with push_message(self.key_schedule, handshake_buf): push_finished( handshake_buf, Finished( verify_data=self.key_schedule.finished_verify_data(self._enc_key) ), ) # prepare traffic keys assert self.key_schedule.generation == 2 self.key_schedule.extract(None) self._setup_traffic_protection( Direction.ENCRYPT, Epoch.ONE_RTT, b"s ap traffic" ) self._next_dec_key = self.key_schedule.derive_secret(b"c ap traffic") self._psk_key_exchange_mode = psk_key_exchange_mode if self._request_client_certificate: self._set_state(State.SERVER_EXPECT_CERTIFICATE) else: self._server_expect_finished(onertt_buf) def _server_handle_certificate(self, input_buf: Buffer, output_buf: Buffer) -> None: certificate = pull_certificate(input_buf) self.key_schedule.update_hash(input_buf.data) if certificate.certificates: self._set_peer_certificate(certificate) self._set_state(State.SERVER_EXPECT_CERTIFICATE_VERIFY) else: self._server_expect_finished(output_buf) def _server_handle_certificate_verify( self, input_buf: Buffer, output_buf: Buffer ) -> None: verify = pull_certificate_verify(input_buf) # check signature self._check_certificate_verify_signature(verify) self.key_schedule.update_hash(input_buf.data) self._server_expect_finished(output_buf) def _server_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None: finished = pull_finished(input_buf) # check verify data if finished.verify_data != self._expected_verify_data: raise AlertDecryptError # commit traffic key self._dec_key = self._next_dec_key self._next_dec_key = None self.update_traffic_key_cb( Direction.DECRYPT, Epoch.ONE_RTT, self.key_schedule.cipher_suite, self._dec_key, ) self._set_state(State.SERVER_POST_HANDSHAKE) def _setup_traffic_protection( self, direction: Direction, epoch: Epoch, label: bytes ) -> None: key = self.key_schedule.derive_secret(label) if direction == Direction.ENCRYPT: self._enc_key = key else: self._dec_key = key self.update_traffic_key_cb( direction, epoch, self.key_schedule.cipher_suite, key ) def _set_peer_certificate(self, certificate: Certificate) -> None: self._peer_certificate = x509.load_der_x509_certificate( certificate.certificates[0][0] ) self._peer_certificate_chain = [ x509.load_der_x509_certificate(certificate.certificates[i][0]) for i in range(1, len(certificate.certificates)) ] def _set_state(self, state: State) -> None: if self.__logger: self.__logger.debug("TLS %s -> %s", self.state, state) self.state = state def _signature_algorithms_for_private_key(self) -> List[SignatureAlgorithm]: signature_algorithms: List[SignatureAlgorithm] = [] if isinstance(self.certificate_private_key, rsa.RSAPrivateKey): signature_algorithms = [ SignatureAlgorithm.RSA_PSS_RSAE_SHA256, SignatureAlgorithm.RSA_PKCS1_SHA256, SignatureAlgorithm.RSA_PSS_RSAE_SHA384, SignatureAlgorithm.RSA_PKCS1_SHA384, SignatureAlgorithm.RSA_PKCS1_SHA1, ] elif isinstance( self.certificate_private_key, ec.EllipticCurvePrivateKey ) and isinstance(self.certificate_private_key.curve, ec.SECP256R1): signature_algorithms = [SignatureAlgorithm.ECDSA_SECP256R1_SHA256] elif isinstance( self.certificate_private_key, ec.EllipticCurvePrivateKey ) and isinstance(self.certificate_private_key.curve, ec.SECP384R1): signature_algorithms = [SignatureAlgorithm.ECDSA_SECP384R1_SHA384] elif isinstance(self.certificate_private_key, ed25519.Ed25519PrivateKey): signature_algorithms = [SignatureAlgorithm.ED25519] elif isinstance(self.certificate_private_key, ed448.Ed448PrivateKey): signature_algorithms = [SignatureAlgorithm.ED448] return signature_algorithms ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1372943 aioquic-1.2.0/src/aioquic.egg-info/0000755000175100001770000000000000000000000017665 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306888.0 aioquic-1.2.0/src/aioquic.egg-info/PKG-INFO0000644000175100001770000001426100000000000020766 0ustar00runnerdocker00000000000000Metadata-Version: 2.1 Name: aioquic Version: 1.2.0 Summary: An implementation of QUIC and HTTP/3 Author-email: Jeremy Lainé License: BSD-3-Clause Project-URL: Homepage, https://github.com/aiortc/aioquic Project-URL: Changelog, https://aioquic.readthedocs.io/en/stable/changelog.html Project-URL: Documentation, https://aioquic.readthedocs.io/ Classifier: Development Status :: 5 - Production/Stable Classifier: Environment :: Web Environment Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: BSD License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 Classifier: Topic :: Internet :: WWW/HTTP Requires-Python: >=3.8 Description-Content-Type: text/x-rst License-File: LICENSE Requires-Dist: certifi Requires-Dist: cryptography>=42.0.0 Requires-Dist: pylsqpack<0.4.0,>=0.3.3 Requires-Dist: pyopenssl>=24 Requires-Dist: service-identity>=24.1.0 Provides-Extra: dev Requires-Dist: coverage[toml]>=7.2.2; extra == "dev" aioquic ======= .. image:: https://img.shields.io/pypi/l/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: License .. image:: https://img.shields.io/pypi/v/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Version .. image:: https://img.shields.io/pypi/pyversions/aioquic.svg :target: https://pypi.python.org/pypi/aioquic :alt: Python versions .. image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg :target: https://github.com/aiortc/aioquic/actions :alt: Tests .. image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg :target: https://codecov.io/gh/aiortc/aioquic :alt: Coverage .. image:: https://readthedocs.org/projects/aioquic/badge/?version=latest :target: https://aioquic.readthedocs.io/ :alt: Documentation What is ``aioquic``? -------------------- ``aioquic`` is a library for the QUIC network protocol in Python. It features a minimal TLS 1.3 implementation, a QUIC stack and an HTTP/3 stack. ``aioquic`` is used by Python opensource projects such as `dnspython`_, `hypercorn`_, `mitmproxy`_ and the `Web Platform Tests`_ cross-browser test suite. It has also been used extensively in research papers about QUIC. To learn more about ``aioquic`` please `read the documentation`_. Why should I use ``aioquic``? ----------------------------- ``aioquic`` has been designed to be embedded into Python client and server libraries wishing to support QUIC and / or HTTP/3. The goal is to provide a common codebase for Python libraries in the hope of avoiding duplicated effort. Both the QUIC and the HTTP/3 APIs follow the "bring your own I/O" pattern, leaving actual I/O operations to the API user. This approach has a number of advantages including making the code testable and allowing integration with different concurrency models. A lot of effort has gone into writing an extensive test suite for the ``aioquic`` code to ensure best-in-class code quality, and it is regularly `tested for interoperability`_ against other `QUIC implementations`_. Features -------- - minimal TLS 1.3 implementation conforming with `RFC 8446`_ - QUIC stack conforming with `RFC 9000`_ (QUIC v1) and `RFC 9369`_ (QUIC v2) * IPv4 and IPv6 support * connection migration and NAT rebinding * logging TLS traffic secrets * logging QUIC events in QLOG format * version negotiation conforming with `RFC 9368`_ - HTTP/3 stack conforming with `RFC 9114`_ * server push support * WebSocket bootstrapping conforming with `RFC 9220`_ * datagram support conforming with `RFC 9297`_ Installing ---------- The easiest way to install ``aioquic`` is to run: .. code:: bash pip install aioquic Building from source -------------------- If there are no wheels for your system or if you wish to build ``aioquic`` from source you will need the OpenSSL development headers. Linux ..... On Debian/Ubuntu run: .. code-block:: console sudo apt install libssl-dev python3-dev On Alpine Linux run: .. code-block:: console sudo apk add openssl-dev python3-dev bsd-compat-headers libffi-dev OS X .... On OS X run: .. code-block:: console brew install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console export CFLAGS=-I$(brew --prefix openssl)/include export LDFLAGS=-L$(brew --prefix openssl)/lib Windows ....... On Windows the easiest way to install OpenSSL is to use `Chocolatey`_. .. code-block:: console choco install openssl You will need to set some environment variables to link against OpenSSL: .. code-block:: console $Env:INCLUDE = "C:\Progra~1\OpenSSL\include" $Env:LIB = "C:\Progra~1\OpenSSL\lib" Running the examples -------------------- `aioquic` comes with a number of examples illustrating various QUIC usecases. You can browse these examples here: https://github.com/aiortc/aioquic/tree/main/examples License ------- ``aioquic`` is released under the `BSD license`_. .. _read the documentation: https://aioquic.readthedocs.io/en/latest/ .. _dnspython: https://github.com/rthalley/dnspython .. _hypercorn: https://github.com/pgjones/hypercorn .. _mitmproxy: https://github.com/mitmproxy/mitmproxy .. _Web Platform Tests: https://github.com/web-platform-tests/wpt .. _tested for interoperability: https://interop.seemann.io/ .. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations .. _cryptography: https://cryptography.io/ .. _Chocolatey: https://chocolatey.org/ .. _BSD license: https://aioquic.readthedocs.io/en/latest/license.html .. _RFC 8446: https://datatracker.ietf.org/doc/html/rfc8446 .. _RFC 9000: https://datatracker.ietf.org/doc/html/rfc9000 .. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114 .. _RFC 9220: https://datatracker.ietf.org/doc/html/rfc9220 .. _RFC 9297: https://datatracker.ietf.org/doc/html/rfc9297 .. _RFC 9368: https://datatracker.ietf.org/doc/html/rfc9368 .. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306888.0 aioquic-1.2.0/src/aioquic.egg-info/SOURCES.txt0000644000175100001770000000554400000000000021561 0ustar00runnerdocker00000000000000LICENSE MANIFEST.in README.rst pyproject.toml setup.py docs/Makefile docs/asyncio.rst docs/changelog.rst docs/conf.py docs/design.rst docs/h3.rst docs/index.rst docs/license.rst docs/quic.rst docs/_ext/sphinx_aioquic.py docs/_static/aioquic.svg examples/README.rst examples/demo.py examples/doq_client.py examples/doq_server.py examples/http3_client.py examples/http3_server.py examples/httpx_client.py examples/interop.py examples/siduck_client.py examples/htdocs/robots.txt examples/htdocs/style.css examples/templates/index.html examples/templates/logs.html requirements/doc.txt scripts/fetch-vendor.json scripts/fetch-vendor.py src/aioquic/__init__.py src/aioquic/_buffer.c src/aioquic/_buffer.pyi src/aioquic/_crypto.c src/aioquic/_crypto.pyi src/aioquic/buffer.py src/aioquic/py.typed src/aioquic/tls.py src/aioquic.egg-info/PKG-INFO src/aioquic.egg-info/SOURCES.txt src/aioquic.egg-info/dependency_links.txt src/aioquic.egg-info/requires.txt src/aioquic.egg-info/top_level.txt src/aioquic/asyncio/__init__.py src/aioquic/asyncio/client.py src/aioquic/asyncio/protocol.py src/aioquic/asyncio/server.py src/aioquic/h0/__init__.py src/aioquic/h0/connection.py src/aioquic/h3/__init__.py src/aioquic/h3/connection.py src/aioquic/h3/events.py src/aioquic/h3/exceptions.py src/aioquic/quic/__init__.py src/aioquic/quic/configuration.py src/aioquic/quic/connection.py src/aioquic/quic/crypto.py src/aioquic/quic/events.py src/aioquic/quic/logger.py src/aioquic/quic/packet.py src/aioquic/quic/packet_builder.py src/aioquic/quic/rangeset.py src/aioquic/quic/recovery.py src/aioquic/quic/retry.py src/aioquic/quic/stream.py src/aioquic/quic/congestion/__init__.py src/aioquic/quic/congestion/base.py src/aioquic/quic/congestion/cubic.py src/aioquic/quic/congestion/reno.py tests/__init__.py tests/pycacert.pem tests/ssl_cert.pem tests/ssl_cert_with_chain.pem tests/ssl_combined.pem tests/ssl_key.pem tests/test_asyncio.py tests/test_buffer.py tests/test_connection.py tests/test_crypto_v1.py tests/test_crypto_v2.py tests/test_h0.py tests/test_h3.py tests/test_logger.py tests/test_packet.py tests/test_packet_builder.py tests/test_rangeset.py tests/test_recovery.py tests/test_recovery_cubic.py tests/test_recovery_reno.py tests/test_retry.py tests/test_stream.py tests/test_tls.py tests/test_webtransport.py tests/tls_certificate.bin tests/tls_certificate_request.bin tests/tls_certificate_verify.bin tests/tls_client_hello.bin tests/tls_client_hello_with_alpn.bin tests/tls_client_hello_with_psk.bin tests/tls_client_hello_with_sni.bin tests/tls_encrypted_extensions.bin tests/tls_encrypted_extensions_with_alpn.bin tests/tls_encrypted_extensions_with_alpn_and_early_data.bin tests/tls_finished.bin tests/tls_new_session_ticket.bin tests/tls_new_session_ticket_with_unknown_extension.bin tests/tls_server_hello.bin tests/tls_server_hello_with_psk.bin tests/tls_server_hello_with_unknown_extension.bin tests/utils.py././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306888.0 aioquic-1.2.0/src/aioquic.egg-info/dependency_links.txt0000644000175100001770000000000100000000000023733 0ustar00runnerdocker00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306888.0 aioquic-1.2.0/src/aioquic.egg-info/requires.txt0000644000175100001770000000017100000000000022264 0ustar00runnerdocker00000000000000certifi cryptography>=42.0.0 pylsqpack<0.4.0,>=0.3.3 pyopenssl>=24 service-identity>=24.1.0 [dev] coverage[toml]>=7.2.2 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306888.0 aioquic-1.2.0/src/aioquic.egg-info/top_level.txt0000644000175100001770000000001000000000000022406 0ustar00runnerdocker00000000000000aioquic ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1720306888.1372943 aioquic-1.2.0/tests/0000755000175100001770000000000000000000000015114 5ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/__init__.py0000644000175100001770000000000000000000000017213 0ustar00runnerdocker00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/pycacert.pem0000644000175100001770000001303000000000000017426 0ustar00runnerdocker00000000000000Certificate: Data: Version: 3 (0x2) Serial Number: cb:2d:80:99:5a:69:52:5b Signature Algorithm: sha256WithRSAEncryption Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server Validity Not Before: Aug 29 14:23:16 2018 GMT Not After : Aug 26 14:23:16 2028 GMT Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server Subject Public Key Info: Public Key Algorithm: rsaEncryption Public-Key: (3072 bit) Modulus: 00:97:ed:55:41:ba:36:17:95:db:71:1c:d3:e1:61: ac:58:73:e3:c6:96:cf:2b:1f:b8:08:f5:9d:4b:4b: c7:30:f6:b8:0b:b3:52:72:a0:bb:c9:4d:3b:8e:df: 22:8e:01:57:81:c9:92:73:cc:00:c6:ec:70:b0:3a: 17:40:c1:df:f2:8c:36:4c:c4:a7:81:e7:b6:24:68: e2:a0:7e:35:07:2f:a0:5b:f9:45:46:f7:1e:f0:46: 11:fe:ca:1a:3c:50:f1:26:a9:5f:9c:22:9c:f8:41: e1:df:4f:12:95:19:2f:5c:90:01:17:6e:7e:3e:7d: cf:e9:09:af:25:f8:f8:42:77:2d:6d:5f:36:f2:78: 1e:7d:4a:87:68:63:6c:06:71:1b:8d:fa:25:fe:d4: d3:f5:a5:17:b1:ef:ea:17:cb:54:c8:27:99:80:cb: 3c:45:f1:2c:52:1c:dd:1f:51:45:20:50:1e:5e:ab: 57:73:1b:41:78:96:de:84:a4:7a:dd:8f:30:85:36: 58:79:76:a0:d2:61:c8:1b:a9:94:99:63:c6:ee:f8: 14:bf:b4:52:56:31:97:fa:eb:ac:53:9e:95:ce:4c: c4:5a:4a:b7:ca:03:27:5b:35:57:ce:02:dc:ec:ca: 69:f8:8a:5a:39:cb:16:20:15:03:24:61:6c:f4:7a: fc:b6:48:e5:59:10:5c:49:d0:23:9f:fb:71:5e:3a: e9:68:9f:34:72:80:27:b6:3f:4c:b1:d9:db:63:7f: 67:68:4a:6e:11:f8:e8:c0:f4:5a:16:39:53:0b:68: de:77:fa:45:e7:f8:91:cd:78:cd:28:94:97:71:54: fb:cf:f0:37:de:c9:26:c5:dc:1b:9e:89:6d:09:ac: c8:44:71:cb:6d:f1:97:31:d5:4c:20:33:bf:75:4a: a0:e0:dc:69:11:ed:2a:b4:64:10:11:30:8b:0e:b0: a7:10:d8:8a:c5:aa:1b:c8:26:8a:25:e7:66:9f:a5: 6a:1a:2f:7c:5f:83:c6:78:4f:1f Exponent: 65537 (0x10001) X509v3 extensions: X509v3 Subject Key Identifier: DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48 X509v3 Authority Key Identifier: keyid:DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48 X509v3 Basic Constraints: CA:TRUE Signature Algorithm: sha256WithRSAEncryption 33:6a:54:d3:6b:c0:d7:01:5f:9d:f4:05:c1:93:66:90:50:d0: b7:18:e9:b0:1e:4a:a0:b6:da:76:93:af:84:db:ad:15:54:31: 15:13:e4:de:7e:4e:0c:d5:09:1c:34:35:b6:e5:4c:d6:6f:65: 7d:32:5f:eb:fc:a9:6b:07:f7:49:82:e5:81:7e:07:80:9a:63: f8:2c:c3:40:bc:8f:d4:2a:da:3e:d1:ee:08:b7:4d:a7:84:ca: f4:3f:a1:98:45:be:b1:05:69:e7:df:d7:99:ab:1b:ee:8b:30: cc:f7:fc:e7:d4:0b:17:ae:97:bf:e4:7b:fd:0f:a7:b4:85:79: e3:59:e2:16:87:bf:1f:29:45:2c:23:93:76:be:c0:87:1d:de: ec:2b:42:6a:e5:bb:c8:f4:0a:4a:08:0a:8c:5c:d8:7d:4d:d1: b8:bf:d5:f7:29:ed:92:d1:94:04:e8:35:06:57:7f:2c:23:97: 87:a5:35:8d:26:d3:1a:47:f2:16:d7:d9:c6:d4:1f:23:43:d3: 26:99:39:ca:20:f4:71:23:6f:0c:4a:76:76:f7:76:1f:b3:fe: bf:47:b0:fc:2a:56:81:e1:d2:dd:ee:08:d8:f4:ff:5a:dc:25: 61:8a:91:02:b9:86:1c:f2:50:73:76:25:35:fc:b6:25:26:15: cb:eb:c4:2b:61:0c:1c:e7:ee:2f:17:9b:ec:f0:d4:a1:84:e7: d2:af:de:e4:1b:24:14:a7:01:87:e3:ab:29:58:46:a0:d9:c0: 0a:e0:8d:d7:59:d3:1b:f8:54:20:3e:78:a5:a5:c8:4f:8b:03: c4:96:9f:ec:fb:47:cf:76:2d:8d:65:34:27:bf:fa:ae:01:05: 8a:f3:92:0a:dd:89:6c:97:a1:c7:e7:60:51:e7:ac:eb:4b:7d: 2c:b8:65:c9:fe:5d:6a:48:55:8e:e4:c7:f9:6a:40:e1:b8:64: 45:e9:b5:59:29:a5:5f:cf:7d:58:7d:64:79:e5:a4:09:ac:1e: 76:65:3d:94:c4:68 -----BEGIN CERTIFICATE----- MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1 nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo 4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc 3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7 cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU +8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3 SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8 59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo -----END CERTIFICATE----- ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/ssl_cert.pem0000644000175100001770000000411200000000000017433 0ustar00runnerdocker00000000000000-----BEGIN CERTIFICATE----- MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA 3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr +UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== -----END CERTIFICATE----- ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/ssl_cert_with_chain.pem0000644000175100001770000000720400000000000021635 0ustar00runnerdocker00000000000000-----BEGIN CERTIFICATE----- MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA 3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr +UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1 nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo 4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc 3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7 cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU +8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3 SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8 59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo -----END CERTIFICATE----- ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/ssl_combined.pem0000644000175100001770000001077700000000000020274 0ustar00runnerdocker00000000000000-----BEGIN CERTIFICATE----- MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA 3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr +UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== -----END CERTIFICATE----- -----BEGIN PRIVATE KEY----- MIIG/QIBADANBgkqhkiG9w0BAQEFAASCBucwggbjAgEAAoIBgQCfKC83Qe9/ZGMW YhbpARRiKco6mJI9CNNeaf7A89TE+w5Y3GSwS8uzqp5C6QebZzPNueg8HYoTwN85 Z3xM036/Qw9KhQVth+XDAqM+19e5KHkYcxg3d3ZI1HgY170eakaLBvMDN5ULoFOw Is2PtwM2o9cjd5mfSuWttI6+fCqop8/l8cerG9iX2GH39p3iWwWoTZuYndAA9qYv 07YWajuQ1ESWKPjHYGTnMvu4xIzibC1mXd2M6u/IjNO6g426SKFaRDWQkx01gIV/ CyKs9DgZoeMHkKZuPqZVOxOK+A/NrmrqHFsPIsrs5wk7QAVju5/X1skpn/UGQlmM RwBaQULOs1FagA+54RXU6qUPW0YmhJ4xOB4gHHD1vjAKEsRZ7/6zcxMyOm+M1DbK RTH4NWjVWpnY8XaVGdRhtTpH9MjycpKhF+D2Zdy2tQXtqu2GdcMnUedt13fn9xDu P4PophE0ip/IMgn+kb4m9e+S+K9lldQl0B+4BcGWAqHelh2KuU0CAwEAAQKCAYEA lKiWIYjmyRjdLKUGPTES9vWNvNmRjozV0RQ0LcoSbMMLDZkeO0UwyWqOVHUQ8+ib jIcfEjeNJxI57oZopeHOO5vJhpNlFH+g7ltiW2qERqA1K88lSXm99Bzw6FNqhCRE K8ub5N9fyfJA+P4o/xm0WK8EXk5yIUV17p/9zJJxzgKgv2jsVTi3QG2OZGvn4Oug ByomMZEGHkBDzdxz8c/cP1Tlk1RFuwSgews178k2xq7AYSM/s0YmHi7b/RSvptX6 1v8P8kXNUe4AwTaNyrlvF2lwIadZ8h1hA7tCE2n44b7a7KfhAkwcbr1T59ioYh6P zxsyPT678uD51dbtD/DXJCcoeeFOb8uzkR2KNcrnQzZpCJnRq4Gp5ybxwsxxuzpr gz0gbNlhuWtE7EoSzmIK9t+WTS7IM2CvZymd6/OAh1Fuw6AQhSp64XRp3OfMMAAC Ie2EPtKj4islWGT8VoUjuRYGmdRh4duAH1dkiAXOWA3R7y5a1/y/iE8KE8BtxocB AoHBAM8aiURgpu1Fs0Oqz6izec7KSLL3l8hmW+MKUOfk/Ybng6FrTFsL5YtzR+Ap wW4wwWnnIKEc1JLiZ7g8agRETK8hr5PwFXUn/GSWC0SMsazLJToySQS5LOV0tLzK kJ3jtNU7tnlDGNkCHTHSoVL2T/8t+IkZI/h5Z6wjlYPvU2Iu0nVIXtiG+alv4A6M Hrh9l5or4mjB6rGnVXeYohLkCm6s/W97ahVxLMcEdbsBo1prm2JqGnSoiR/tEFC/ QHQnbQKBwQDEu7kW0Yg9sZ89QtYtVQ1YpixFZORaUeRIRLnpEs1w7L1mCbOZ2Lj9 JHxsH05cYAc7HJfPwwxv3+3aGAIC/dfu4VSwEFtatAzUpzlhzKS5+HQCWB4JUNNU MQ3+FwK2xQX4Ph8t+OzrFiYcK2g0An5UxWMa2HWIAWUOhnTOydAVsoH6yP31cVm4 0hxoABCwflaNLNGjRUyfBpLTAcNu/YtcE+KREy7YAAgXXrhRSO4XpLsSXwLnLT7/ YOkoBWDcTWECgcBPWnSUDZCIQ3efithMZJBciqd2Y2X19Dpq8O31HImD4jtOY0V7 cUB/wSkeHAGwjd/eCyA2e0x8B2IEdqmMfvr+86JJxekC3dJYXCFvH5WIhsH53YCa 3bT1KlWCLP9ib/g+58VQC0R/Cc9T4sfLePNH7D5ZkZd1wlbV30CPr+i8KwKay6MD xhvtLx+jk07GE+E9wmjbCMo7TclyrLoVEOlqZMAqshgApT+p9eyCPetwXuDHwa3n WxhHclcZCV7R4rUCgcAkdGSnxcvpIrDPOUNWwxvmAWTStw9ZbTNP8OxCNCm9cyDl d4bAS1h8D/a+Uk7C70hnu7Sl2w7C7Eu2zhwRUdhhe3+l4GINPK/j99i6NqGPlGpq xMlMEJ4YS768BqeKFpg0l85PRoEgTsphDeoROSUPsEPdBZ9BxIBlYKTkbKESZDGR twzYHljx1n1NCDYPflmrb1KpXn4EOcObNghw2KqqNUUWfOeBPwBA1FxzM4BrAStp DBINpGS4Dc0mjViVegECgcA3hTtm82XdxQXj9LQmb/E3lKx/7H87XIOeNMmvjYuZ iS9wKrkF+u42vyoDxcKMCnxP5056wpdST4p56r+SBwVTHcc3lGBSGcMTIfwRXrj3 thOA2our2n4ouNIsYyTlcsQSzifwmpRmVMRPxl9fYVdEWUgB83FgHT0D9avvZnF9 t9OccnGJXShAIZIBADhVj/JwG4FbaX42NijD5PNpVLk1Y17OV0I576T9SfaQoBjJ aH1M/zC4aVaS0DYB/Gxq7v8= -----END PRIVATE KEY----- ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/ssl_key.pem0000644000175100001770000000466400000000000017302 0ustar00runnerdocker00000000000000-----BEGIN PRIVATE KEY----- MIIG/QIBADANBgkqhkiG9w0BAQEFAASCBucwggbjAgEAAoIBgQCfKC83Qe9/ZGMW YhbpARRiKco6mJI9CNNeaf7A89TE+w5Y3GSwS8uzqp5C6QebZzPNueg8HYoTwN85 Z3xM036/Qw9KhQVth+XDAqM+19e5KHkYcxg3d3ZI1HgY170eakaLBvMDN5ULoFOw Is2PtwM2o9cjd5mfSuWttI6+fCqop8/l8cerG9iX2GH39p3iWwWoTZuYndAA9qYv 07YWajuQ1ESWKPjHYGTnMvu4xIzibC1mXd2M6u/IjNO6g426SKFaRDWQkx01gIV/ CyKs9DgZoeMHkKZuPqZVOxOK+A/NrmrqHFsPIsrs5wk7QAVju5/X1skpn/UGQlmM RwBaQULOs1FagA+54RXU6qUPW0YmhJ4xOB4gHHD1vjAKEsRZ7/6zcxMyOm+M1DbK RTH4NWjVWpnY8XaVGdRhtTpH9MjycpKhF+D2Zdy2tQXtqu2GdcMnUedt13fn9xDu P4PophE0ip/IMgn+kb4m9e+S+K9lldQl0B+4BcGWAqHelh2KuU0CAwEAAQKCAYEA lKiWIYjmyRjdLKUGPTES9vWNvNmRjozV0RQ0LcoSbMMLDZkeO0UwyWqOVHUQ8+ib jIcfEjeNJxI57oZopeHOO5vJhpNlFH+g7ltiW2qERqA1K88lSXm99Bzw6FNqhCRE K8ub5N9fyfJA+P4o/xm0WK8EXk5yIUV17p/9zJJxzgKgv2jsVTi3QG2OZGvn4Oug ByomMZEGHkBDzdxz8c/cP1Tlk1RFuwSgews178k2xq7AYSM/s0YmHi7b/RSvptX6 1v8P8kXNUe4AwTaNyrlvF2lwIadZ8h1hA7tCE2n44b7a7KfhAkwcbr1T59ioYh6P zxsyPT678uD51dbtD/DXJCcoeeFOb8uzkR2KNcrnQzZpCJnRq4Gp5ybxwsxxuzpr gz0gbNlhuWtE7EoSzmIK9t+WTS7IM2CvZymd6/OAh1Fuw6AQhSp64XRp3OfMMAAC Ie2EPtKj4islWGT8VoUjuRYGmdRh4duAH1dkiAXOWA3R7y5a1/y/iE8KE8BtxocB AoHBAM8aiURgpu1Fs0Oqz6izec7KSLL3l8hmW+MKUOfk/Ybng6FrTFsL5YtzR+Ap wW4wwWnnIKEc1JLiZ7g8agRETK8hr5PwFXUn/GSWC0SMsazLJToySQS5LOV0tLzK kJ3jtNU7tnlDGNkCHTHSoVL2T/8t+IkZI/h5Z6wjlYPvU2Iu0nVIXtiG+alv4A6M Hrh9l5or4mjB6rGnVXeYohLkCm6s/W97ahVxLMcEdbsBo1prm2JqGnSoiR/tEFC/ QHQnbQKBwQDEu7kW0Yg9sZ89QtYtVQ1YpixFZORaUeRIRLnpEs1w7L1mCbOZ2Lj9 JHxsH05cYAc7HJfPwwxv3+3aGAIC/dfu4VSwEFtatAzUpzlhzKS5+HQCWB4JUNNU MQ3+FwK2xQX4Ph8t+OzrFiYcK2g0An5UxWMa2HWIAWUOhnTOydAVsoH6yP31cVm4 0hxoABCwflaNLNGjRUyfBpLTAcNu/YtcE+KREy7YAAgXXrhRSO4XpLsSXwLnLT7/ YOkoBWDcTWECgcBPWnSUDZCIQ3efithMZJBciqd2Y2X19Dpq8O31HImD4jtOY0V7 cUB/wSkeHAGwjd/eCyA2e0x8B2IEdqmMfvr+86JJxekC3dJYXCFvH5WIhsH53YCa 3bT1KlWCLP9ib/g+58VQC0R/Cc9T4sfLePNH7D5ZkZd1wlbV30CPr+i8KwKay6MD xhvtLx+jk07GE+E9wmjbCMo7TclyrLoVEOlqZMAqshgApT+p9eyCPetwXuDHwa3n WxhHclcZCV7R4rUCgcAkdGSnxcvpIrDPOUNWwxvmAWTStw9ZbTNP8OxCNCm9cyDl d4bAS1h8D/a+Uk7C70hnu7Sl2w7C7Eu2zhwRUdhhe3+l4GINPK/j99i6NqGPlGpq xMlMEJ4YS768BqeKFpg0l85PRoEgTsphDeoROSUPsEPdBZ9BxIBlYKTkbKESZDGR twzYHljx1n1NCDYPflmrb1KpXn4EOcObNghw2KqqNUUWfOeBPwBA1FxzM4BrAStp DBINpGS4Dc0mjViVegECgcA3hTtm82XdxQXj9LQmb/E3lKx/7H87XIOeNMmvjYuZ iS9wKrkF+u42vyoDxcKMCnxP5056wpdST4p56r+SBwVTHcc3lGBSGcMTIfwRXrj3 thOA2our2n4ouNIsYyTlcsQSzifwmpRmVMRPxl9fYVdEWUgB83FgHT0D9avvZnF9 t9OccnGJXShAIZIBADhVj/JwG4FbaX42NijD5PNpVLk1Y17OV0I576T9SfaQoBjJ aH1M/zC4aVaS0DYB/Gxq7v8= -----END PRIVATE KEY----- ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_asyncio.py0000644000175100001770000004043700000000000020202 0ustar00runnerdocker00000000000000import asyncio import binascii import contextlib import random import socket from unittest import TestCase, skipIf from unittest.mock import patch from aioquic.asyncio.client import connect from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.asyncio.server import serve from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.logger import QuicLogger from cryptography.hazmat.primitives import serialization from .utils import ( SERVER_CACERTFILE, SERVER_CERTFILE, SERVER_COMBINEDFILE, SERVER_KEYFILE, SKIP_TESTS, asynctest, generate_ec_certificate, generate_ed448_certificate, generate_ed25519_certificate, generate_rsa_certificate, ) real_sendto = socket.socket.sendto def sendto_with_loss(self, data, addr=None): """ Simulate 25% packet loss. """ if random.random() > 0.25: real_sendto(self, data, addr) class SessionTicketStore: def __init__(self): self.tickets = {} def add(self, ticket): self.tickets[ticket.ticket] = ticket def pop(self, label): return self.tickets.pop(label, None) def handle_stream(reader, writer): async def serve(): data = await reader.read() writer.write(bytes(reversed(data))) writer.write_eof() asyncio.ensure_future(serve()) class HighLevelTest(TestCase): def setUp(self): self.bogus_port = 1024 self.server_host = "localhost" async def run_client( self, *, port: int, host=None, cadata=None, cafile=SERVER_CACERTFILE, configuration=None, request=b"ping", **kwargs, ): if host is None: host = self.server_host if configuration is None: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cadata=cadata, cafile=cafile) async with connect(host, port, configuration=configuration, **kwargs) as client: # waiting for connected when connected returns immediately await client.wait_connected() reader, writer = await client.create_stream() self.assertEqual(writer.can_write_eof(), True) self.assertEqual(writer.get_extra_info("stream_id"), 0) writer.write(request) writer.write_eof() response = await reader.read() # explicit no-op close to test that multiple closes are harmless. writer.close() # waiting for closed when closed returns immediately await client.wait_closed() return response @contextlib.asynccontextmanager async def run_server(self, configuration=None, host="::", **kwargs): if configuration is None: configuration = QuicConfiguration(is_client=False) configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) server = await serve( host=host, port=0, configuration=configuration, stream_handler=handle_stream, **kwargs, ) try: yield server._transport.get_extra_info("sockname")[1] finally: server.close() @asynctest async def test_connect_and_serve(self): async with self.run_server() as server_port: response = await self.run_client(port=server_port) self.assertEqual(response, b"gnip") @asynctest async def test_connect_and_serve_ipv4(self): certificate, private_key = generate_rsa_certificate( alternative_names=["localhost", "127.0.0.1"], common_name="localhost" ) async with self.run_server( configuration=QuicConfiguration( certificate=certificate, private_key=private_key, is_client=False, ), host="0.0.0.0", ) as server_port: response = await self.run_client( cadata=certificate.public_bytes(serialization.Encoding.PEM), cafile=None, host="127.0.0.1", port=server_port, ) self.assertEqual(response, b"gnip") @skipIf("ipv6" in SKIP_TESTS, "Skipping IPv6 tests") @asynctest async def test_connect_and_serve_ipv6(self): certificate, private_key = generate_rsa_certificate( alternative_names=["localhost", "::1"], common_name="localhost" ) async with self.run_server( configuration=QuicConfiguration( certificate=certificate, private_key=private_key, is_client=False, ), host="::", ) as server_port: response = await self.run_client( cadata=certificate.public_bytes(serialization.Encoding.PEM), cafile=None, host="::1", port=server_port, ) self.assertEqual(response, b"gnip") async def _test_connect_and_serve_with_certificate(self, certificate, private_key): async with self.run_server( configuration=QuicConfiguration( certificate=certificate, private_key=private_key, is_client=False, ) ) as server_port: response = await self.run_client( cadata=certificate.public_bytes(serialization.Encoding.PEM), cafile=None, port=server_port, ) self.assertEqual(response, b"gnip") @asynctest async def test_connect_and_serve_with_ec_certificate(self): await self._test_connect_and_serve_with_certificate( *generate_ec_certificate( alternative_names=["localhost"], common_name="localhost" ) ) @asynctest async def test_connect_and_serve_with_ed25519_certificate(self): await self._test_connect_and_serve_with_certificate( *generate_ed25519_certificate( alternative_names=["localhost"], common_name="localhost" ) ) @asynctest async def test_connect_and_serve_with_ed448_certificate(self): await self._test_connect_and_serve_with_certificate( *generate_ed448_certificate( alternative_names=["localhost"], common_name="localhost" ) ) @asynctest async def test_connect_and_serve_with_rsa_certificate(self): await self._test_connect_and_serve_with_certificate( *generate_rsa_certificate( alternative_names=["localhost"], common_name="localhost" ) ) @asynctest async def test_connect_and_serve_large(self): """ Transfer enough data to require raising MAX_DATA and MAX_STREAM_DATA. """ data = b"Z" * 2097152 async with self.run_server() as server_port: response = await self.run_client(port=server_port, request=data) self.assertEqual(response, data) @asynctest async def test_connect_and_serve_without_client_configuration(self): async with self.run_server() as server_port: with self.assertRaises(ConnectionError): async with connect(self.server_host, server_port) as client: await client.ping() @asynctest async def test_connect_and_serve_writelines(self): async with self.run_server() as server_port: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, server_port, configuration=configuration ) as client: reader, writer = await client.create_stream() assert writer.can_write_eof() is True writer.writelines([b"01234567", b"89012345"]) writer.write_eof() response = await reader.read() self.assertEqual(response, b"5432109876543210") @skipIf("loss" in SKIP_TESTS, "Skipping loss tests") @patch("socket.socket.sendto", new_callable=lambda: sendto_with_loss) @asynctest async def test_connect_and_serve_with_packet_loss(self, mock_sendto): """ This test ensures handshake success and stream data is successfully sent and received in the presence of packet loss (randomized 25% in each direction). """ data = b"Z" * 65536 server_configuration = QuicConfiguration( is_client=False, quic_logger=QuicLogger() ) server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) async with self.run_server(configuration=server_configuration) as server_port: response = await self.run_client( configuration=QuicConfiguration( is_client=True, quic_logger=QuicLogger() ), port=server_port, request=data, ) self.assertEqual(response, data) @asynctest async def test_connect_and_serve_with_session_ticket(self): client_ticket = None store = SessionTicketStore() def save_ticket(t): nonlocal client_ticket client_ticket = t async with self.run_server( session_ticket_fetcher=store.pop, session_ticket_handler=store.add ) as server_port: # first request response = await self.run_client( port=server_port, session_ticket_handler=save_ticket ) self.assertEqual(response, b"gnip") self.assertIsNotNone(client_ticket) # second request response = await self.run_client( configuration=QuicConfiguration( is_client=True, session_ticket=client_ticket ), port=server_port, ) self.assertEqual(response, b"gnip") @asynctest async def test_connect_and_serve_with_retry(self): async with self.run_server(retry=True) as server_port: response = await self.run_client(port=server_port) self.assertEqual(response, b"gnip") @asynctest async def test_connect_and_serve_with_retry_bad_original_destination_connection_id( self, ): """ If the server's transport parameters do not have the correct original_destination_connection_id the connection must fail. """ def create_protocol(*args, **kwargs): protocol = QuicConnectionProtocol(*args, **kwargs) protocol._quic._original_destination_connection_id = None return protocol async with self.run_server( create_protocol=create_protocol, retry=True ) as server_port: with self.assertRaises(ConnectionError): await self.run_client(port=server_port) @asynctest async def test_connect_and_serve_with_retry_bad_retry_source_connection_id(self): """ If the server's transport parameters do not have the correct retry_source_connection_id the connection must fail. """ def create_protocol(*args, **kwargs): protocol = QuicConnectionProtocol(*args, **kwargs) protocol._quic._retry_source_connection_id = None return protocol async with self.run_server( create_protocol=create_protocol, retry=True ) as server_port: with self.assertRaises(ConnectionError): await self.run_client(port=server_port) @patch("aioquic.quic.retry.QuicRetryTokenHandler.validate_token") @asynctest async def test_connect_and_serve_with_retry_bad_token(self, mock_validate): mock_validate.side_effect = ValueError("Decryption failed.") async with self.run_server(retry=True) as server_port: with self.assertRaises(ConnectionError): await self.run_client( configuration=QuicConfiguration(is_client=True, idle_timeout=4.0), port=server_port, ) @asynctest async def test_connect_and_serve_with_version_negotiation(self): async with self.run_server() as server_port: # force version negotiation configuration = QuicConfiguration(is_client=True, quic_logger=QuicLogger()) configuration.supported_versions.insert(0, 0x1A2A3A4A) response = await self.run_client( configuration=configuration, port=server_port ) self.assertEqual(response, b"gnip") @asynctest async def test_connect_timeout(self): with self.assertRaises(ConnectionError): await self.run_client( port=self.bogus_port, configuration=QuicConfiguration(is_client=True, idle_timeout=5), ) @asynctest async def test_connect_timeout_no_wait_connected(self): with self.assertRaises(ConnectionError): configuration = QuicConfiguration(is_client=True, idle_timeout=5) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, self.bogus_port, configuration=configuration, wait_connected=False, ) as client: await client.ping() @asynctest async def test_connect_local_port(self): async with self.run_server() as server_port: response = await self.run_client(local_port=3456, port=server_port) self.assertEqual(response, b"gnip") @asynctest async def test_connect_local_port_bind(self): with self.assertRaises(OverflowError): await self.run_client(local_port=-1, port=self.bogus_port) @asynctest async def test_change_connection_id(self): async with self.run_server() as server_port: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, server_port, configuration=configuration ) as client: await client.ping() client.change_connection_id() await client.ping() @asynctest async def test_key_update(self): async with self.run_server() as server_port: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, server_port, configuration=configuration ) as client: await client.ping() client.request_key_update() await client.ping() @asynctest async def test_ping(self): async with self.run_server() as server_port: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, server_port, configuration=configuration ) as client: await client.ping() await client.ping() @asynctest async def test_ping_parallel(self): async with self.run_server() as server_port: configuration = QuicConfiguration(is_client=True) configuration.load_verify_locations(cafile=SERVER_CACERTFILE) async with connect( self.server_host, server_port, configuration=configuration ) as client: coros = [client.ping() for x in range(16)] await asyncio.gather(*coros) @asynctest async def test_server_receives_garbage(self): configuration = QuicConfiguration(is_client=False) configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) server = await serve( host=self.server_host, port=0, configuration=configuration, ) server.datagram_received(binascii.unhexlify("c00000000080"), ("1.2.3.4", 1234)) server.close() @asynctest async def test_combined_key(self): config1 = QuicConfiguration() config2 = QuicConfiguration() config1.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) config2.load_cert_chain(SERVER_COMBINEDFILE) self.assertEqual(config1.certificate, config2.certificate) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_buffer.py0000644000175100001770000001512600000000000020003 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.buffer import Buffer, BufferReadError, BufferWriteError, size_uint_var class BufferTest(TestCase): def test_data_slice(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.data_slice(0, 8), b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.data_slice(1, 3), b"\x07\x06") with self.assertRaises(BufferReadError): buf.data_slice(-1, 3) with self.assertRaises(BufferReadError): buf.data_slice(0, 9) with self.assertRaises(BufferReadError): buf.data_slice(1, 0) def test_pull_bytes(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_bytes(3), b"\x08\x07\x06") def test_pull_bytes_negative(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") with self.assertRaises(BufferReadError): buf.pull_bytes(-1) def test_pull_bytes_truncated(self): buf = Buffer(capacity=0) with self.assertRaises(BufferReadError): buf.pull_bytes(2) self.assertEqual(buf.tell(), 0) def test_pull_bytes_zero(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_bytes(0), b"") def test_pull_uint8(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_uint8(), 0x08) self.assertEqual(buf.tell(), 1) def test_pull_uint8_truncated(self): buf = Buffer(capacity=0) with self.assertRaises(BufferReadError): buf.pull_uint8() self.assertEqual(buf.tell(), 0) def test_pull_uint16(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_uint16(), 0x0807) self.assertEqual(buf.tell(), 2) def test_pull_uint16_truncated(self): buf = Buffer(capacity=1) with self.assertRaises(BufferReadError): buf.pull_uint16() self.assertEqual(buf.tell(), 0) def test_pull_uint32(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_uint32(), 0x08070605) self.assertEqual(buf.tell(), 4) def test_pull_uint32_truncated(self): buf = Buffer(capacity=3) with self.assertRaises(BufferReadError): buf.pull_uint32() self.assertEqual(buf.tell(), 0) def test_pull_uint64(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_uint64(), 0x0807060504030201) self.assertEqual(buf.tell(), 8) def test_pull_uint64_truncated(self): buf = Buffer(capacity=7) with self.assertRaises(BufferReadError): buf.pull_uint64() self.assertEqual(buf.tell(), 0) def test_push_bytes(self): buf = Buffer(capacity=3) buf.push_bytes(b"\x08\x07\x06") self.assertEqual(buf.data, b"\x08\x07\x06") self.assertEqual(buf.tell(), 3) def test_push_bytes_truncated(self): buf = Buffer(capacity=3) with self.assertRaises(BufferWriteError): buf.push_bytes(b"\x08\x07\x06\x05") self.assertEqual(buf.tell(), 0) def test_push_bytes_zero(self): buf = Buffer(capacity=3) buf.push_bytes(b"") self.assertEqual(buf.data, b"") self.assertEqual(buf.tell(), 0) def test_push_uint8(self): buf = Buffer(capacity=1) buf.push_uint8(0x08) self.assertEqual(buf.data, b"\x08") self.assertEqual(buf.tell(), 1) def test_push_uint16(self): buf = Buffer(capacity=2) buf.push_uint16(0x0807) self.assertEqual(buf.data, b"\x08\x07") self.assertEqual(buf.tell(), 2) def test_push_uint32(self): buf = Buffer(capacity=4) buf.push_uint32(0x08070605) self.assertEqual(buf.data, b"\x08\x07\x06\x05") self.assertEqual(buf.tell(), 4) def test_push_uint64(self): buf = Buffer(capacity=8) buf.push_uint64(0x0807060504030201) self.assertEqual(buf.data, b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.tell(), 8) def test_seek(self): buf = Buffer(data=b"01234567") self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 0) buf.seek(4) self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 4) buf.seek(8) self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(-1) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(9) self.assertEqual(buf.tell(), 8) class UintVarTest(TestCase): def roundtrip(self, data, value): buf = Buffer(data=data) self.assertEqual(buf.pull_uint_var(), value) self.assertEqual(buf.tell(), len(data)) buf = Buffer(capacity=8) buf.push_uint_var(value) self.assertEqual(buf.data, data) def test_uint_var(self): # 1 byte self.roundtrip(b"\x00", 0) self.roundtrip(b"\x01", 1) self.roundtrip(b"\x25", 37) self.roundtrip(b"\x3f", 63) # 2 bytes self.roundtrip(b"\x7b\xbd", 15293) self.roundtrip(b"\x7f\xff", 16383) # 4 bytes self.roundtrip(b"\x9d\x7f\x3e\x7d", 494878333) self.roundtrip(b"\xbf\xff\xff\xff", 1073741823) # 8 bytes self.roundtrip(b"\xc2\x19\x7c\x5e\xff\x14\xe8\x8c", 151288809941952652) self.roundtrip(b"\xff\xff\xff\xff\xff\xff\xff\xff", 4611686018427387903) def test_pull_uint_var_truncated(self): buf = Buffer(capacity=0) with self.assertRaises(BufferReadError): buf.pull_uint_var() buf = Buffer(data=b"\xff") with self.assertRaises(BufferReadError): buf.pull_uint_var() def test_push_uint_var_too_big(self): buf = Buffer(capacity=8) with self.assertRaises(ValueError) as cm: buf.push_uint_var(4611686018427387904) self.assertEqual( str(cm.exception), "Integer is too big for a variable-length integer" ) def test_size_uint_var(self): self.assertEqual(size_uint_var(63), 1) self.assertEqual(size_uint_var(16383), 2) self.assertEqual(size_uint_var(1073741823), 4) self.assertEqual(size_uint_var(4611686018427387903), 8) with self.assertRaises(ValueError) as cm: size_uint_var(4611686018427387904) self.assertEqual( str(cm.exception), "Integer is too big for a variable-length integer" ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_connection.py0000644000175100001770000041272700000000000020701 0ustar00runnerdocker00000000000000import binascii import contextlib import io import time from typing import List, Tuple from unittest import TestCase, skipIf from aioquic import tls from aioquic.buffer import UINT_VAR_MAX, Buffer, encode_uint_var from aioquic.quic import events from aioquic.quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration from aioquic.quic.connection import ( MAX_LOCAL_CHALLENGES, MAX_PENDING_CRYPTO, STREAM_COUNT_MAX, NetworkAddress, QuicConnection, QuicConnectionError, QuicNetworkPath, QuicReceiveContext, ) from aioquic.quic.crypto import CryptoPair from aioquic.quic.logger import QuicLogger from aioquic.quic.packet import ( QuicErrorCode, QuicFrameType, QuicPacketType, QuicProtocolVersion, QuicTransportParameters, QuicVersionInformation, encode_quic_retry, encode_quic_version_negotiation, push_quic_transport_parameters, ) from aioquic.quic.packet_builder import QuicDeliveryState, QuicPacketBuilder from aioquic.quic.recovery import QuicPacketPacer from .utils import ( SERVER_CACERTFILE, SERVER_CERTFILE, SERVER_CERTFILE_WITH_CHAIN, SERVER_KEYFILE, SKIP_TESTS, ) CLIENT_ADDR = ("1.2.3.4", 1234) CLIENT_HANDSHAKE_DATAGRAM_SIZES = [1200] SERVER_ADDR = ("2.3.4.5", 4433) SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1162] HANDSHAKE_COMPLETED_EVENTS = [ events.HandshakeCompleted, events.ConnectionIdIssued, events.ConnectionIdIssued, events.ConnectionIdIssued, events.ConnectionIdIssued, events.ConnectionIdIssued, events.ConnectionIdIssued, events.ConnectionIdIssued, ] TICK = 0.05 # seconds class SessionTicketStore: def __init__(self): self.tickets = {} def add(self, ticket): self.tickets[ticket.ticket] = ticket def pop(self, label): return self.tickets.pop(label, None) def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): return QuicReceiveContext( epoch=epoch, host_cid=client.host_cid, network_path=client._network_paths[0], quic_logger_frames=[], time=time.time(), version=None, ) def consume_events(connection): while True: event = connection.next_event() if event is None: break def create_standalone_client(self, **client_options): client = QuicConnection( configuration=QuicConfiguration( is_client=True, quic_logger=QuicLogger(), **client_options ) ) client._ack_delay = 0 # kick-off handshake client.connect(SERVER_ADDR, now=time.time()) self.assertEqual(drop(client), 1) return client def create_standalone_server(self, original_destination_connection_id=bytes(8)): server_configuration = QuicConfiguration(is_client=False, quic_logger=QuicLogger()) server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) server = QuicConnection( configuration=server_configuration, original_destination_connection_id=original_destination_connection_id, ) server._ack_delay = 0 return server def datagram_sizes(items: List[Tuple[bytes, NetworkAddress]]) -> List[int]: return [len(x[0]) for x in items] def new_connection_id( *, sequence_number: int, retire_prior_to: int = 0, connection_id: bytes = bytes(8), capacity: int = 100, ): buf = Buffer(capacity=capacity) buf.push_uint_var(sequence_number) buf.push_uint_var(retire_prior_to) buf.push_uint_var(len(connection_id)) buf.push_bytes(connection_id) buf.push_bytes(bytes(16)) # stateless reset token buf.seek(0) return buf @contextlib.contextmanager def client_and_server( client_kwargs={}, client_options={}, client_patch=lambda x: None, handshake=True, server_kwargs={}, server_certfile=SERVER_CERTFILE, server_keyfile=SERVER_KEYFILE, server_options={}, server_patch=lambda x: None, ): client_configuration = QuicConfiguration( is_client=True, quic_logger=QuicLogger(), **client_options ) client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) client = QuicConnection(configuration=client_configuration, **client_kwargs) client._ack_delay = 0 disable_packet_pacing(client) client_patch(client) server_configuration = QuicConfiguration( is_client=False, quic_logger=QuicLogger(), **server_options ) server_configuration.load_cert_chain(server_certfile, server_keyfile) server = QuicConnection( configuration=server_configuration, original_destination_connection_id=client.original_destination_connection_id, **server_kwargs, ) server._ack_delay = 0 disable_packet_pacing(server) server_patch(server) # perform handshake if handshake: client.connect(SERVER_ADDR, now=time.time()) for i in range(3): roundtrip(client, server) yield client, server # close client.close() server.close() def disable_packet_pacing(connection): class DummyPacketPacer(QuicPacketPacer): def __init__(self): super().__init__(max_datagram_size=SMALLEST_MAX_DATAGRAM_SIZE) def next_send_time(self, now): return None connection._loss._pacer = DummyPacketPacer() def encode_transport_parameters(parameters: QuicTransportParameters) -> bytes: buf = Buffer(capacity=512) push_quic_transport_parameters(buf, parameters) return buf.data def sequence_numbers(connection_ids): return list(map(lambda x: x.sequence_number, connection_ids)) def drop(sender): """ Drop datagrams from `sender`. """ return len(sender.datagrams_to_send(now=time.time())) def roundtrip(sender, receiver): """ Send datagrams from `sender` to `receiver` and back. """ return (transfer(sender, receiver), transfer(receiver, sender)) def roundtrip_until_done(sender, receiver): """ Send datagrams from `sender` to `receiver` and back repeatedly until no more datagrams are exchanged. """ rounds = 0 while roundtrip(sender, receiver) != (0, 0): rounds += 1 assert rounds < 10, "Too many roundtrips!" def transfer(sender, receiver): """ Send datagrams from `sender` to `receiver`. """ datagrams = 0 from_addr = CLIENT_ADDR if sender._is_client else SERVER_ADDR for data, addr in sender.datagrams_to_send(now=time.time()): datagrams += 1 receiver.receive_datagram(data, from_addr, now=time.time()) return datagrams class QuicConnectionTest(TestCase): def assertEvents(self, connection: QuicConnection, expected: list): types = [] while True: event = connection.next_event() if event is not None: types.append(type(event)) else: break self.assertEqual(types, expected) def assertPacketDropped(self, connection: QuicConnection, trigger: str): log = connection.configuration.quic_logger.to_dict() found_trigger = None for event in log["traces"][0]["events"]: if event["name"] == "transport:packet_dropped": found_trigger = event["data"]["trigger"] break self.assertEqual(found_trigger, trigger) def assertSentPackets(self, connection: QuicConnection, expected: List[int]): counts = [len(space.sent_packets) for space in connection._loss.spaces] self.assertEqual(counts, expected) def check_handshake(self, client, server, alpn_protocol=None): """ Check handshake completed. """ event = client.next_event() self.assertEqual(type(event), events.ProtocolNegotiated) self.assertEqual(event.alpn_protocol, alpn_protocol) event = client.next_event() self.assertEqual(type(event), events.HandshakeCompleted) self.assertEqual(event.alpn_protocol, alpn_protocol) self.assertEqual(event.early_data_accepted, False) self.assertEqual(event.session_resumed, False) for i in range(7): self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) self.assertIsNone(client.next_event()) event = server.next_event() self.assertEqual(type(event), events.ProtocolNegotiated) self.assertEqual(event.alpn_protocol, alpn_protocol) event = server.next_event() self.assertEqual(type(event), events.HandshakeCompleted) self.assertEqual(event.alpn_protocol, alpn_protocol) for i in range(7): self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) self.assertIsNone(server.next_event()) def test_connect(self): with client_and_server() as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # check each endpoint has available connection IDs for the peer self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) self.assertEqual( sequence_numbers(server._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) # client closes the connection client.close() self.assertEqual(transfer(client, server), 1) # check connection closes on the client side client.handle_timer(client.get_timer()) event = client.next_event() self.assertEqual(type(event), events.ConnectionTerminated) self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(event.frame_type, None) self.assertEqual(event.reason_phrase, "") self.assertIsNone(client.next_event()) # check connection closes on the server side server.handle_timer(server.get_timer()) event = server.next_event() self.assertEqual(type(event), events.ConnectionTerminated) self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(event.frame_type, None) self.assertEqual(event.reason_phrase, "") self.assertIsNone(server.next_event()) # check client log client_log = client.configuration.quic_logger.to_dict() self.assertGreater(len(client_log["traces"][0]["events"]), 20) # check server log server_log = server.configuration.quic_logger.to_dict() self.assertGreater(len(server_log["traces"][0]["events"]), 20) def test_connect_with_alpn(self): with client_and_server( client_options={"alpn_protocols": ["h3", "hq-interop"]}, server_options={"alpn_protocols": ["hq-interop"]}, ) as (client, server): # check handshake completed self.check_handshake( client=client, server=server, alpn_protocol="hq-interop" ) def test_connect_with_secrets_log(self): client_log_file = io.StringIO() server_log_file = io.StringIO() with client_and_server( client_options={"secrets_log_file": client_log_file}, server_options={"secrets_log_file": server_log_file}, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # check secrets were logged client_log = client_log_file.getvalue() server_log = server_log_file.getvalue() self.assertEqual(client_log, server_log) labels = [] for line in client_log.splitlines(): labels.append(line.split()[0]) self.assertEqual( labels, [ "SERVER_HANDSHAKE_TRAFFIC_SECRET", "CLIENT_HANDSHAKE_TRAFFIC_SECRET", "SERVER_TRAFFIC_SECRET_0", "CLIENT_TRAFFIC_SECRET_0", ], ) def test_connect_with_cert_chain(self): with client_and_server(server_certfile=SERVER_CERTFILE_WITH_CHAIN) as ( client, server, ): # check handshake completed self.check_handshake(client=client, server=server) def test_connect_with_cipher_suite_aes128(self): with client_and_server( client_options={"cipher_suites": [tls.CipherSuite.AES_128_GCM_SHA256]} ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # check selected cipher suite self.assertEqual( client.tls.key_schedule.cipher_suite, tls.CipherSuite.AES_128_GCM_SHA256 ) self.assertEqual( server.tls.key_schedule.cipher_suite, tls.CipherSuite.AES_128_GCM_SHA256 ) def test_connect_with_cipher_suite_aes256(self): with client_and_server( client_options={"cipher_suites": [tls.CipherSuite.AES_256_GCM_SHA384]} ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # check selected cipher suite self.assertEqual( client.tls.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) self.assertEqual( server.tls.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_connect_with_cipher_suite_chacha20(self): with client_and_server( client_options={"cipher_suites": [tls.CipherSuite.CHACHA20_POLY1305_SHA256]} ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # check selected cipher suite self.assertEqual( client.tls.key_schedule.cipher_suite, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ) self.assertEqual( server.tls.key_schedule.cipher_suite, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ) def test_connect_with_custom_packet_size(self): """ Check that the size of the initial QUIC packet corresponds to the packet size configuration. """ client_configuration = QuicConfiguration(is_client=True, max_datagram_size=1480) client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) client = QuicConnection(configuration=client_configuration) now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1480]) def test_connect_without_loss(self): """ Check connection is established in the absence of loss. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives INITIAL, sends INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertAlmostEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # handshake continues normally now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) self.assertAlmostEqual(client.get_timer(), 0.425) self.assertSentPackets(client, [0, 1, 1]) self.assertEvents( client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS ) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.425) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 60.2) # idle timeout self.assertSentPackets(client, [0, 0, 1]) self.assertEvents(client, []) def test_connect_with_loss_1(self): """ Check connection is established even in the client's INITIAL is lost. The client's PTO fires, triggering retransmission. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # INITIAL is lost and retransmitted now = client.get_timer() client.handle_timer(now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertAlmostEqual(client.get_timer(), 0.6) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives INITIAL, sends INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertAlmostEqual(server.get_timer(), 0.45) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # handshake continues normally now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) self.assertAlmostEqual(client.get_timer(), 0.625) self.assertSentPackets(client, [0, 1, 1]) self.assertEvents( client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS ) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.625) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout self.assertSentPackets(client, [0, 0, 1]) self.assertEvents(client, []) def test_connect_with_loss_2(self): """ Check connection is established even in the server's INITIAL is lost. The client receives HANDSHAKE packets before it has the corresponding keys and decides to retransmit its own CRYPTO to speedup handshake completion. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives INITIAL, sends INITIAL + HANDSHAKE but first datagram # is lost now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # client only receives second datagram, retransmits INITIAL now += TICK client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertAlmostEqual(client.get_timer(), 0.3) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) self.assertPacketDropped(client, "key_unavailable") # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertAlmostEqual(server.get_timer(), 0.35) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, []) # handshake continues normally now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) self.assertAlmostEqual(client.get_timer(), 0.525) self.assertSentPackets(client, [0, 1, 1]) self.assertEvents( client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS ) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.525) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 60.3) # idle timeout self.assertSentPackets(client, [0, 0, 1]) self.assertEvents(client, []) def test_connect_with_loss_3(self): """ Check connection is established even in the server's INITIAL + HANDSHAKE are lost. The server receives duplicate CRYPTO and decides to retransmit its own CRYPTO to speedup handshake completion. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives INITIAL, sends INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # INITIAL + HANDSHAKE are lost, client retransmits INITIAL now = client.get_timer() client.handle_timer(now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertAlmostEqual(client.get_timer(), 0.6) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives duplicate INITIAL, retransmits INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.45) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, []) # handshake continues normally now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) self.assertAlmostEqual(client.get_timer(), 0.625) self.assertSentPackets(client, [0, 1, 1]) self.assertEvents( client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS ) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.625) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout self.assertSentPackets(client, [0, 0, 1]) self.assertEvents(client, []) def test_connect_with_loss_4(self): """ Check connection is established even in the server's HANDSHAKE is lost. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) self.assertSentPackets(client, [1, 0, 0]) self.assertEvents(client, []) # server receives INITIAL, sends ACK + INITIAL + HANDSHAKE but third # datagram is lost now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # client only receives the first datagram and sends ACKS now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertAlmostEqual(client.get_timer(), 0.325) self.assertSentPackets(client, [0, 1, 0]) self.assertEvents(client, [events.ProtocolNegotiated]) # client PTO - HANDSHAKE PING now = client.get_timer() client.handle_timer(now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [45]) self.assertAlmostEqual(client.get_timer(), 0.975) self.assertSentPackets(client, [0, 2, 0]) self.assertEvents(client, []) # server receives PING, discards INITIAL and sends ACK now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [48]) self.assertAlmostEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [0, 3, 0]) self.assertEvents(server, []) # ACKs are lost, server retransmits HANDSHAKE now = server.get_timer() server.handle_timer(now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200, 986]) self.assertAlmostEqual(server.get_timer(), 0.65) self.assertSentPackets(server, [0, 3, 0]) self.assertEvents(server, []) # handshake continues normally now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [329]) self.assertAlmostEqual(client.get_timer(), 0.95) self.assertSentPackets(client, [0, 3, 1]) self.assertEvents(client, HANDSHAKE_COMPLETED_EVENTS) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.675) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 60.4) # idle timeout self.assertSentPackets(client, [0, 0, 1]) self.assertEvents(client, []) def test_connect_with_loss_5(self): """ Check connection is established even in the server's HANDSHAKE_DONE is lost. """ with client_and_server(handshake=False) as (client, server): # client sends INITIAL now = 0.0 client.connect(SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) # server receives INITIAL, sends INITIAL + HANDSHAKE now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertSentPackets(server, [1, 2, 0]) self.assertEvents(server, [events.ProtocolNegotiated]) # client receives INITIAL + HANDSHAKE now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) client.receive_datagram(items[1][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), CLIENT_HANDSHAKE_DATAGRAM_SIZES) self.assertAlmostEqual(client.get_timer(), 0.425) self.assertSentPackets(client, [0, 1, 1]) self.assertEvents( client, [events.ProtocolNegotiated] + HANDSHAKE_COMPLETED_EVENTS ) # server completes handshake, but HANDSHAKE_DONE is lost now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [229]) self.assertAlmostEqual(server.get_timer(), 0.425) self.assertSentPackets(server, [0, 0, 1]) self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS) # server PTO - 1-RTT PING now = server.get_timer() server.handle_timer(now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [29]) self.assertAlmostEqual(server.get_timer(), 0.975) self.assertSentPackets(server, [0, 0, 2]) self.assertEvents(server, []) # client receives PING, sends ACK now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 0.425) self.assertSentPackets(client, [0, 1, 2]) self.assertEvents(client, []) # server receives ACK, retransmits HANDSHAKE_DONE now += TICK self.assertFalse(server._handshake_done_pending) server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) self.assertTrue(server._handshake_done_pending) items = server.datagrams_to_send(now=now) self.assertFalse(server._handshake_done_pending) self.assertEqual(datagram_sizes(items), [224]) self.assertAlmostEqual(server.get_timer(), 0.7625) self.assertSentPackets(server, [0, 0, 1]) # FIXME: the server re-emits the ConnectionIdIssued events self.assertEvents(server, HANDSHAKE_COMPLETED_EVENTS[1:]) now += TICK client.receive_datagram(items[0][0], SERVER_ADDR, now=now) items = client.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), [32]) self.assertAlmostEqual(client.get_timer(), 0.425) self.assertSentPackets(client, [0, 0, 3]) self.assertEvents(client, []) now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) self.assertEqual(datagram_sizes(items), []) self.assertAlmostEqual(server.get_timer(), 60.625) # idle timeout self.assertSentPackets(server, [0, 0, 0]) self.assertEvents(server, []) def test_initial_that_is_too_small(self): """ Check that a too-small initial datagram from the client is dropped by the server. """ client = create_standalone_client(self) server = create_standalone_server( self, original_destination_connection_id=client.original_destination_connection_id, ) builder = QuicPacketBuilder( host_cid=client._peer_cid.cid, is_client=False, max_datagram_size=1000, # too small! peer_cid=client.host_cid, version=client._version, ) crypto = CryptoPair() crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) for datagram in builder.flush()[0]: server.receive_datagram(datagram, SERVER_ADDR, now=time.time()) # Look for the drop event. self.assertPacketDropped(server, "initial_packet_datagram_too_small") def test_connect_with_no_crypto_frame(self): def patch(client): """ Patch client to send PING instead of CRYPTO. """ client._push_crypto_data = client._send_probe with client_and_server(client_patch=patch) as (client, server): self.assertEqual( server._close_event.reason_phrase, "Packet contains no CRYPTO frame", ) def test_connect_with_no_transport_parameters(self): def patch(client): """ Patch client's TLS initialization to clear TLS extensions. """ real_initialize = client._initialize def patched_initialize(peer_cid: bytes): real_initialize(peer_cid) client.tls.handshake_extensions = [] client._initialize = patched_initialize with client_and_server(client_patch=patch) as (client, server): self.assertEqual( server._close_event.reason_phrase, "No QUIC transport parameters received", ) def test_connect_with_compatible_version_negotiation_1(self): """ The client only supports version 1. The server sets the Negotiated Version to version 1. """ with client_and_server( client_options={ "supported_versions": [QuicProtocolVersion.VERSION_1], }, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) def test_connect_with_compatible_version_negotiation_1_to_2(self): """ The client originally connects using version 1 but prefers version 2. The server sets the Negotiated Version to version 2. """ with client_and_server( client_options={ "original_version": QuicProtocolVersion.VERSION_1, "supported_versions": [ QuicProtocolVersion.VERSION_2, QuicProtocolVersion.VERSION_1, ], }, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) def test_connect_with_compatible_version_negotiation_2(self): """ The client only supports version 2. The server sets the Negotiated Version to version 2. """ with client_and_server( client_options={ "supported_versions": [QuicProtocolVersion.VERSION_2], }, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) def test_connect_with_compatible_version_negotiation_2_to_1(self): """ The client originally connects using version 2 but prefers version 1. The server sets the Negotiated Version to version 1. """ with client_and_server( client_options={ "original_version": QuicProtocolVersion.VERSION_2, "supported_versions": [ QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2, ], }, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server) self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) def test_connect_with_quantum_readiness(self): with client_and_server(client_options={"quantum_readiness_test": True}) as ( client, server, ): stream_id = client.get_next_available_stream_id() client.send_stream_data(stream_id, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) received = None while True: event = server.next_event() if isinstance(event, events.StreamDataReceived): received = event.data elif event is None: break self.assertEqual(received, b"hello") def test_connect_with_0rtt(self): client_ticket = None ticket_store = SessionTicketStore() def save_session_ticket(ticket): nonlocal client_ticket client_ticket = ticket with client_and_server( client_kwargs={"session_ticket_handler": save_session_ticket}, server_kwargs={"session_ticket_handler": ticket_store.add}, ) as (client, server): pass with client_and_server( client_options={"session_ticket": client_ticket}, server_kwargs={"session_ticket_fetcher": ticket_store.pop}, handshake=False, ) as (client, server): client.connect(SERVER_ADDR, now=time.time()) stream_id = client.get_next_available_stream_id() client.send_stream_data(stream_id, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) event = server.next_event() self.assertEqual(type(event), events.ProtocolNegotiated) event = server.next_event() self.assertEqual(type(event), events.StreamDataReceived) self.assertEqual(event.data, b"hello") def test_connect_with_0rtt_bad_max_early_data(self): client_ticket = None ticket_store = SessionTicketStore() def patch(server): """ Patch server's TLS initialization to set an invalid max_early_data value. """ real_initialize = server._initialize def patched_initialize(peer_cid: bytes): real_initialize(peer_cid) server.tls._max_early_data = 12345 server._initialize = patched_initialize def save_session_ticket(ticket): nonlocal client_ticket client_ticket = ticket with client_and_server( client_kwargs={"session_ticket_handler": save_session_ticket}, server_kwargs={"session_ticket_handler": ticket_store.add}, server_patch=patch, ) as (client, server): # check handshake failed event = client.next_event() self.assertIsNone(event) def test_change_connection_id(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) # the client changes connection ID client.change_connection_id() self.assertEqual(transfer(client, server), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] ) # the server provides a new connection ID self.assertEqual(transfer(server, client), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] ) def test_change_connection_id_retransmit_new_connection_id(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) # the client changes connection ID client.change_connection_id() self.assertEqual(transfer(client, server), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] ) # the server provides a new connection ID, NEW_CONNECTION_ID is lost self.assertEqual(drop(server), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] ) # NEW_CONNECTION_ID is retransmitted server._on_new_connection_id_delivery( QuicDeliveryState.LOST, server._host_cids[-1] ) self.assertEqual(transfer(server, client), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] ) def test_change_connection_id_retransmit_retire_connection_id(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) # the client changes connection ID, RETIRE_CONNECTION_ID is lost client.change_connection_id() self.assertEqual(drop(client), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] ) # RETIRE_CONNECTION_ID is retransmitted client._on_retire_connection_id_delivery(QuicDeliveryState.LOST, 0) self.assertEqual(transfer(client, server), 1) # the server provides a new connection ID self.assertEqual(transfer(server, client), 1) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] ) def test_get_next_available_stream_id(self): with client_and_server() as (client, server): # client stream_id = client.get_next_available_stream_id() self.assertEqual(stream_id, 0) client.send_stream_data(stream_id, b"hello") stream_id = client.get_next_available_stream_id() self.assertEqual(stream_id, 4) client.send_stream_data(stream_id, b"hello") stream_id = client.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 2) client.send_stream_data(stream_id, b"hello") stream_id = client.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 6) client.send_stream_data(stream_id, b"hello") # server stream_id = server.get_next_available_stream_id() self.assertEqual(stream_id, 1) server.send_stream_data(stream_id, b"hello") stream_id = server.get_next_available_stream_id() self.assertEqual(stream_id, 5) server.send_stream_data(stream_id, b"hello") stream_id = server.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 3) server.send_stream_data(stream_id, b"hello") stream_id = server.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 7) server.send_stream_data(stream_id, b"hello") def test_datagram_frame(self): with client_and_server( client_options={"max_datagram_frame_size": 65536}, server_options={"max_datagram_frame_size": 65536}, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server, alpn_protocol=None) # send datagram client.send_datagram_frame(b"hello") self.assertEqual(transfer(client, server), 1) event = server.next_event() self.assertEqual(type(event), events.DatagramFrameReceived) self.assertEqual(event.data, b"hello") def test_datagram_frame_2(self): # payload which exactly fills an entire packet payload = b"Z" * 1170 with client_and_server( client_options={"max_datagram_frame_size": 65536}, server_options={"max_datagram_frame_size": 65536}, ) as (client, server): # check handshake completed self.check_handshake(client=client, server=server, alpn_protocol=None) # queue 20 datagrams for i in range(20): client.send_datagram_frame(payload) # client can only 11 datagrams are sent due to congestion control self.assertEqual(transfer(client, server), 11) for i in range(11): event = server.next_event() self.assertEqual(type(event), events.DatagramFrameReceived) self.assertEqual(event.data, payload) # server sends ACK self.assertEqual(transfer(server, client), 1) # client sends remaining datagrams self.assertEqual(transfer(client, server), 9) for i in range(9): event = server.next_event() self.assertEqual(type(event), events.DatagramFrameReceived) self.assertEqual(event.data, payload) def test_decryption_error(self): with client_and_server() as (client, server): # mess with encryption key server._cryptos[tls.Epoch.ONE_RTT].send.setup( cipher_suite=tls.CipherSuite.AES_128_GCM_SHA256, secret=bytes(48), version=server._version, ) # server sends close server.close(error_code=QuicErrorCode.NO_ERROR) for data, addr in server.datagrams_to_send(now=time.time()): client.receive_datagram(data, SERVER_ADDR, now=time.time()) def test_tls_error(self): def patch(client): """ Patch the client's TLS initialization to send invalid TLS version. """ real_initialize = client._initialize def patched_initialize(peer_cid: bytes): real_initialize(peer_cid) client.tls._supported_versions = [tls.TLS_VERSION_1_3_DRAFT_28] client._initialize = patched_initialize # handshake fails with client_and_server(client_patch=patch) as (client, server): timer_at = server.get_timer() server.handle_timer(timer_at) event = server.next_event() self.assertEqual(type(event), events.ConnectionTerminated) self.assertEqual(event.error_code, 326) self.assertEqual(event.frame_type, QuicFrameType.CRYPTO) self.assertEqual(event.reason_phrase, "No supported protocol version") def test_receive_datagram_garbage(self): client = create_standalone_client(self) datagram = binascii.unhexlify("c00000000080") client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) def test_receive_datagram_reserved_bits_non_zero(self): client = create_standalone_client(self) builder = QuicPacketBuilder( host_cid=client._peer_cid.cid, is_client=False, max_datagram_size=SMALLEST_MAX_DATAGRAM_SIZE, peer_cid=client.host_cid, version=client._version, ) crypto = CryptoPair() crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) crypto.encrypt_packet_real = crypto.encrypt_packet def encrypt_packet(plain_header, plain_payload, packet_number): # mess with reserved bits plain_header = bytes([plain_header[0] | 0x0C]) + plain_header[1:] return crypto.encrypt_packet_real( plain_header, plain_payload, packet_number ) crypto.encrypt_packet = encrypt_packet builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) for datagram in builder.flush()[0]: client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) self.assertEqual(drop(client), 1) self.assertEqual( client._close_event, events.ConnectionTerminated( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Reserved bits must be zero", ), ) def test_receive_datagram_wrong_version(self): client = create_standalone_client(self) builder = QuicPacketBuilder( host_cid=client._peer_cid.cid, is_client=False, max_datagram_size=SMALLEST_MAX_DATAGRAM_SIZE, peer_cid=client.host_cid, version=0x1A2A3A4A, ) crypto = CryptoPair() crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) for datagram in builder.flush()[0]: client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) self.assertEqual(drop(client), 0) self.assertPacketDropped(client, "unsupported_version") def test_receive_datagram_retry(self): client = create_standalone_client(self) client.receive_datagram( encode_quic_retry( version=client._version, source_cid=binascii.unhexlify("85abb547bf28be97"), destination_cid=client.host_cid, original_destination_cid=client._peer_cid.cid, retry_token=bytes(16), ), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 1) def test_receive_datagram_retry_wrong_destination_cid(self): client = create_standalone_client(self) # The client does not reply to a retry packet with a wrong destination CID. client.receive_datagram( encode_quic_retry( version=client._version, source_cid=binascii.unhexlify("85abb547bf28be97"), destination_cid=binascii.unhexlify("c98343fe8f5f0ff4"), original_destination_cid=client._peer_cid.cid, retry_token=bytes(16), ), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 0) self.assertPacketDropped(client, "unknown_connection_id") def test_receive_datagram_retry_wrong_integrity_tag(self): client = create_standalone_client(self) client.receive_datagram( encode_quic_retry( version=client._version, source_cid=binascii.unhexlify("85abb547bf28be97"), destination_cid=client.host_cid, original_destination_cid=client._peer_cid.cid, retry_token=bytes(16), )[0:-16] + bytes(16), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 0) def test_handle_ack_frame_ecn(self): client = create_standalone_client(self) client._handle_ack_frame( client_receive_context(client), QuicFrameType.ACK_ECN, Buffer(data=b"\x00\x02\x00\x00\x00\x00\x00"), ) def test_handle_connection_close_frame(self): with client_and_server() as (client, server): server.close( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.ACK, reason_phrase="illegal ACK frame", ) self.assertEqual(roundtrip(server, client), (1, 0)) self.assertEqual( client._close_event, events.ConnectionTerminated( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.ACK, reason_phrase="illegal ACK frame", ), ) def test_handle_connection_close_frame_app(self): with client_and_server() as (client, server): server.close(error_code=QuicErrorCode.NO_ERROR, reason_phrase="goodbye") self.assertEqual(roundtrip(server, client), (1, 0)) self.assertEqual( client._close_event, events.ConnectionTerminated( error_code=QuicErrorCode.NO_ERROR, frame_type=None, reason_phrase="goodbye", ), ) def test_handle_connection_close_frame_app_not_utf8(self): client = create_standalone_client(self) client._handle_connection_close_frame( client_receive_context(client), QuicFrameType.APPLICATION_CLOSE, Buffer(data=binascii.unhexlify("0008676f6f6462798200")), ) self.assertEqual( client._close_event, events.ConnectionTerminated( error_code=QuicErrorCode.NO_ERROR, frame_type=None, reason_phrase="", ), ) def test_handle_crypto_frame_over_largest_offset(self): with client_and_server() as (client, server): # client receives offset + length > 2^62 - 1 with self.assertRaises(QuicConnectionError) as cm: client._handle_crypto_frame( client_receive_context(client), QuicFrameType.CRYPTO, Buffer(data=encode_uint_var(UINT_VAR_MAX) + encode_uint_var(1)), ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "offset + length cannot exceed 2^62 - 1" ) def test_excessive_crypto_buffering(self): with client_and_server() as (client, server): # Client receives data that causes more than 512K of buffering; note that # because the stream buffer is a single buffer and not a set of fragments, # the total buffering size depends not on how much data is received, but # how much buffering is needed. We send fragments of only 100 bytes # at offsets 10000, 20000, 30000 etc. highest_good_offset = 0 with self.assertRaises(QuicConnectionError) as cm: # We don't start at zero as we want to force buffering, not cause # a TLS error. for offset in range(10000, 1000000, 10000): client._handle_crypto_frame( client_receive_context(client), QuicFrameType.CRYPTO, Buffer( data=encode_uint_var(offset) + encode_uint_var(100) + b"\x00" * 100 ), ) highest_good_offset = offset self.assertEqual( cm.exception.error_code, QuicErrorCode.CRYPTO_BUFFER_EXCEEDED ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual(highest_good_offset, (MAX_PENDING_CRYPTO // 10000) * 10000) def test_handle_data_blocked_frame(self): with client_and_server() as (client, server): # client receives DATA_BLOCKED: 12345 client._handle_data_blocked_frame( client_receive_context(client), QuicFrameType.DATA_BLOCKED, Buffer(data=encode_uint_var(12345)), ) def test_handle_datagram_frame(self): client = create_standalone_client(self, max_datagram_frame_size=6) client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM, Buffer(data=b"hello"), ) self.assertEqual( client.next_event(), events.DatagramFrameReceived(data=b"hello") ) def test_handle_datagram_frame_not_allowed(self): client = create_standalone_client(self, max_datagram_frame_size=None) with self.assertRaises(QuicConnectionError) as cm: client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM, Buffer(data=b"hello"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM) self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") def test_handle_datagram_frame_too_large(self): client = create_standalone_client(self, max_datagram_frame_size=5) with self.assertRaises(QuicConnectionError) as cm: client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM, Buffer(data=b"hello"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM) self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") def test_handle_datagram_frame_with_length(self): client = create_standalone_client(self, max_datagram_frame_size=7) client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM_WITH_LENGTH, Buffer(data=b"\x05hellojunk"), ) self.assertEqual( client.next_event(), events.DatagramFrameReceived(data=b"hello") ) def test_handle_datagram_frame_with_length_not_allowed(self): client = create_standalone_client(self, max_datagram_frame_size=None) with self.assertRaises(QuicConnectionError) as cm: client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM_WITH_LENGTH, Buffer(data=b"\x05hellojunk"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM_WITH_LENGTH) self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") def test_handle_datagram_frame_with_length_too_large(self): client = create_standalone_client(self, max_datagram_frame_size=6) with self.assertRaises(QuicConnectionError) as cm: client._handle_datagram_frame( client_receive_context(client), QuicFrameType.DATAGRAM_WITH_LENGTH, Buffer(data=b"\x05hellojunk"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM_WITH_LENGTH) self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") def test_handle_handshake_done_not_allowed(self): with client_and_server() as (client, server): # server receives HANDSHAKE_DONE frame with self.assertRaises(QuicConnectionError) as cm: server._handle_handshake_done_frame( client_receive_context(server), QuicFrameType.HANDSHAKE_DONE, Buffer(data=b""), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.HANDSHAKE_DONE) self.assertEqual( cm.exception.reason_phrase, "Clients must not send HANDSHAKE_DONE frames", ) def test_handle_max_data_frame(self): with client_and_server() as (client, server): self.assertEqual(client._remote_max_data, 1048576) # client receives MAX_DATA raising limit client._handle_max_data_frame( client_receive_context(client), QuicFrameType.MAX_DATA, Buffer(data=encode_uint_var(1048577)), ) self.assertEqual(client._remote_max_data, 1048577) def test_handle_max_stream_data_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 stream = client._get_or_create_stream_for_send(stream_id=0) self.assertEqual(stream.max_stream_data_remote, 1048576) # client receives MAX_STREAM_DATA raising limit client._handle_max_stream_data_frame( client_receive_context(client), QuicFrameType.MAX_STREAM_DATA, Buffer(data=b"\x00" + encode_uint_var(1048577)), ) self.assertEqual(stream.max_stream_data_remote, 1048577) # client receives MAX_STREAM_DATA lowering limit client._handle_max_stream_data_frame( client_receive_context(client), QuicFrameType.MAX_STREAM_DATA, Buffer(data=b"\x00" + encode_uint_var(1048575)), ) self.assertEqual(stream.max_stream_data_remote, 1048577) def test_handle_max_stream_data_frame_receive_only(self): with client_and_server() as (client, server): # server creates unidirectional stream 3 server.send_stream_data(stream_id=3, data=b"hello") # client receives MAX_STREAM_DATA: 3, 1 with self.assertRaises(QuicConnectionError) as cm: client._handle_max_stream_data_frame( client_receive_context(client), QuicFrameType.MAX_STREAM_DATA, Buffer(data=b"\x03\x01"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.MAX_STREAM_DATA) self.assertEqual(cm.exception.reason_phrase, "Stream is receive-only") def test_handle_max_streams_bidi_frame(self): with client_and_server() as (client, server): self.assertEqual(client._remote_max_streams_bidi, 128) # client receives MAX_STREAMS_BIDI raising limit client._handle_max_streams_bidi_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_BIDI, Buffer(data=encode_uint_var(129)), ) self.assertEqual(client._remote_max_streams_bidi, 129) # client receives MAX_STREAMS_BIDI lowering limit client._handle_max_streams_bidi_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_BIDI, Buffer(data=encode_uint_var(127)), ) self.assertEqual(client._remote_max_streams_bidi, 129) # client receives invalid MAX_STREAMS_BIDI with self.assertRaises(QuicConnectionError) as cm: client._handle_max_streams_bidi_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_BIDI, Buffer(data=encode_uint_var(STREAM_COUNT_MAX + 1)), ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR, ) self.assertEqual(cm.exception.frame_type, QuicFrameType.MAX_STREAMS_BIDI) self.assertEqual( cm.exception.reason_phrase, "Maximum Streams cannot exceed 2^60" ) def test_handle_max_streams_uni_frame(self): with client_and_server() as (client, server): self.assertEqual(client._remote_max_streams_uni, 128) # client receives MAX_STREAMS_UNI raising limit client._handle_max_streams_uni_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_UNI, Buffer(data=encode_uint_var(129)), ) self.assertEqual(client._remote_max_streams_uni, 129) # client receives MAX_STREAMS_UNI raising limit client._handle_max_streams_uni_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_UNI, Buffer(data=encode_uint_var(127)), ) self.assertEqual(client._remote_max_streams_uni, 129) # client receives invalid MAX_STREAMS_UNI with self.assertRaises(QuicConnectionError) as cm: client._handle_max_streams_uni_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_UNI, Buffer(data=encode_uint_var(STREAM_COUNT_MAX + 1)), ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR, ) self.assertEqual(cm.exception.frame_type, QuicFrameType.MAX_STREAMS_UNI) self.assertEqual( cm.exception.reason_phrase, "Maximum Streams cannot exceed 2^60" ) def test_handle_new_connection_id_duplicate(self): with client_and_server() as (client, server): buf = new_connection_id(sequence_number=7) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual(client._peer_cid.sequence_number, 0) self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) def test_handle_new_connection_id_over_limit(self): with client_and_server() as (client, server): buf = new_connection_id(sequence_number=8) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual( cm.exception.error_code, QuicErrorCode.CONNECTION_ID_LIMIT_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID) self.assertEqual( cm.exception.reason_phrase, "Too many active connection IDs" ) def test_handle_new_connection_id_with_retire_prior_to(self): with client_and_server() as (client, server): buf = new_connection_id(sequence_number=8, retire_prior_to=2, capacity=42) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual(client._peer_cid.sequence_number, 2) self.assertEqual( sequence_numbers(client._peer_cid_available), [3, 4, 5, 6, 7, 8] ) def test_handle_new_connection_id_with_retire_prior_to_lower(self): with client_and_server() as (client, server): buf = new_connection_id(sequence_number=80, retire_prior_to=80) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual(client._peer_cid.sequence_number, 80) self.assertEqual(sequence_numbers(client._peer_cid_available), []) buf = new_connection_id(sequence_number=30, retire_prior_to=30) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual(client._peer_cid.sequence_number, 80) self.assertEqual(sequence_numbers(client._peer_cid_available), []) def test_handle_excessive_new_connection_id_retires(self): with client_and_server() as (client, server): for i in range(25): sequence_number = 8 + i buf = new_connection_id( sequence_number=sequence_number, retire_prior_to=sequence_number ) # client receives NEW_CONNECTION_ID client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) # So far, so good! We should be at the (default) limit of 4*8 pending # retirements. self.assertEqual(len(client._retire_connection_ids), 32) # Now we will go one too many! sequence_number = 8 + 25 buf = new_connection_id( sequence_number=sequence_number, retire_prior_to=sequence_number ) with self.assertRaises(QuicConnectionError) as cm: client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual( cm.exception.error_code, QuicErrorCode.CONNECTION_ID_LIMIT_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID) self.assertEqual( cm.exception.reason_phrase, "Too many pending retired connection IDs" ) def test_handle_new_connection_id_with_connection_id_invalid(self): with client_and_server() as (client, server): buf = new_connection_id( sequence_number=8, retire_prior_to=2, connection_id=bytes(21) ) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR, ) self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID) self.assertEqual( cm.exception.reason_phrase, "Length must be greater than 0 and less than 20", ) def test_handle_new_connection_id_with_retire_prior_to_invalid(self): with client_and_server() as (client, server): buf = new_connection_id(sequence_number=8, retire_prior_to=9) # client receives NEW_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: client._handle_new_connection_id_frame( client_receive_context(client), QuicFrameType.NEW_CONNECTION_ID, buf, ) self.assertEqual( cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION, ) self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_CONNECTION_ID) self.assertEqual( cm.exception.reason_phrase, "Retire Prior To is greater than Sequence Number", ) def test_handle_new_token_frame(self): new_token = None def token_handler(token): nonlocal new_token new_token = token with client_and_server(client_kwargs={"token_handler": token_handler}) as ( client, server, ): # client receives NEW_TOKEN client._handle_new_token_frame( client_receive_context(client), QuicFrameType.NEW_TOKEN, Buffer(data=binascii.unhexlify("080102030405060708")), ) self.assertEqual(new_token, binascii.unhexlify("0102030405060708")) def test_handle_new_token_frame_from_client(self): with client_and_server() as (client, server): # server receives NEW_TOKEN with self.assertRaises(QuicConnectionError) as cm: server._handle_new_token_frame( client_receive_context(client), QuicFrameType.NEW_TOKEN, Buffer(data=binascii.unhexlify("080102030405060708")), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_TOKEN) self.assertEqual( cm.exception.reason_phrase, "Clients must not send NEW_TOKEN frames" ) def test_handle_path_challenge_frame(self): with client_and_server() as (client, server): # client changes address and sends some data client.send_stream_data(0, b"01234567") for data, addr in client.datagrams_to_send(now=time.time()): server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) # check paths self.assertEqual(len(server._network_paths), 2) self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) self.assertFalse(server._network_paths[0].is_validated) self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[1].is_validated) # server sends PATH_CHALLENGE and receives PATH_RESPONSE for data, addr in server.datagrams_to_send(now=time.time()): client.receive_datagram(data, SERVER_ADDR, now=time.time()) for data, addr in client.datagrams_to_send(now=time.time()): server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) # check paths self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) self.assertTrue(server._network_paths[0].is_validated) self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[1].is_validated) def test_handle_path_challenge_response_on_different_path(self): with client_and_server() as (client, server): # client changes address and sends some data client.send_stream_data(0, b"01234567") for data, addr in client.datagrams_to_send(now=time.time()): server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) # check paths self.assertEqual(len(server._network_paths), 2) self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) self.assertFalse(server._network_paths[0].is_validated) self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[1].is_validated) # server sends PATH_CHALLENGE and receives PATH_RESPONSE on the 1234 # path instead of the expected 2345 path. for data, addr in server.datagrams_to_send(now=time.time()): client.receive_datagram(data, SERVER_ADDR, now=time.time()) for data, addr in client.datagrams_to_send(now=time.time()): server.receive_datagram(data, ("1.2.3.4", 1234), now=time.time()) # check paths; note that the order is backwards from the prior test # as receiving on 1234 promotes it to first in the list self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[0].is_validated) self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 2345)) self.assertTrue(server._network_paths[1].is_validated) def test_local_path_challenges_are_bounded(self): with client_and_server() as (client, server): for i in range(MAX_LOCAL_CHALLENGES + 2): server._add_local_challenge( int.to_bytes(i, 8, "big"), QuicNetworkPath(f"1.2.3.{i}") ) self.assertEqual(len(server._local_challenges), MAX_LOCAL_CHALLENGES) for i in range(2, MAX_LOCAL_CHALLENGES + 2): self.assertEqual( server._local_challenges[int.to_bytes(i, 8, "big")].addr, f"1.2.3.{i}", ) def test_handle_path_response_frame_bad(self): with client_and_server() as (client, server): # server receives unsolicited PATH_RESPONSE with self.assertRaises(QuicConnectionError) as cm: server._handle_path_response_frame( client_receive_context(client), QuicFrameType.PATH_RESPONSE, Buffer(data=b"\x11\x22\x33\x44\x55\x66\x77\x88"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.PATH_RESPONSE) def test_handle_padding_frame(self): client = create_standalone_client(self) # no more padding buf = Buffer(data=b"") client._handle_padding_frame( client_receive_context(client), QuicFrameType.PADDING, buf ) self.assertEqual(buf.tell(), 0) # padding until end buf = Buffer(data=bytes(10)) client._handle_padding_frame( client_receive_context(client), QuicFrameType.PADDING, buf ) self.assertEqual(buf.tell(), 10) # padding then something else buf = Buffer(data=bytes(10) + b"\x01") client._handle_padding_frame( client_receive_context(client), QuicFrameType.PADDING, buf ) self.assertEqual(buf.tell(), 10) def test_handle_reset_stream_frame(self): stream_id = 0 with client_and_server() as (client, server): # client creates bidirectional stream client.send_stream_data(stream_id=stream_id, data=b"hello") consume_events(client) # client receives RESET_STREAM client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.INTERNAL_ERROR) + encode_uint_var(0) ), ) event = client.next_event() self.assertEqual(type(event), events.StreamReset) self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) self.assertEqual(event.stream_id, stream_id) def test_handle_reset_stream_frame_final_size_error(self): stream_id = 0 with client_and_server() as (client, server): # client creates bidirectional stream client.send_stream_data(stream_id=stream_id, data=b"hello") consume_events(client) # client receives RESET_STREAM at offset 8 client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.NO_ERROR) + encode_uint_var(8) ), ) event = client.next_event() self.assertEqual(type(event), events.StreamReset) self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(event.stream_id, stream_id) # client receives RESET_STREAM at offset 5 with self.assertRaises(QuicConnectionError) as cm: client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.NO_ERROR) + encode_uint_var(5) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FINAL_SIZE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.RESET_STREAM) self.assertEqual(cm.exception.reason_phrase, "Cannot change final size") def test_handle_reset_stream_frame_over_max_data(self): stream_id = 0 with client_and_server() as (client, server): # client creates bidirectional stream client.send_stream_data(stream_id=stream_id, data=b"hello") consume_events(client) # artificially raise received data counter client._local_max_data.used = client._local_max_data.value # client receives RESET_STREAM frame with self.assertRaises(QuicConnectionError) as cm: client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.NO_ERROR) + encode_uint_var(1) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.RESET_STREAM) self.assertEqual(cm.exception.reason_phrase, "Over connection data limit") def test_handle_reset_stream_frame_over_max_stream_data(self): stream_id = 0 with client_and_server() as (client, server): # client creates bidirectional stream client.send_stream_data(stream_id=stream_id, data=b"hello") consume_events(client) # client receives STREAM frame with self.assertRaises(QuicConnectionError) as cm: client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.NO_ERROR) + encode_uint_var(client._local_max_stream_data_bidi_local + 1) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.RESET_STREAM) self.assertEqual(cm.exception.reason_phrase, "Over stream data limit") def test_handle_reset_stream_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 client.send_stream_data(stream_id=2, data=b"hello") # client receives RESET_STREAM with self.assertRaises(QuicConnectionError) as cm: client._handle_reset_stream_frame( client_receive_context(client), QuicFrameType.RESET_STREAM, Buffer(data=binascii.unhexlify("021100")), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.RESET_STREAM) self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") def test_handle_reset_stream_frame_twice(self): stream_id = 3 reset_stream_data = ( encode_uint_var(QuicFrameType.RESET_STREAM) + encode_uint_var(stream_id) + encode_uint_var(QuicErrorCode.INTERNAL_ERROR) + encode_uint_var(0) ) with client_and_server() as (client, server): # server creates unidirectional stream server.send_stream_data(stream_id=stream_id, data=b"hello") roundtrip(server, client) consume_events(client) # client receives RESET_STREAM client._payload_received(client_receive_context(client), reset_stream_data) event = client.next_event() self.assertEqual(type(event), events.StreamReset) self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) self.assertEqual(event.stream_id, stream_id) # stream gets discarded self.assertEqual(drop(client), 0) # client receives RESET_STREAM again client._payload_received(client_receive_context(client), reset_stream_data) event = client.next_event() self.assertIsNone(event) def test_handle_retire_connection_id_frame(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) # client receives RETIRE_CONNECTION_ID client._handle_retire_connection_id_frame( client_receive_context(client), QuicFrameType.RETIRE_CONNECTION_ID, Buffer(data=b"\x02"), ) self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 3, 4, 5, 6, 7, 8] ) def test_handle_retire_connection_id_frame_current_cid(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) # client receives RETIRE_CONNECTION_ID for the current CID with self.assertRaises(QuicConnectionError) as cm: client._handle_retire_connection_id_frame( client_receive_context(client), QuicFrameType.RETIRE_CONNECTION_ID, Buffer(data=b"\x00"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual( cm.exception.frame_type, QuicFrameType.RETIRE_CONNECTION_ID ) self.assertEqual( cm.exception.reason_phrase, "Cannot retire current connection ID" ) self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) def test_handle_retire_connection_id_frame_invalid_sequence_number(self): with client_and_server() as (client, server): self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) # client receives RETIRE_CONNECTION_ID with self.assertRaises(QuicConnectionError) as cm: client._handle_retire_connection_id_frame( client_receive_context(client), QuicFrameType.RETIRE_CONNECTION_ID, Buffer(data=b"\x08"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual( cm.exception.frame_type, QuicFrameType.RETIRE_CONNECTION_ID ) self.assertEqual( cm.exception.reason_phrase, "Cannot retire unknown connection ID" ) self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) def test_handle_stop_sending_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 client.send_stream_data(stream_id=0, data=b"hello") # client receives STOP_SENDING client._handle_stop_sending_frame( client_receive_context(client), QuicFrameType.STOP_SENDING, Buffer(data=b"\x00\x11"), ) # check events self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) self.assertEqual(type(client.next_event()), events.HandshakeCompleted) for i in range(7): self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) event = client.next_event() self.assertEqual(type(event), events.StopSendingReceived) self.assertEqual(event.stream_id, 0) self.assertEqual(event.error_code, 0x11) self.assertIsNone(client.next_event()) def test_handle_stop_sending_frame_receive_only(self): with client_and_server() as (client, server): # server creates unidirectional stream 3 server.send_stream_data(stream_id=3, data=b"hello") # client receives STOP_SENDING with self.assertRaises(QuicConnectionError) as cm: client._handle_stop_sending_frame( client_receive_context(client), QuicFrameType.STOP_SENDING, Buffer(data=b"\x03\x11"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.STOP_SENDING) self.assertEqual(cm.exception.reason_phrase, "Stream is receive-only") def test_handle_stream_frame_final_size_error(self): with client_and_server() as (client, server): frame_type = QuicFrameType.STREAM_BASE | 7 stream_id = 1 # client receives FIN at offset 8 client._handle_stream_frame( client_receive_context(client), frame_type, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(8) + encode_uint_var(0) ), ) # client receives FIN at offset 5 with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), frame_type, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(5) + encode_uint_var(0) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FINAL_SIZE_ERROR) self.assertEqual(cm.exception.frame_type, frame_type) self.assertEqual(cm.exception.reason_phrase, "Cannot change final size") def test_handle_stream_frame_over_largest_offset(self): with client_and_server() as (client, server): # client receives offset + length > 2^62 - 1 frame_type = QuicFrameType.STREAM_BASE | 6 stream_id = 1 with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), frame_type, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(UINT_VAR_MAX) + encode_uint_var(1) ), ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR ) self.assertEqual(cm.exception.frame_type, frame_type) self.assertEqual( cm.exception.reason_phrase, "offset + length cannot exceed 2^62 - 1" ) def test_handle_stream_frame_over_max_data(self): with client_and_server() as (client, server): # artificially raise received data counter client._local_max_data.used = client._local_max_data.value # client receives STREAM frame frame_type = QuicFrameType.STREAM_BASE | 4 stream_id = 1 with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), frame_type, Buffer(data=encode_uint_var(stream_id) + encode_uint_var(1)), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) self.assertEqual(cm.exception.frame_type, frame_type) self.assertEqual(cm.exception.reason_phrase, "Over connection data limit") def test_handle_stream_frame_over_max_stream_data(self): with client_and_server() as (client, server): # client receives STREAM frame frame_type = QuicFrameType.STREAM_BASE | 4 stream_id = 1 with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), frame_type, Buffer( data=encode_uint_var(stream_id) + encode_uint_var(client._local_max_stream_data_bidi_remote + 1) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) self.assertEqual(cm.exception.frame_type, frame_type) self.assertEqual(cm.exception.reason_phrase, "Over stream data limit") def test_handle_stream_frame_over_max_streams(self): with client_and_server() as (client, server): # client receives STREAM frame with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), QuicFrameType.STREAM_BASE, Buffer( data=encode_uint_var(client._local_max_stream_data_uni * 4 + 3) ), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_LIMIT_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) self.assertEqual(cm.exception.reason_phrase, "Too many streams open") def test_handle_stream_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 client.send_stream_data(stream_id=2, data=b"hello") # client receives STREAM frame with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), QuicFrameType.STREAM_BASE, Buffer(data=b"\x02"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") def test_handle_stream_frame_wrong_initiator(self): with client_and_server() as (client, server): # client receives STREAM frame with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_frame( client_receive_context(client), QuicFrameType.STREAM_BASE, Buffer(data=b"\x00"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) self.assertEqual(cm.exception.reason_phrase, "Wrong stream initiator") def test_handle_stream_data_blocked_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 client.send_stream_data(stream_id=0, data=b"hello") # client receives STREAM_DATA_BLOCKED client._handle_stream_data_blocked_frame( client_receive_context(client), QuicFrameType.STREAM_DATA_BLOCKED, Buffer(data=b"\x00\x01"), ) def test_handle_stream_data_blocked_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 client.send_stream_data(stream_id=2, data=b"hello") # client receives STREAM_DATA_BLOCKED with self.assertRaises(QuicConnectionError) as cm: client._handle_stream_data_blocked_frame( client_receive_context(client), QuicFrameType.STREAM_DATA_BLOCKED, Buffer(data=b"\x02\x01"), ) self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_DATA_BLOCKED) self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") def test_handle_streams_blocked_uni_frame(self): with client_and_server() as (client, server): # client receives STREAMS_BLOCKED_UNI: 0 client._handle_streams_blocked_frame( client_receive_context(client), QuicFrameType.STREAMS_BLOCKED_UNI, Buffer(data=b"\x00"), ) # client receives invalid STREAMS_BLOCKED_UNI with self.assertRaises(QuicConnectionError) as cm: client._handle_streams_blocked_frame( client_receive_context(client), QuicFrameType.STREAMS_BLOCKED_UNI, Buffer(data=encode_uint_var(STREAM_COUNT_MAX + 1)), ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR, ) self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAMS_BLOCKED_UNI) self.assertEqual( cm.exception.reason_phrase, "Maximum Streams cannot exceed 2^60" ) def test_parse_transport_parameters(self): client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( original_destination_connection_id=client.original_destination_connection_id ) ) client._parse_transport_parameters(data) def test_parse_transport_parameters_idle_time(self): # Remote idle of 10s and local idle of 60s. client = create_standalone_client(self) data = encode_transport_parameters( # Note the timeout parameter here is in milliseconds. QuicTransportParameters( original_destination_connection_id=client.original_destination_connection_id, max_idle_timeout=10000, ) ) client._parse_transport_parameters(data) self.assertAlmostEqual(client._remote_max_idle_timeout, 10.0) self.assertAlmostEqual(client._idle_timeout(), 10.0) # Remote idle of 120s and local idle of 60s. client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( original_destination_connection_id=client.original_destination_connection_id, max_idle_timeout=120000, ) ) client._parse_transport_parameters(data) self.assertAlmostEqual(client._remote_max_idle_timeout, 120.0) self.assertAlmostEqual(client._idle_timeout(), 60.0) # Remote idle of 1ms and local idle of 60s. # # Very low values are clamped to 3 * PTO; we use the default initial RTT of 0.1 # and since RTT is not initialized get_probe_timeout() will return 2 * the # initial RTT as the PTO, i.e. 0.2, so 3 * PTO == 0.6. client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( original_destination_connection_id=client.original_destination_connection_id, max_idle_timeout=1, ) ) client._parse_transport_parameters(data) self.assertAlmostEqual(client._remote_max_idle_timeout, 0.001) self.assertAlmostEqual(client._idle_timeout(), 0.6) def test_parse_transport_parameters_malformed(self): client = create_standalone_client(self) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(b"0") self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "Could not parse QUIC transport parameters" ) def test_parse_transport_parameters_with_bad_ack_delay_exponent(self): client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( ack_delay_exponent=21, original_destination_connection_id=client.original_destination_connection_id, ) ) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual(cm.exception.reason_phrase, "ack_delay_exponent must be <= 20") def test_parse_transport_parameters_with_bad_active_connection_id_limit(self): client = create_standalone_client(self) for active_connection_id_limit in [0, 1]: data = encode_transport_parameters( QuicTransportParameters( active_connection_id_limit=active_connection_id_limit, original_destination_connection_id=client.original_destination_connection_id, ) ) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "active_connection_id_limit must be no less than 2", ) def test_parse_transport_parameters_with_bad_max_ack_delay(self): client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( max_ack_delay=2**14, original_destination_connection_id=client.original_destination_connection_id, ) ) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual(cm.exception.reason_phrase, "max_ack_delay must be < 2^14") def test_parse_transport_parameters_with_bad_max_udp_payload_size(self): client = create_standalone_client(self) data = encode_transport_parameters( QuicTransportParameters( max_udp_payload_size=1199, original_destination_connection_id=client.original_destination_connection_id, ) ) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "max_udp_payload_size must be >= 1200" ) def test_parse_transport_parameters_with_bad_initial_source_connection_id(self): client = create_standalone_client(self) client._initial_source_connection_id = binascii.unhexlify("0011223344556677") data = encode_transport_parameters( QuicTransportParameters( initial_source_connection_id=binascii.unhexlify("1122334455667788"), original_destination_connection_id=client.original_destination_connection_id, ) ) with self.assertRaises(QuicConnectionError) as cm: client._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "initial_source_connection_id does not match" ) def test_parse_transport_parameters_with_bad_version_information_1(self): server = create_standalone_server(self) data = encode_transport_parameters( QuicTransportParameters( version_information=QuicVersionInformation( chosen_version=QuicProtocolVersion.VERSION_1, available_versions=[QuicProtocolVersion.VERSION_2], ) ) ) with self.assertRaises(QuicConnectionError) as cm: server._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "version_information's chosen_version is not included in " "available_versions", ) def test_parse_transport_parameters_with_bad_version_information_2(self): server = create_standalone_server(self) data = encode_transport_parameters( QuicTransportParameters( version_information=QuicVersionInformation( chosen_version=QuicProtocolVersion.VERSION_1, available_versions=[ QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2, ], ) ) ) server._crypto_packet_version = QuicProtocolVersion.VERSION_2 with self.assertRaises(QuicConnectionError) as cm: server._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.VERSION_NEGOTIATION_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "version_information's chosen_version does not match the version in use", ) def test_parse_transport_parameters_with_server_only_parameter(self): server = create_standalone_server(self) for active_connection_id_limit in [0, 1]: data = encode_transport_parameters( QuicTransportParameters( active_connection_id_limit=active_connection_id_limit, original_destination_connection_id=bytes(8), ) ) with self.assertRaises(QuicConnectionError) as cm: server._parse_transport_parameters(data) self.assertEqual( cm.exception.error_code, QuicErrorCode.TRANSPORT_PARAMETER_ERROR ) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual( cm.exception.reason_phrase, "original_destination_connection_id is not allowed for clients", ) def test_payload_received_empty(self): with client_and_server() as (client, server): # client receives empty payload with self.assertRaises(QuicConnectionError) as cm: client._payload_received(client_receive_context(client), b"") self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.PADDING) self.assertEqual(cm.exception.reason_phrase, "Packet contains no frames") def test_payload_received_padding_only(self): with client_and_server() as (client, server): # client receives padding only is_ack_eliciting, is_probing = client._payload_received( client_receive_context(client), b"\x00" * SMALLEST_MAX_DATAGRAM_SIZE ) self.assertFalse(is_ack_eliciting) self.assertTrue(is_probing) def test_payload_received_malformed_frame_type(self): with client_and_server() as (client, server): # client receives a malformed frame type with self.assertRaises(QuicConnectionError) as cm: client._payload_received(client_receive_context(client), b"\xff") self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR ) self.assertEqual(cm.exception.frame_type, None) self.assertEqual(cm.exception.reason_phrase, "Malformed frame type") def test_payload_received_unknown_frame_type(self): with client_and_server() as (client, server): # client receives unknown frame type with self.assertRaises(QuicConnectionError) as cm: client._payload_received(client_receive_context(client), b"\x1f") self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR ) self.assertEqual(cm.exception.frame_type, 0x1F) self.assertEqual(cm.exception.reason_phrase, "Unknown frame type") def test_payload_received_unexpected_frame_type(self): with client_and_server() as (client, server): # client receives CRYPTO frame in 0-RTT with self.assertRaises(QuicConnectionError) as cm: client._payload_received( client_receive_context(client, epoch=tls.Epoch.ZERO_RTT), b"\x06" ) self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) self.assertEqual(cm.exception.reason_phrase, "Unexpected frame type") def test_payload_received_malformed_frame(self): with client_and_server() as (client, server): # client receives malformed TRANSPORT_CLOSE frame with self.assertRaises(QuicConnectionError) as cm: client._payload_received( client_receive_context(client), b"\x1c\x00\x01" ) self.assertEqual( cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR ) self.assertEqual(cm.exception.frame_type, 0x1C) self.assertEqual(cm.exception.reason_phrase, "Failed to parse frame") def test_send_max_data_blocked_by_cc(self): with client_and_server() as (client, server): # Check congestion control. We do not check the congestion # window too strictly as its exact value depends on the size # of our ACKs, which depends on the execution time. self.assertEqual(client._loss.bytes_in_flight, 0) self.assertGreaterEqual(client._loss.congestion_window, 13530) self.assertLessEqual(client._loss.congestion_window, 13540) # artificially raise received data counter client._local_max_data_used = client._local_max_data self.assertEqual(server._remote_max_data, 1048576) # artificially raise bytes in flight client._loss._cc.bytes_in_flight = client._loss.congestion_window # MAX_DATA is not sent due to congestion control self.assertEqual(drop(client), 0) def test_send_max_data_retransmit(self): with client_and_server() as (client, server): # artificially raise received data counter client._local_max_data.used = client._local_max_data.value self.assertEqual(client._local_max_data.sent, 1048576) self.assertEqual(client._local_max_data.used, 1048576) self.assertEqual(client._local_max_data.value, 1048576) self.assertEqual(server._remote_max_data, 1048576) # MAX_DATA is sent and lost self.assertEqual(drop(client), 1) self.assertEqual(client._local_max_data.sent, 2097152) self.assertEqual(client._local_max_data.used, 1048576) self.assertEqual(client._local_max_data.value, 2097152) self.assertEqual(server._remote_max_data, 1048576) # MAX_DATA loss is detected client._on_connection_limit_delivery( QuicDeliveryState.LOST, client._local_max_data ) self.assertEqual(client._local_max_data.sent, 0) self.assertEqual(client._local_max_data.used, 1048576) self.assertEqual(client._local_max_data.value, 2097152) # MAX_DATA is retransmitted and acked roundtrip_until_done(client, server) self.assertEqual(client._local_max_data.sent, 2097152) self.assertEqual(client._local_max_data.used, 1048576) self.assertEqual(client._local_max_data.value, 2097152) self.assertEqual(server._remote_max_data, 2097152) def test_send_max_stream_data_retransmit(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 stream = client._get_or_create_stream_for_send(stream_id=0) client.send_stream_data(0, b"hello") self.assertEqual(stream.max_stream_data_local, 1048576) self.assertEqual(stream.max_stream_data_local_sent, 1048576) self.assertEqual(roundtrip(client, server), (1, 1)) # server sends data, just before raising MAX_STREAM_DATA server.send_stream_data(0, b"Z" * 524288) # 1048576 // 2 for i in range(10): roundtrip(server, client) self.assertEqual(stream.max_stream_data_local, 1048576) self.assertEqual(stream.max_stream_data_local_sent, 1048576) # server sends one more byte server.send_stream_data(0, b"Z") self.assertEqual(transfer(server, client), 1) # MAX_STREAM_DATA is sent and lost self.assertEqual(drop(client), 1) self.assertEqual(stream.max_stream_data_local, 2097152) self.assertEqual(stream.max_stream_data_local_sent, 2097152) client._on_max_stream_data_delivery(QuicDeliveryState.LOST, stream) self.assertEqual(stream.max_stream_data_local, 2097152) self.assertEqual(stream.max_stream_data_local_sent, 0) # MAX_STREAM_DATA is retransmitted and acked roundtrip_until_done(client, server) self.assertEqual(stream.max_stream_data_local, 2097152) self.assertEqual(stream.max_stream_data_local_sent, 2097152) def test_send_max_streams_retransmit(self): with client_and_server() as (client, server): # client opens 65 streams client.send_stream_data(4 * 64, b"Z") self.assertEqual(transfer(client, server), 1) self.assertEqual(client._remote_max_streams_bidi, 128) self.assertEqual(server._local_max_streams_bidi.sent, 128) self.assertEqual(server._local_max_streams_bidi.used, 65) self.assertEqual(server._local_max_streams_bidi.value, 128) # MAX_STREAMS is sent and lost self.assertEqual(drop(server), 1) self.assertEqual(client._remote_max_streams_bidi, 128) self.assertEqual(server._local_max_streams_bidi.sent, 256) self.assertEqual(server._local_max_streams_bidi.used, 65) self.assertEqual(server._local_max_streams_bidi.value, 256) # MAX_STREAMS loss is detected server._on_connection_limit_delivery( QuicDeliveryState.LOST, server._local_max_streams_bidi ) self.assertEqual(client._remote_max_streams_bidi, 128) self.assertEqual(server._local_max_streams_bidi.sent, 0) self.assertEqual(server._local_max_streams_bidi.used, 65) self.assertEqual(server._local_max_streams_bidi.value, 256) # MAX_STREAMS is retransmitted and acked roundtrip_until_done(server, client) self.assertEqual(client._remote_max_streams_bidi, 256) self.assertEqual(server._local_max_streams_bidi.sent, 256) self.assertEqual(server._local_max_streams_bidi.used, 65) self.assertEqual(server._local_max_streams_bidi.value, 256) def test_send_ping(self): with client_and_server() as (client, server): consume_events(client) # client sends ping, server ACKs it client.send_ping(uid=12345) self.assertEqual(roundtrip(client, server), (1, 1)) # check event event = client.next_event() self.assertEqual(type(event), events.PingAcknowledged) self.assertEqual(event.uid, 12345) def test_send_ping_retransmit(self): with client_and_server() as (client, server): consume_events(client) # client sends another ping, PING is lost client.send_ping(uid=12345) self.assertEqual(drop(client), 1) # PING is retransmitted and acked client._on_ping_delivery(QuicDeliveryState.LOST, (12345,)) self.assertEqual(roundtrip(client, server), (1, 1)) # check event event = client.next_event() self.assertEqual(type(event), events.PingAcknowledged) self.assertEqual(event.uid, 12345) def test_send_reset_stream(self): with client_and_server() as (client, server): # client creates bidirectional stream client.send_stream_data(0, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) # client resets stream client.reset_stream(0, QuicErrorCode.NO_ERROR) self.assertEqual(roundtrip(client, server), (1, 1)) def test_send_stop_sending(self): with client_and_server() as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # client creates bidirectional stream client.send_stream_data(0, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) # client sends STOP_SENDING frame client.stop_stream(0, QuicErrorCode.NO_ERROR) self.assertEqual(roundtrip(client, server), (1, 1)) # client receives STREAM_RESET frame event = client.next_event() self.assertEqual(type(event), events.StreamReset) self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(event.stream_id, 0) def test_send_stop_sending_uni_stream(self): with client_and_server() as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # client sends STOP_SENDING frame with self.assertRaises(ValueError) as cm: client.stop_stream(2, QuicErrorCode.NO_ERROR) self.assertEqual( str(cm.exception), "Cannot stop receiving on a local-initiated unidirectional stream", ) def test_send_stop_sending_unknown_stream(self): with client_and_server() as (client, server): # check handshake completed self.check_handshake(client=client, server=server) # client sends STOP_SENDING frame with self.assertRaises(ValueError) as cm: client.stop_stream(0, QuicErrorCode.NO_ERROR) self.assertEqual( str(cm.exception), "Cannot stop receiving on an unknown stream" ) def test_send_stream_data_over_max_streams_bidi(self): with client_and_server() as (client, server): # create streams for i in range(128): stream_id = i * 4 client.send_stream_data(stream_id, b"") self.assertFalse(client._streams[stream_id].is_blocked) self.assertEqual(len(client._streams_blocked_bidi), 0) self.assertEqual(len(client._streams_blocked_uni), 0) self.assertEqual(roundtrip(client, server), (0, 0)) # create one too many -> STREAMS_BLOCKED stream_id = 128 * 4 client.send_stream_data(stream_id, b"") self.assertTrue(client._streams[stream_id].is_blocked) self.assertEqual(len(client._streams_blocked_bidi), 1) self.assertEqual(len(client._streams_blocked_uni), 0) self.assertEqual(roundtrip(client, server), (1, 1)) # peer raises max streams client._handle_max_streams_bidi_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_BIDI, Buffer(data=encode_uint_var(129)), ) self.assertFalse(client._streams[stream_id].is_blocked) def test_send_stream_data_over_max_streams_uni(self): with client_and_server() as (client, server): # create streams for i in range(128): stream_id = i * 4 + 2 client.send_stream_data(stream_id, b"") self.assertFalse(client._streams[stream_id].is_blocked) self.assertEqual(len(client._streams_blocked_bidi), 0) self.assertEqual(len(client._streams_blocked_uni), 0) self.assertEqual(roundtrip(client, server), (0, 0)) # create one too many -> STREAMS_BLOCKED stream_id = 128 * 4 + 2 client.send_stream_data(stream_id, b"") self.assertTrue(client._streams[stream_id].is_blocked) self.assertEqual(len(client._streams_blocked_bidi), 0) self.assertEqual(len(client._streams_blocked_uni), 1) self.assertEqual(roundtrip(client, server), (1, 1)) # peer raises max streams client._handle_max_streams_uni_frame( client_receive_context(client), QuicFrameType.MAX_STREAMS_UNI, Buffer(data=encode_uint_var(129)), ) self.assertFalse(client._streams[stream_id].is_blocked) def test_send_stream_data_peer_initiated(self): with client_and_server() as (client, server): # server creates bidirectional stream server.send_stream_data(1, b"hello") self.assertEqual(roundtrip(server, client), (1, 1)) # server creates unidirectional stream server.send_stream_data(3, b"hello") self.assertEqual(roundtrip(server, client), (1, 1)) # client creates bidirectional stream client.send_stream_data(0, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) # client sends data on server-initiated bidirectional stream client.send_stream_data(1, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) # client creates unidirectional stream client.send_stream_data(2, b"hello") self.assertEqual(roundtrip(client, server), (1, 1)) # client tries to reset server-initiated unidirectional stream with self.assertRaises(ValueError) as cm: client.reset_stream(3, QuicErrorCode.NO_ERROR) self.assertEqual( str(cm.exception), "Cannot send data on peer-initiated unidirectional stream", ) # client tries to reset unknown server-initiated bidirectional stream with self.assertRaises(ValueError) as cm: client.reset_stream(5, QuicErrorCode.NO_ERROR) self.assertEqual( str(cm.exception), "Cannot send data on unknown peer-initiated stream" ) # client tries to send data on server-initiated unidirectional stream with self.assertRaises(ValueError) as cm: client.send_stream_data(3, b"hello") self.assertEqual( str(cm.exception), "Cannot send data on peer-initiated unidirectional stream", ) # client tries to send data on unknown server-initiated bidirectional stream with self.assertRaises(ValueError) as cm: client.send_stream_data(5, b"hello") self.assertEqual( str(cm.exception), "Cannot send data on unknown peer-initiated stream" ) def test_stream_direction(self): with client_and_server() as (client, server): for off in [0, 4, 8]: # Client-Initiated, Bidirectional self.assertTrue(client._stream_can_receive(off)) self.assertTrue(client._stream_can_send(off)) self.assertTrue(server._stream_can_receive(off)) self.assertTrue(server._stream_can_send(off)) # Server-Initiated, Bidirectional self.assertTrue(client._stream_can_receive(off + 1)) self.assertTrue(client._stream_can_send(off + 1)) self.assertTrue(server._stream_can_receive(off + 1)) self.assertTrue(server._stream_can_send(off + 1)) # Client-Initiated, Unidirectional self.assertFalse(client._stream_can_receive(off + 2)) self.assertTrue(client._stream_can_send(off + 2)) self.assertTrue(server._stream_can_receive(off + 2)) self.assertFalse(server._stream_can_send(off + 2)) # Server-Initiated, Unidirectional self.assertTrue(client._stream_can_receive(off + 3)) self.assertFalse(client._stream_can_send(off + 3)) self.assertFalse(server._stream_can_receive(off + 3)) self.assertTrue(server._stream_can_send(off + 3)) def test_version_negotiation_fail(self): client = create_standalone_client(self) # no common version, no retry client.receive_datagram( encode_quic_version_negotiation( source_cid=client._peer_cid.cid, destination_cid=client.host_cid, supported_versions=[0x1A2A3A4A], ), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 0) event = client.next_event() self.assertEqual(type(event), events.ConnectionTerminated) self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) self.assertEqual(event.frame_type, QuicFrameType.PADDING) self.assertEqual( event.reason_phrase, "Could not find a common protocol version" ) def test_version_negotiation_ignore(self): client = create_standalone_client(self) # version negotiation contains the client's version client.receive_datagram( encode_quic_version_negotiation( source_cid=client._peer_cid.cid, destination_cid=client.host_cid, supported_versions=[client._version], ), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 0) def test_version_negotiation_ignore_server(self): server = create_standalone_server(self) # Servers do not expect version negotiation packets. server.receive_datagram( encode_quic_version_negotiation( source_cid=server._peer_cid.cid, destination_cid=server.host_cid, supported_versions=[QuicProtocolVersion.VERSION_1], ), CLIENT_ADDR, now=time.time(), ) self.assertPacketDropped(server, "unexpected_packet") def test_version_negotiation_ok(self): client = create_standalone_client( self, supported_versions=[ QuicProtocolVersion.VERSION_1, 0x1A2A3A4A, ], ) # found a common version, retry client.receive_datagram( encode_quic_version_negotiation( source_cid=client._peer_cid.cid, destination_cid=client.host_cid, supported_versions=[0x1A2A3A4A], ), SERVER_ADDR, now=time.time(), ) self.assertEqual(drop(client), 1) def test_write_connection_close_early(self): client = create_standalone_client(self) builder = QuicPacketBuilder( host_cid=client.host_cid, is_client=True, max_datagram_size=SMALLEST_MAX_DATAGRAM_SIZE, peer_cid=client._peer_cid.cid, version=client._version, ) crypto = CryptoPair() crypto.setup_initial(client.host_cid, is_client=True, version=client._version) builder.start_packet(QuicPacketType.INITIAL, crypto) client._write_connection_close_frame( builder=builder, epoch=tls.Epoch.INITIAL, error_code=123, frame_type=None, reason_phrase="some reason", ) self.assertEqual( builder.quic_logger_frames, [ { "error_code": QuicErrorCode.APPLICATION_ERROR, "error_space": "transport", "frame_type": "connection_close", "raw_error_code": QuicErrorCode.APPLICATION_ERROR, "reason": "", "trigger_frame_type": QuicFrameType.PADDING, } ], ) class QuicNetworkPathTest(TestCase): def test_can_send(self): path = QuicNetworkPath(("1.2.3.4", 1234)) self.assertFalse(path.is_validated) # initially, cannot send any data self.assertTrue(path.can_send(0)) self.assertFalse(path.can_send(1)) # receive some data path.bytes_received += 1 self.assertTrue(path.can_send(0)) self.assertTrue(path.can_send(1)) self.assertTrue(path.can_send(2)) self.assertTrue(path.can_send(3)) self.assertFalse(path.can_send(4)) # send some data path.bytes_sent += 3 self.assertTrue(path.can_send(0)) self.assertFalse(path.can_send(1)) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_crypto_v1.py0000644000175100001770000003415100000000000020457 0ustar00runnerdocker00000000000000import binascii from unittest import TestCase, skipIf from aioquic.buffer import Buffer from aioquic.quic.crypto import ( INITIAL_CIPHER_SUITE, CryptoError, CryptoPair, derive_key_iv_hp, ) from aioquic.quic.packet import PACKET_FIXED_BIT, QuicProtocolVersion from aioquic.tls import CipherSuite from .utils import SKIP_TESTS PROTOCOL_VERSION = QuicProtocolVersion.VERSION_1 # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.5 CHACHA20_CLIENT_PACKET_NUMBER = 654360564 CHACHA20_CLIENT_PLAIN_HEADER = binascii.unhexlify("4200bff4") CHACHA20_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify("01") CHACHA20_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( "4cfe4189655e5cd55c41f69080575d7999c25a5bfb" ) # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.2 LONG_CLIENT_PACKET_NUMBER = 2 LONG_CLIENT_PLAIN_HEADER = binascii.unhexlify( "c300000001088394c8f03e5157080000449e00000002" ) LONG_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify( "060040f1010000ed0303ebf8fa56f12939b9584a3896472ec40bb863cfd3e868" "04fe3a47f06a2b69484c00000413011302010000c000000010000e00000b6578" "616d706c652e636f6dff01000100000a00080006001d00170018001000070005" "04616c706e000500050100000000003300260024001d00209370b2c9caa47fba" "baf4559fedba753de171fa71f50f1ce15d43e994ec74d748002b000302030400" "0d0010000e0403050306030203080408050806002d00020101001c0002400100" "3900320408ffffffffffffffff05048000ffff07048000ffff08011001048000" "75300901100f088394c8f03e51570806048000ffff" ) + bytes(917) LONG_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( "c000000001088394c8f03e5157080000449e7b9aec34d1b1c98dd7689fb8ec11" "d242b123dc9bd8bab936b47d92ec356c0bab7df5976d27cd449f63300099f399" "1c260ec4c60d17b31f8429157bb35a1282a643a8d2262cad67500cadb8e7378c" "8eb7539ec4d4905fed1bee1fc8aafba17c750e2c7ace01e6005f80fcb7df6212" "30c83711b39343fa028cea7f7fb5ff89eac2308249a02252155e2347b63d58c5" "457afd84d05dfffdb20392844ae812154682e9cf012f9021a6f0be17ddd0c208" "4dce25ff9b06cde535d0f920a2db1bf362c23e596d11a4f5a6cf3948838a3aec" "4e15daf8500a6ef69ec4e3feb6b1d98e610ac8b7ec3faf6ad760b7bad1db4ba3" "485e8a94dc250ae3fdb41ed15fb6a8e5eba0fc3dd60bc8e30c5c4287e53805db" "059ae0648db2f64264ed5e39be2e20d82df566da8dd5998ccabdae053060ae6c" "7b4378e846d29f37ed7b4ea9ec5d82e7961b7f25a9323851f681d582363aa5f8" "9937f5a67258bf63ad6f1a0b1d96dbd4faddfcefc5266ba6611722395c906556" "be52afe3f565636ad1b17d508b73d8743eeb524be22b3dcbc2c7468d54119c74" "68449a13d8e3b95811a198f3491de3e7fe942b330407abf82a4ed7c1b311663a" "c69890f4157015853d91e923037c227a33cdd5ec281ca3f79c44546b9d90ca00" "f064c99e3dd97911d39fe9c5d0b23a229a234cb36186c4819e8b9c5927726632" "291d6a418211cc2962e20fe47feb3edf330f2c603a9d48c0fcb5699dbfe58964" "25c5bac4aee82e57a85aaf4e2513e4f05796b07ba2ee47d80506f8d2c25e50fd" "14de71e6c418559302f939b0e1abd576f279c4b2e0feb85c1f28ff18f58891ff" "ef132eef2fa09346aee33c28eb130ff28f5b766953334113211996d20011a198" "e3fc433f9f2541010ae17c1bf202580f6047472fb36857fe843b19f5984009dd" "c324044e847a4f4a0ab34f719595de37252d6235365e9b84392b061085349d73" "203a4a13e96f5432ec0fd4a1ee65accdd5e3904df54c1da510b0ff20dcc0c77f" "cb2c0e0eb605cb0504db87632cf3d8b4dae6e705769d1de354270123cb11450e" "fc60ac47683d7b8d0f811365565fd98c4c8eb936bcab8d069fc33bd801b03ade" "a2e1fbc5aa463d08ca19896d2bf59a071b851e6c239052172f296bfb5e724047" "90a2181014f3b94a4e97d117b438130368cc39dbb2d198065ae3986547926cd2" "162f40a29f0c3c8745c0f50fba3852e566d44575c29d39a03f0cda721984b6f4" "40591f355e12d439ff150aab7613499dbd49adabc8676eef023b15b65bfc5ca0" "6948109f23f350db82123535eb8a7433bdabcb909271a6ecbcb58b936a88cd4e" "8f2e6ff5800175f113253d8fa9ca8885c2f552e657dc603f252e1a8e308f76f0" "be79e2fb8f5d5fbbe2e30ecadd220723c8c0aea8078cdfcb3868263ff8f09400" "54da48781893a7e49ad5aff4af300cd804a6b6279ab3ff3afb64491c85194aab" "760d58a606654f9f4400e8b38591356fbf6425aca26dc85244259ff2b19c41b9" "f96f3ca9ec1dde434da7d2d392b905ddf3d1f9af93d1af5950bd493f5aa731b4" "056df31bd267b6b90a079831aaf579be0a39013137aac6d404f518cfd4684064" "7e78bfe706ca4cf5e9c5453e9f7cfd2b8b4c8d169a44e55c88d4a9a7f9474241" "e221af44860018ab0856972e194cd934" ) # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.3 LONG_SERVER_PACKET_NUMBER = 1 LONG_SERVER_PLAIN_HEADER = binascii.unhexlify( "c1000000010008f067a5502a4262b50040750001" ) LONG_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( "02000000000600405a020000560303eefce7f7b37ba1d1632e96677825ddf739" "88cfc79825df566dc5430b9a045a1200130100002e00330024001d00209d3c94" "0d89690b84d08a60993c144eca684d1081287c834d5311bcf32bb9da1a002b00" "020304" ) LONG_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( "cf000000010008f067a5502a4262b5004075c0d95a482cd0991cd25b0aac406a" "5816b6394100f37a1c69797554780bb38cc5a99f5ede4cf73c3ec2493a1839b3" "dbcba3f6ea46c5b7684df3548e7ddeb9c3bf9c73cc3f3bded74b562bfb19fb84" "022f8ef4cdd93795d77d06edbb7aaf2f58891850abbdca3d20398c276456cbc4" "2158407dd074ee" ) SHORT_SERVER_PACKET_NUMBER = 3 SHORT_SERVER_PLAIN_HEADER = binascii.unhexlify("41b01fd24a586a9cf30003") SHORT_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( "06003904000035000151805a4bebf5000020b098c8dc4183e4c182572e10ac3e" "2b88897e0524c8461847548bd2dffa2c0ae60008002a0004ffffffff" ) SHORT_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( "5db01fd24a586a9cf33dec094aaec6d6b4b7a5e15f5a3f05d06cf1ad0355c19d" "cce0807eecf7bf1c844a66e1ecd1f74b2a2d69bfd25d217833edd973246597bd" "5107ea15cb1e210045396afa602fe23432f4ab24ce251b" ) class CryptoTest(TestCase): """ Test vectors from: https://datatracker.ietf.org/doc/html/rfc9001#appendix-A """ def create_crypto(self, is_client): pair = CryptoPair() pair.setup_initial( cid=binascii.unhexlify("8394c8f03e515708"), is_client=is_client, version=PROTOCOL_VERSION, ) return pair def test_derive_key_iv_hp(self): # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.1 # client secret = binascii.unhexlify( "c00cf151ca5be075ed0ebfb5c80323c42d6b7db67881289af4008f1f6c357aea" ) key, iv, hp = derive_key_iv_hp( cipher_suite=INITIAL_CIPHER_SUITE, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual(key, binascii.unhexlify("1f369613dd76d5467730efcbe3b1a22d")) self.assertEqual(iv, binascii.unhexlify("fa044b2f42a3fd3b46fb255c")) self.assertEqual(hp, binascii.unhexlify("9f50449e04a0e810283a1e9933adedd2")) # server secret = binascii.unhexlify( "3c199828fd139efd216c155ad844cc81fb82fa8d7446fa7d78be803acdda951b" ) key, iv, hp = derive_key_iv_hp( cipher_suite=INITIAL_CIPHER_SUITE, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual(key, binascii.unhexlify("cf3a5331653c364c88f0f379b6067e37")) self.assertEqual(iv, binascii.unhexlify("0ac1493ca1905853b0bba03e")) self.assertEqual(hp, binascii.unhexlify("c206b8d9b9f0f37644430b490eeaa314")) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_derive_key_iv_hp_chacha20(self): # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.5 # server secret = binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ) key, iv, hp = derive_key_iv_hp( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual( key, binascii.unhexlify( "c6d98ff3441c3fe1b2182094f69caa2ed4b716b65488960a7a984979fb23e1c8" ), ) self.assertEqual(iv, binascii.unhexlify("e0459b3474bdd0e44a41c144")) self.assertEqual( hp, binascii.unhexlify( "25a282b9e82f06f21f488917a4fc8f1b73573685608597d0efcb076b0ab7a7a4" ), ) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_decrypt_chacha20(self): pair = CryptoPair() pair.recv.setup( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ), version=PROTOCOL_VERSION, ) plain_header, plain_payload, packet_number = pair.decrypt_packet( CHACHA20_CLIENT_ENCRYPTED_PACKET, 1, CHACHA20_CLIENT_PACKET_NUMBER ) self.assertEqual(plain_header, CHACHA20_CLIENT_PLAIN_HEADER) self.assertEqual(plain_payload, CHACHA20_CLIENT_PLAIN_PAYLOAD) self.assertEqual(packet_number, CHACHA20_CLIENT_PACKET_NUMBER) def test_decrypt_long_client(self): pair = self.create_crypto(is_client=False) plain_header, plain_payload, packet_number = pair.decrypt_packet( LONG_CLIENT_ENCRYPTED_PACKET, 18, 0 ) self.assertEqual(plain_header, LONG_CLIENT_PLAIN_HEADER) self.assertEqual(plain_payload, LONG_CLIENT_PLAIN_PAYLOAD) self.assertEqual(packet_number, LONG_CLIENT_PACKET_NUMBER) def test_decrypt_long_server(self): pair = self.create_crypto(is_client=True) plain_header, plain_payload, packet_number = pair.decrypt_packet( LONG_SERVER_ENCRYPTED_PACKET, 18, 0 ) self.assertEqual(plain_header, LONG_SERVER_PLAIN_HEADER) self.assertEqual(plain_payload, LONG_SERVER_PLAIN_PAYLOAD) self.assertEqual(packet_number, LONG_SERVER_PACKET_NUMBER) def test_decrypt_no_key(self): pair = CryptoPair() with self.assertRaises(CryptoError): pair.decrypt_packet(LONG_SERVER_ENCRYPTED_PACKET, 18, 0) def test_decrypt_short_server(self): pair = CryptoPair() pair.recv.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=binascii.unhexlify( "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" ), version=PROTOCOL_VERSION, ) plain_header, plain_payload, packet_number = pair.decrypt_packet( SHORT_SERVER_ENCRYPTED_PACKET, 9, 0 ) self.assertEqual(plain_header, SHORT_SERVER_PLAIN_HEADER) self.assertEqual(plain_payload, SHORT_SERVER_PLAIN_PAYLOAD) self.assertEqual(packet_number, SHORT_SERVER_PACKET_NUMBER) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_encrypt_chacha20(self): pair = CryptoPair() pair.send.setup( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ), version=PROTOCOL_VERSION, ) packet = pair.encrypt_packet( CHACHA20_CLIENT_PLAIN_HEADER, CHACHA20_CLIENT_PLAIN_PAYLOAD, CHACHA20_CLIENT_PACKET_NUMBER, ) self.assertEqual(packet, CHACHA20_CLIENT_ENCRYPTED_PACKET) def test_encrypt_long_client(self): pair = self.create_crypto(is_client=True) packet = pair.encrypt_packet( LONG_CLIENT_PLAIN_HEADER, LONG_CLIENT_PLAIN_PAYLOAD, LONG_CLIENT_PACKET_NUMBER, ) self.assertEqual(packet, LONG_CLIENT_ENCRYPTED_PACKET) def test_encrypt_long_server(self): pair = self.create_crypto(is_client=False) packet = pair.encrypt_packet( LONG_SERVER_PLAIN_HEADER, LONG_SERVER_PLAIN_PAYLOAD, LONG_SERVER_PACKET_NUMBER, ) self.assertEqual(packet, LONG_SERVER_ENCRYPTED_PACKET) def test_encrypt_short_server(self): pair = CryptoPair() pair.send.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=binascii.unhexlify( "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" ), version=PROTOCOL_VERSION, ) packet = pair.encrypt_packet( SHORT_SERVER_PLAIN_HEADER, SHORT_SERVER_PLAIN_PAYLOAD, SHORT_SERVER_PACKET_NUMBER, ) self.assertEqual(packet, SHORT_SERVER_ENCRYPTED_PACKET) def test_key_update(self): pair1 = self.create_crypto(is_client=True) pair2 = self.create_crypto(is_client=False) def create_packet(key_phase, packet_number): buf = Buffer(capacity=100) buf.push_uint8(PACKET_FIXED_BIT | key_phase << 2 | 1) buf.push_bytes(binascii.unhexlify("8394c8f03e515708")) buf.push_uint16(packet_number) return buf.data, b"\x00\x01\x02\x03" def send(sender, receiver, packet_number=0): plain_header, plain_payload = create_packet( key_phase=sender.key_phase, packet_number=packet_number ) encrypted = sender.encrypt_packet( plain_header, plain_payload, packet_number ) recov_header, recov_payload, recov_packet_number = receiver.decrypt_packet( encrypted, len(plain_header) - 2, 0 ) self.assertEqual(recov_header, plain_header) self.assertEqual(recov_payload, plain_payload) self.assertEqual(recov_packet_number, packet_number) # roundtrip send(pair1, pair2, 0) send(pair2, pair1, 0) self.assertEqual(pair1.key_phase, 0) self.assertEqual(pair2.key_phase, 0) # pair 1 key update pair1.update_key() # roundtrip send(pair1, pair2, 1) send(pair2, pair1, 1) self.assertEqual(pair1.key_phase, 1) self.assertEqual(pair2.key_phase, 1) # pair 2 key update pair2.update_key() # roundtrip send(pair2, pair1, 2) send(pair1, pair2, 2) self.assertEqual(pair1.key_phase, 0) self.assertEqual(pair2.key_phase, 0) # pair 1 key - update, but not next to send pair1.update_key() # roundtrip send(pair2, pair1, 3) send(pair1, pair2, 3) self.assertEqual(pair1.key_phase, 1) self.assertEqual(pair2.key_phase, 1) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_crypto_v2.py0000644000175100001770000003415100000000000020460 0ustar00runnerdocker00000000000000import binascii from unittest import TestCase, skipIf from aioquic.buffer import Buffer from aioquic.quic.crypto import ( INITIAL_CIPHER_SUITE, CryptoError, CryptoPair, derive_key_iv_hp, ) from aioquic.quic.packet import PACKET_FIXED_BIT, QuicProtocolVersion from aioquic.tls import CipherSuite from .utils import SKIP_TESTS PROTOCOL_VERSION = QuicProtocolVersion.VERSION_2 # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.5 CHACHA20_CLIENT_PACKET_NUMBER = 654360564 CHACHA20_CLIENT_PLAIN_HEADER = binascii.unhexlify("4200bff4") CHACHA20_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify("01") CHACHA20_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( "5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba" ) # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.2 LONG_CLIENT_PACKET_NUMBER = 2 LONG_CLIENT_PLAIN_HEADER = binascii.unhexlify( "d36b3343cf088394c8f03e5157080000449e00000002" ) LONG_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify( "060040f1010000ed0303ebf8fa56f12939b9584a3896472ec40bb863cfd3e868" "04fe3a47f06a2b69484c00000413011302010000c000000010000e00000b6578" "616d706c652e636f6dff01000100000a00080006001d00170018001000070005" "04616c706e000500050100000000003300260024001d00209370b2c9caa47fba" "baf4559fedba753de171fa71f50f1ce15d43e994ec74d748002b000302030400" "0d0010000e0403050306030203080408050806002d00020101001c0002400100" "3900320408ffffffffffffffff05048000ffff07048000ffff08011001048000" "75300901100f088394c8f03e51570806048000ffff" ) + bytes(917) LONG_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( "d76b3343cf088394c8f03e5157080000449ea0c95e82ffe67b6abcdb4298b485" "dd04de806071bf03dceebfa162e75d6c96058bdbfb127cdfcbf903388e99ad04" "9f9a3dd4425ae4d0992cfff18ecf0fdb5a842d09747052f17ac2053d21f57c5d" "250f2c4f0e0202b70785b7946e992e58a59ac52dea6774d4f03b55545243cf1a" "12834e3f249a78d395e0d18f4d766004f1a2674802a747eaa901c3f10cda5500" "cb9122faa9f1df66c392079a1b40f0de1c6054196a11cbea40afb6ef5253cd68" "18f6625efce3b6def6ba7e4b37a40f7732e093daa7d52190935b8da58976ff33" "12ae50b187c1433c0f028edcc4c2838b6a9bfc226ca4b4530e7a4ccee1bfa2a3" "d396ae5a3fb512384b2fdd851f784a65e03f2c4fbe11a53c7777c023462239dd" "6f7521a3f6c7d5dd3ec9b3f233773d4b46d23cc375eb198c63301c21801f6520" "bcfb7966fc49b393f0061d974a2706df8c4a9449f11d7f3d2dcbb90c6b877045" "636e7c0c0fe4eb0f697545460c806910d2c355f1d253bc9d2452aaa549e27a1f" "ac7cf4ed77f322e8fa894b6a83810a34b361901751a6f5eb65a0326e07de7c12" "16ccce2d0193f958bb3850a833f7ae432b65bc5a53975c155aa4bcb4f7b2c4e5" "4df16efaf6ddea94e2c50b4cd1dfe06017e0e9d02900cffe1935e0491d77ffb4" "fdf85290fdd893d577b1131a610ef6a5c32b2ee0293617a37cbb08b847741c3b" "8017c25ca9052ca1079d8b78aebd47876d330a30f6a8c6d61dd1ab5589329de7" "14d19d61370f8149748c72f132f0fc99f34d766c6938597040d8f9e2bb522ff9" "9c63a344d6a2ae8aa8e51b7b90a4a806105fcbca31506c446151adfeceb51b91" "abfe43960977c87471cf9ad4074d30e10d6a7f03c63bd5d4317f68ff325ba3bd" "80bf4dc8b52a0ba031758022eb025cdd770b44d6d6cf0670f4e990b22347a7db" "848265e3e5eb72dfe8299ad7481a408322cac55786e52f633b2fb6b614eaed18" "d703dd84045a274ae8bfa73379661388d6991fe39b0d93debb41700b41f90a15" "c4d526250235ddcd6776fc77bc97e7a417ebcb31600d01e57f32162a8560cacc" "7e27a096d37a1a86952ec71bd89a3e9a30a2a26162984d7740f81193e8238e61" "f6b5b984d4d3dfa033c1bb7e4f0037febf406d91c0dccf32acf423cfa1e70710" "10d3f270121b493ce85054ef58bada42310138fe081adb04e2bd901f2f13458b" "3d6758158197107c14ebb193230cd1157380aa79cae1374a7c1e5bbcb80ee23e" "06ebfde206bfb0fcbc0edc4ebec309661bdd908d532eb0c6adc38b7ca7331dce" "8dfce39ab71e7c32d318d136b6100671a1ae6a6600e3899f31f0eed19e3417d1" "34b90c9058f8632c798d4490da4987307cba922d61c39805d072b589bd52fdf1" "e86215c2d54e6670e07383a27bbffb5addf47d66aa85a0c6f9f32e59d85a44dd" "5d3b22dc2be80919b490437ae4f36a0ae55edf1d0b5cb4e9a3ecabee93dfc6e3" "8d209d0fa6536d27a5d6fbb17641cde27525d61093f1b28072d111b2b4ae5f89" "d5974ee12e5cf7d5da4d6a31123041f33e61407e76cffcdcfd7e19ba58cf4b53" "6f4c4938ae79324dc402894b44faf8afbab35282ab659d13c93f70412e85cb19" "9a37ddec600545473cfb5a05e08d0b209973b2172b4d21fb69745a262ccde96b" "a18b2faa745b6fe189cf772a9f84cbfc" ) # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.3 LONG_SERVER_PACKET_NUMBER = 1 LONG_SERVER_PLAIN_HEADER = binascii.unhexlify( "d16b3343cf0008f067a5502a4262b50040750001" ) LONG_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( "02000000000600405a020000560303eefce7f7b37ba1d1632e96677825ddf739" "88cfc79825df566dc5430b9a045a1200130100002e00330024001d00209d3c94" "0d89690b84d08a60993c144eca684d1081287c834d5311bcf32bb9da1a002b00" "020304" ) LONG_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( "dc6b3343cf0008f067a5502a4262b5004075d92faaf16f05d8a4398c47089698" "baeea26b91eb761d9b89237bbf87263017915358230035f7fd3945d88965cf17" "f9af6e16886c61bfc703106fbaf3cb4cfa52382dd16a393e42757507698075b2" "c984c707f0a0812d8cd5a6881eaf21ceda98f4bd23f6fe1a3e2c43edd9ce7ca8" "4bed8521e2e140" ) SHORT_SERVER_PACKET_NUMBER = 3 SHORT_SERVER_PLAIN_HEADER = binascii.unhexlify("41b01fd24a586a9cf30003") SHORT_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( "06003904000035000151805a4bebf5000020b098c8dc4183e4c182572e10ac3e" "2b88897e0524c8461847548bd2dffa2c0ae60008002a0004ffffffff" ) SHORT_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( "59b01fd24a586a9cf3be262d3eb9b42ada03644d223dae08cbffd5bddab1cf02" "c33711d0cf5cdc785ce55a4d95c6a82e117ba937080ac6d063915f8c4ee28bd3" "d86949197c48e8550aa32612f9af806a6c20d6d10ed08f" ) class CryptoTest(TestCase): """ Test vectors from: https://datatracker.ietf.org/doc/html/rfc9001#appendix-A """ def create_crypto(self, is_client): pair = CryptoPair() pair.setup_initial( cid=binascii.unhexlify("8394c8f03e515708"), is_client=is_client, version=PROTOCOL_VERSION, ) return pair def test_derive_key_iv_hp(self): # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.1 # client secret = binascii.unhexlify( "14ec9d6eb9fd7af83bf5a668bc17a7e283766aade7ecd0891f70f9ff7f4bf47b" ) key, iv, hp = derive_key_iv_hp( cipher_suite=INITIAL_CIPHER_SUITE, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual(key, binascii.unhexlify("8b1a0bc121284290a29e0971b5cd045d")) self.assertEqual(iv, binascii.unhexlify("91f73e2351d8fa91660e909f")) self.assertEqual(hp, binascii.unhexlify("45b95e15235d6f45a6b19cbcb0294ba9")) # server secret = binascii.unhexlify( "0263db1782731bf4588e7e4d93b7463907cb8cd8200b5da55a8bd488eafc37c1" ) key, iv, hp = derive_key_iv_hp( cipher_suite=INITIAL_CIPHER_SUITE, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual(key, binascii.unhexlify("82db637861d55e1d011f19ea71d5d2a7")) self.assertEqual(iv, binascii.unhexlify("dd13c276499c0249d3310652")) self.assertEqual(hp, binascii.unhexlify("edf6d05c83121201b436e16877593c3a")) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_derive_key_iv_hp_chacha20(self): # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.5 # server secret = binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ) key, iv, hp = derive_key_iv_hp( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=secret, version=PROTOCOL_VERSION, ) self.assertEqual( key, binascii.unhexlify( "3bfcddd72bcf02541d7fa0dd1f5f9eeea817e09a6963a0e6c7df0f9a1bab90f2" ), ) self.assertEqual(iv, binascii.unhexlify("a6b5bc6ab7dafce30ffff5dd")) self.assertEqual( hp, binascii.unhexlify( "d659760d2ba434a226fd37b35c69e2da8211d10c4f12538787d65645d5d1b8e2" ), ) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_decrypt_chacha20(self): pair = CryptoPair() pair.recv.setup( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ), version=PROTOCOL_VERSION, ) plain_header, plain_payload, packet_number = pair.decrypt_packet( CHACHA20_CLIENT_ENCRYPTED_PACKET, 1, CHACHA20_CLIENT_PACKET_NUMBER ) self.assertEqual(plain_header, CHACHA20_CLIENT_PLAIN_HEADER) self.assertEqual(plain_payload, CHACHA20_CLIENT_PLAIN_PAYLOAD) self.assertEqual(packet_number, CHACHA20_CLIENT_PACKET_NUMBER) def test_decrypt_long_client(self): pair = self.create_crypto(is_client=False) plain_header, plain_payload, packet_number = pair.decrypt_packet( LONG_CLIENT_ENCRYPTED_PACKET, 18, 0 ) self.assertEqual(plain_header, LONG_CLIENT_PLAIN_HEADER) self.assertEqual(plain_payload, LONG_CLIENT_PLAIN_PAYLOAD) self.assertEqual(packet_number, LONG_CLIENT_PACKET_NUMBER) def test_decrypt_long_server(self): pair = self.create_crypto(is_client=True) plain_header, plain_payload, packet_number = pair.decrypt_packet( LONG_SERVER_ENCRYPTED_PACKET, 18, 0 ) self.assertEqual(plain_header, LONG_SERVER_PLAIN_HEADER) self.assertEqual(plain_payload, LONG_SERVER_PLAIN_PAYLOAD) self.assertEqual(packet_number, LONG_SERVER_PACKET_NUMBER) def test_decrypt_no_key(self): pair = CryptoPair() with self.assertRaises(CryptoError): pair.decrypt_packet(LONG_SERVER_ENCRYPTED_PACKET, 18, 0) def test_decrypt_short_server(self): pair = CryptoPair() pair.recv.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=binascii.unhexlify( "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" ), version=PROTOCOL_VERSION, ) plain_header, plain_payload, packet_number = pair.decrypt_packet( SHORT_SERVER_ENCRYPTED_PACKET, 9, 0 ) self.assertEqual(plain_header, SHORT_SERVER_PLAIN_HEADER) self.assertEqual(plain_payload, SHORT_SERVER_PLAIN_PAYLOAD) self.assertEqual(packet_number, SHORT_SERVER_PACKET_NUMBER) @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") def test_encrypt_chacha20(self): pair = CryptoPair() pair.send.setup( cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, secret=binascii.unhexlify( "9ac312a7f877468ebe69422748ad00a15443f18203a07d6060f688f30f21632b" ), version=PROTOCOL_VERSION, ) packet = pair.encrypt_packet( CHACHA20_CLIENT_PLAIN_HEADER, CHACHA20_CLIENT_PLAIN_PAYLOAD, CHACHA20_CLIENT_PACKET_NUMBER, ) self.assertEqual(packet, CHACHA20_CLIENT_ENCRYPTED_PACKET) def test_encrypt_long_client(self): pair = self.create_crypto(is_client=True) packet = pair.encrypt_packet( LONG_CLIENT_PLAIN_HEADER, LONG_CLIENT_PLAIN_PAYLOAD, LONG_CLIENT_PACKET_NUMBER, ) self.assertEqual(packet, LONG_CLIENT_ENCRYPTED_PACKET) def test_encrypt_long_server(self): pair = self.create_crypto(is_client=False) packet = pair.encrypt_packet( LONG_SERVER_PLAIN_HEADER, LONG_SERVER_PLAIN_PAYLOAD, LONG_SERVER_PACKET_NUMBER, ) self.assertEqual(packet, LONG_SERVER_ENCRYPTED_PACKET) def test_encrypt_short_server(self): pair = CryptoPair() pair.send.setup( cipher_suite=INITIAL_CIPHER_SUITE, secret=binascii.unhexlify( "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" ), version=PROTOCOL_VERSION, ) packet = pair.encrypt_packet( SHORT_SERVER_PLAIN_HEADER, SHORT_SERVER_PLAIN_PAYLOAD, SHORT_SERVER_PACKET_NUMBER, ) self.assertEqual(packet, SHORT_SERVER_ENCRYPTED_PACKET) def test_key_update(self): pair1 = self.create_crypto(is_client=True) pair2 = self.create_crypto(is_client=False) def create_packet(key_phase, packet_number): buf = Buffer(capacity=100) buf.push_uint8(PACKET_FIXED_BIT | key_phase << 2 | 1) buf.push_bytes(binascii.unhexlify("8394c8f03e515708")) buf.push_uint16(packet_number) return buf.data, b"\x00\x01\x02\x03" def send(sender, receiver, packet_number=0): plain_header, plain_payload = create_packet( key_phase=sender.key_phase, packet_number=packet_number ) encrypted = sender.encrypt_packet( plain_header, plain_payload, packet_number ) recov_header, recov_payload, recov_packet_number = receiver.decrypt_packet( encrypted, len(plain_header) - 2, 0 ) self.assertEqual(recov_header, plain_header) self.assertEqual(recov_payload, plain_payload) self.assertEqual(recov_packet_number, packet_number) # roundtrip send(pair1, pair2, 0) send(pair2, pair1, 0) self.assertEqual(pair1.key_phase, 0) self.assertEqual(pair2.key_phase, 0) # pair 1 key update pair1.update_key() # roundtrip send(pair1, pair2, 1) send(pair2, pair1, 1) self.assertEqual(pair1.key_phase, 1) self.assertEqual(pair2.key_phase, 1) # pair 2 key update pair2.update_key() # roundtrip send(pair2, pair1, 2) send(pair1, pair2, 2) self.assertEqual(pair1.key_phase, 0) self.assertEqual(pair2.key_phase, 0) # pair 1 key - update, but not next to send pair1.update_key() # roundtrip send(pair2, pair1, 3) send(pair1, pair2, 3) self.assertEqual(pair1.key_phase, 1) self.assertEqual(pair2.key_phase, 1) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_h0.py0000644000175100001770000001565600000000000017051 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.events import DataReceived, HeadersReceived from aioquic.quic.events import StreamDataReceived from .test_connection import client_and_server, transfer def h0_client_and_server(): return client_and_server( client_options={"alpn_protocols": H0_ALPN}, server_options={"alpn_protocols": H0_ALPN}, ) def h0_transfer(quic_sender, h0_receiver): quic_receiver = h0_receiver._quic transfer(quic_sender, quic_receiver) # process QUIC events http_events = [] event = quic_receiver.next_event() while event is not None: http_events.extend(h0_receiver.handle_event(event)) event = quic_receiver.next_event() return http_events class H0ConnectionTest(TestCase): def test_connect(self): with h0_client_and_server() as (quic_client, quic_server): h0_client = H0Connection(quic_client) h0_server = H0Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h0_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], ) h0_client.send_data(stream_id=stream_id, data=b"", end_stream=True) # receive request events = h0_transfer(quic_client, h0_server) self.assertEqual(len(events), 2) self.assertTrue(isinstance(events[0], HeadersReceived)) self.assertEqual( events[0].headers, [(b":method", b"GET"), (b":path", b"/")] ) self.assertEqual(events[0].stream_id, stream_id) self.assertEqual(events[0].stream_ended, False) self.assertTrue(isinstance(events[1], DataReceived)) self.assertEqual(events[1].data, b"") self.assertEqual(events[1].stream_id, stream_id) self.assertEqual(events[1].stream_ended, True) # send response h0_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], ) h0_server.send_data( stream_id=stream_id, data=b"hello", end_stream=True, ) # receive response events = h0_transfer(quic_server, h0_client) self.assertEqual(len(events), 2) self.assertTrue(isinstance(events[0], HeadersReceived)) self.assertEqual(events[0].headers, []) self.assertEqual(events[0].stream_id, stream_id) self.assertEqual(events[0].stream_ended, False) self.assertTrue(isinstance(events[1], DataReceived)) self.assertEqual(events[1].data, b"hello") self.assertEqual(events[1].stream_id, stream_id) self.assertEqual(events[1].stream_ended, True) def test_headers_only(self): with h0_client_and_server() as (quic_client, quic_server): h0_client = H0Connection(quic_client) h0_server = H0Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h0_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"HEAD"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], end_stream=True, ) # receive request events = h0_transfer(quic_client, h0_server) self.assertEqual(len(events), 2) self.assertTrue(isinstance(events[0], HeadersReceived)) self.assertEqual( events[0].headers, [(b":method", b"HEAD"), (b":path", b"/")] ) self.assertEqual(events[0].stream_id, stream_id) self.assertEqual(events[0].stream_ended, False) self.assertTrue(isinstance(events[1], DataReceived)) self.assertEqual(events[1].data, b"") self.assertEqual(events[1].stream_id, stream_id) self.assertEqual(events[1].stream_ended, True) # send response h0_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], end_stream=True, ) # receive response events = h0_transfer(quic_server, h0_client) self.assertEqual(len(events), 2) self.assertTrue(isinstance(events[0], HeadersReceived)) self.assertEqual(events[0].headers, []) self.assertEqual(events[0].stream_id, stream_id) self.assertEqual(events[0].stream_ended, False) self.assertTrue(isinstance(events[1], DataReceived)) self.assertEqual(events[1].data, b"") self.assertEqual(events[1].stream_id, stream_id) self.assertEqual(events[1].stream_ended, True) def test_fragmented_request(self): with h0_client_and_server() as (quic_client, quic_server): h0_server = H0Connection(quic_server) stream_id = 0 # receive first fragment of the request events = h0_server.handle_event( StreamDataReceived( data=b"GET /012", end_stream=False, stream_id=stream_id ) ) self.assertEqual(len(events), 0) # receive second fragment of the request events = h0_server.handle_event( StreamDataReceived( data=b"34567890", end_stream=False, stream_id=stream_id ) ) # receive final fragment of the request events = h0_server.handle_event( StreamDataReceived( data=b"123456\r\n", end_stream=True, stream_id=stream_id ) ) self.assertEqual(len(events), 2) self.assertTrue(isinstance(events[0], HeadersReceived)) self.assertEqual( events[0].headers, [(b":method", b"GET"), (b":path", b"/01234567890123456")], ) self.assertEqual(events[0].stream_id, stream_id) self.assertEqual(events[0].stream_ended, False) self.assertTrue(isinstance(events[1], DataReceived)) self.assertEqual(events[1].data, b"") self.assertEqual(events[1].stream_id, stream_id) self.assertEqual(events[1].stream_ended, True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_h3.py0000644000175100001770000021125400000000000017044 0ustar00runnerdocker00000000000000import binascii import contextlib import copy from unittest import TestCase from aioquic.buffer import Buffer, encode_uint_var from aioquic.h3.connection import ( H3_ALPN, ErrorCode, FrameType, FrameUnexpected, H3Connection, MessageError, Setting, SettingsError, StreamType, encode_frame, encode_settings, parse_settings, validate_push_promise_headers, validate_request_headers, validate_response_headers, validate_trailers, ) from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived from aioquic.h3.exceptions import InvalidStreamTypeError, NoAvailablePushIDError from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import StreamDataReceived from aioquic.quic.logger import QuicLogger from .test_connection import client_and_server, transfer DUMMY_SETTINGS = { Setting.QPACK_MAX_TABLE_CAPACITY: 4096, Setting.QPACK_BLOCKED_STREAMS: 16, Setting.DUMMY: 1, } QUIC_CONFIGURATION_OPTIONS = {"alpn_protocols": H3_ALPN} def h3_client_and_server(options=QUIC_CONFIGURATION_OPTIONS): return client_and_server( client_options=options, server_options=options, ) @contextlib.contextmanager def h3_fake_client_and_server(options=QUIC_CONFIGURATION_OPTIONS): quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True, **options) ) quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False, **options) ) # exchange transport parameters quic_client._remote_max_datagram_frame_size = ( quic_server.configuration.max_datagram_frame_size ) quic_server._remote_max_datagram_frame_size = ( quic_client.configuration.max_datagram_frame_size ) yield quic_client, quic_server def h3_transfer(quic_sender, h3_receiver): quic_receiver = h3_receiver._quic if hasattr(quic_sender, "stream_queue"): quic_receiver._events.extend(quic_sender.stream_queue) quic_sender.stream_queue.clear() else: transfer(quic_sender, quic_receiver) # process QUIC events http_events = [] event = quic_receiver.next_event() while event is not None: http_events.extend(h3_receiver.handle_event(event)) event = quic_receiver.next_event() return http_events class FakeQuicConnection: def __init__(self, configuration): self.closed = None self.configuration = configuration self.stream_queue = [] self._events = [] self._next_stream_bidi = 0 if configuration.is_client else 1 self._next_stream_uni = 2 if configuration.is_client else 3 self._quic_logger = QuicLogger().start_trace( is_client=configuration.is_client, odcid=b"" ) self._remote_max_datagram_frame_size = None def close(self, error_code, reason_phrase): self.closed = (error_code, reason_phrase) def get_next_available_stream_id(self, is_unidirectional=False): if is_unidirectional: stream_id = self._next_stream_uni self._next_stream_uni += 4 else: stream_id = self._next_stream_bidi self._next_stream_bidi += 4 return stream_id def next_event(self): try: return self._events.pop(0) except IndexError: return None def send_stream_data(self, stream_id, data, end_stream=False): # chop up data into individual bytes for c in data: self.stream_queue.append( StreamDataReceived( data=bytes([c]), end_stream=False, stream_id=stream_id ) ) if end_stream: self.stream_queue.append( StreamDataReceived(data=b"", end_stream=end_stream, stream_id=stream_id) ) class H3ConnectionTest(TestCase): maxDiff = None def _make_request(self, h3_client, h3_server): quic_client = h3_client._quic quic_server = h3_server._quic # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], ) h3_client.send_data(stream_id=stream_id, data=b"", end_stream=True) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], stream_id=stream_id, stream_ended=False, ), DataReceived(data=b"", stream_id=stream_id, stream_ended=True), ], ) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), (b"x-foo", b"server"), ], ) h3_server.send_data( stream_id=stream_id, data=b"hello", end_stream=True, ) # receive response events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), (b"x-foo", b"server"), ], stream_id=stream_id, stream_ended=False, ), DataReceived( data=b"hello", stream_id=stream_id, stream_ended=True, ), ], ) def test_handle_control_frame_headers(self): """ We should not receive HEADERS on the control stream. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) self.assertIsNotNone(h3_server.sent_settings) self.assertIsNone(h3_server.received_settings) # receive SETTINGS h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(DUMMY_SETTINGS)), end_stream=False, ) ) self.assertIsNone(quic_server.closed) self.assertIsNotNone(h3_server.sent_settings) self.assertEqual(h3_server.received_settings, DUMMY_SETTINGS) # receive unexpected HEADERS h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_frame(FrameType.HEADERS, b""), end_stream=False, ) ) self.assertEqual( quic_server.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "Invalid frame type on control stream"), ) def test_handle_control_frame_max_push_id_from_client_before_settings(self): """ A server should not receive MAX_PUSH_ID before SETTINGS. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive unexpected MAX_PUSH_ID h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.MAX_PUSH_ID, b""), end_stream=False, ) ) self.assertEqual( quic_server.closed, (ErrorCode.H3_MISSING_SETTINGS, ""), ) def test_handle_control_frame_max_push_id_from_server(self): """ A client should not receive MAX_PUSH_ID on the control stream. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) # receive SETTINGS h3_client.handle_event( StreamDataReceived( stream_id=3, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(DUMMY_SETTINGS)), end_stream=False, ) ) self.assertIsNone(quic_client.closed) # receive unexpected MAX_PUSH_ID h3_client.handle_event( StreamDataReceived( stream_id=3, data=encode_frame(FrameType.MAX_PUSH_ID, b""), end_stream=False, ) ) self.assertEqual( quic_client.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "Servers must not send MAX_PUSH_ID"), ) def test_handle_control_settings_twice(self): """ We should not receive HEADERS on the control stream. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(DUMMY_SETTINGS)), end_stream=False, ) ) self.assertIsNone(quic_server.closed) # receive unexpected SETTINGS h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_frame(FrameType.SETTINGS, encode_settings(DUMMY_SETTINGS)), end_stream=False, ) ) self.assertEqual( quic_server.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "SETTINGS have already been received"), ) def test_handle_control_stream_close(self): """ Closing the control stream is not allowed. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) # receive SETTINGS h3_client.handle_event( StreamDataReceived( stream_id=3, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(DUMMY_SETTINGS)), end_stream=False, ) ) self.assertIsNone(quic_client.closed) # receive unexpected FIN h3_client.handle_event( StreamDataReceived( stream_id=3, data=b"", end_stream=True, ) ) self.assertEqual( quic_client.closed, ( ErrorCode.H3_CLOSED_CRITICAL_STREAM, "Closing control stream is not allowed", ), ) def test_handle_control_stream_duplicate(self): """ We must only receive a single control stream. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive a first control stream h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL), end_stream=False ) ) # receive a second control stream h3_server.handle_event( StreamDataReceived( stream_id=6, data=encode_uint_var(StreamType.CONTROL), end_stream=False ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_STREAM_CREATION_ERROR, "Only one control stream is allowed", ), ) def test_handle_push_frame_wrong_frame_type(self): """ We should not received SETTINGS on a push stream. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) h3_client.handle_event( StreamDataReceived( stream_id=15, data=encode_uint_var(StreamType.PUSH) + encode_uint_var(0) # push ID + encode_frame(FrameType.SETTINGS, b""), end_stream=False, ) ) self.assertEqual( quic_client.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "Invalid frame type on push stream"), ) def test_handle_qpack_decoder_duplicate(self): """ We must only receive a single QPACK decoder stream. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) # receive a first decoder stream h3_client.handle_event( StreamDataReceived( stream_id=11, data=encode_uint_var(StreamType.QPACK_DECODER), end_stream=False, ) ) # receive a second decoder stream h3_client.handle_event( StreamDataReceived( stream_id=15, data=encode_uint_var(StreamType.QPACK_DECODER), end_stream=False, ) ) self.assertEqual( quic_client.closed, ( ErrorCode.H3_STREAM_CREATION_ERROR, "Only one QPACK decoder stream is allowed", ), ) def test_handle_qpack_decoder_stream_error(self): """ Receiving garbage on the QPACK decoder stream triggers an exception. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) h3_client.handle_event( StreamDataReceived( stream_id=11, data=encode_uint_var(StreamType.QPACK_DECODER) + b"\x00", end_stream=False, ) ) self.assertEqual(quic_client.closed, (ErrorCode.QPACK_DECODER_STREAM_ERROR, "")) def test_handle_qpack_encoder_duplicate(self): """ We must only receive a single QPACK encoder stream. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) # receive a first encoder stream h3_client.handle_event( StreamDataReceived( stream_id=11, data=encode_uint_var(StreamType.QPACK_ENCODER), end_stream=False, ) ) # receive a second encoder stream h3_client.handle_event( StreamDataReceived( stream_id=15, data=encode_uint_var(StreamType.QPACK_ENCODER), end_stream=False, ) ) self.assertEqual( quic_client.closed, ( ErrorCode.H3_STREAM_CREATION_ERROR, "Only one QPACK encoder stream is allowed", ), ) def test_handle_qpack_encoder_stream_error(self): """ Receiving garbage on the QPACK encoder stream triggers an exception. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) h3_client.handle_event( StreamDataReceived( stream_id=7, data=encode_uint_var(StreamType.QPACK_ENCODER) + b"\x00", end_stream=False, ) ) self.assertEqual(quic_client.closed, (ErrorCode.QPACK_ENCODER_STREAM_ERROR, "")) def test_handle_request_frame_bad_headers(self): """ We should not receive HEADERS which cannot be decoded. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) h3_server.handle_event( StreamDataReceived( stream_id=0, data=encode_frame(FrameType.HEADERS, b""), end_stream=False ) ) self.assertEqual(quic_server.closed, (ErrorCode.QPACK_DECOMPRESSION_FAILED, "")) def test_handle_request_frame_data_before_headers(self): """ We should not receive DATA before receiving headers. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) h3_server.handle_event( StreamDataReceived( stream_id=0, data=encode_frame(FrameType.DATA, b""), end_stream=False ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_FRAME_UNEXPECTED, "DATA frame is not allowed in this state", ), ) def test_handle_request_frame_headers_after_trailers(self): """ We should not receive HEADERS after receiving trailers. """ with h3_fake_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], ) h3_client.send_headers( stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=True, ) h3_transfer(quic_client, h3_server) h3_server.handle_event( StreamDataReceived( stream_id=0, data=encode_frame(FrameType.HEADERS, b""), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_FRAME_UNEXPECTED, "HEADERS frame is not allowed in this state", ), ) def test_handle_request_frame_push_promise_from_client(self): """ A server should not receive PUSH_PROMISE on a request stream. """ quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) h3_server.handle_event( StreamDataReceived( stream_id=0, data=encode_frame(FrameType.PUSH_PROMISE, b""), end_stream=False, ) ) self.assertEqual( quic_server.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "Clients must not send PUSH_PROMISE"), ) def test_handle_request_frame_wrong_frame_type(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) h3_server.handle_event( StreamDataReceived( stream_id=0, data=encode_frame(FrameType.SETTINGS, b""), end_stream=False, ) ) self.assertEqual( quic_server.closed, (ErrorCode.H3_FRAME_UNEXPECTED, "Invalid frame type on request stream"), ) def test_request(self): with h3_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # make first request self._make_request(h3_client, h3_server) # make second request self._make_request(h3_client, h3_server) # make third request -> dynamic table self._make_request(h3_client, h3_server) def test_request_headers_only(self): with h3_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"HEAD"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], end_stream=True, ) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"HEAD"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], stream_id=stream_id, stream_ended=True, ) ], ) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), (b"x-foo", b"server"), ], end_stream=True, ) # receive response events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), (b"x-foo", b"server"), ], stream_id=stream_id, stream_ended=True, ) ], ) def test_request_fragmented_frame(self): with h3_fake_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], ) h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=True) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b"x-foo", b"client"), ], stream_id=stream_id, stream_ended=False, ), DataReceived(data=b"h", stream_id=0, stream_ended=False), DataReceived(data=b"e", stream_id=0, stream_ended=False), DataReceived(data=b"l", stream_id=0, stream_ended=False), DataReceived(data=b"l", stream_id=0, stream_ended=False), DataReceived(data=b"o", stream_id=0, stream_ended=False), DataReceived(data=b"", stream_id=0, stream_ended=True), ], ) # send push promise push_stream_id = h3_server.send_push_promise( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.txt"), ], ) self.assertEqual(push_stream_id, 15) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], end_stream=False, ) h3_server.send_data(stream_id=stream_id, data=b"html", end_stream=True) #  fulfill push promise h3_server.send_headers( stream_id=push_stream_id, headers=[(b":status", b"200"), (b"content-type", b"text/plain")], end_stream=False, ) h3_server.send_data(stream_id=push_stream_id, data=b"text", end_stream=True) # receive push promise / response events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ PushPromiseReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.txt"), ], push_id=0, stream_id=stream_id, ), HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], stream_id=0, stream_ended=False, ), DataReceived(data=b"h", stream_id=0, stream_ended=False), DataReceived(data=b"t", stream_id=0, stream_ended=False), DataReceived(data=b"m", stream_id=0, stream_ended=False), DataReceived(data=b"l", stream_id=0, stream_ended=False), DataReceived(data=b"", stream_id=0, stream_ended=True), HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/plain"), ], stream_id=15, stream_ended=False, push_id=0, ), DataReceived( data=b"t", stream_id=15, stream_ended=False, push_id=0 ), DataReceived( data=b"e", stream_id=15, stream_ended=False, push_id=0 ), DataReceived( data=b"x", stream_id=15, stream_ended=False, push_id=0 ), DataReceived( data=b"t", stream_id=15, stream_ended=False, push_id=0 ), DataReceived(data=b"", stream_id=15, stream_ended=True, push_id=0), ], ) def test_request_with_server_push(self): with h3_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], end_stream=True, ) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], stream_id=stream_id, stream_ended=True, ) ], ) # send push promises push_stream_id_css = h3_server.send_push_promise( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.css"), ], ) self.assertEqual(push_stream_id_css, 15) push_stream_id_js = h3_server.send_push_promise( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.js"), ], ) self.assertEqual(push_stream_id_js, 19) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], end_stream=False, ) h3_server.send_data( stream_id=stream_id, data=b"hello", end_stream=True, ) #  fulfill push promises h3_server.send_headers( stream_id=push_stream_id_css, headers=[(b":status", b"200"), (b"content-type", b"text/css")], end_stream=False, ) h3_server.send_data( stream_id=push_stream_id_css, data=b"body { color: pink }", end_stream=True, ) h3_server.send_headers( stream_id=push_stream_id_js, headers=[ (b":status", b"200"), (b"content-type", b"application/javascript"), ], end_stream=False, ) h3_server.send_data( stream_id=push_stream_id_js, data=b"alert('howdee');", end_stream=True ) # receive push promises, response and push responses events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ PushPromiseReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.css"), ], push_id=0, stream_id=stream_id, ), PushPromiseReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/app.js"), ], push_id=1, stream_id=stream_id, ), HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], stream_id=stream_id, stream_ended=False, ), DataReceived( data=b"hello", stream_id=stream_id, stream_ended=True, ), HeadersReceived( headers=[(b":status", b"200"), (b"content-type", b"text/css")], push_id=0, stream_id=push_stream_id_css, stream_ended=False, ), DataReceived( data=b"body { color: pink }", push_id=0, stream_id=push_stream_id_css, stream_ended=True, ), HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"application/javascript"), ], push_id=1, stream_id=push_stream_id_js, stream_ended=False, ), DataReceived( data=b"alert('howdee');", push_id=1, stream_id=push_stream_id_js, stream_ended=True, ), ], ) def test_request_with_server_push_max_push_id(self): with h3_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], end_stream=True, ) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], stream_id=stream_id, stream_ended=True, ) ], ) # send push promise on a server-initiated stream with self.assertRaises(InvalidStreamTypeError): h3_server.send_push_promise( stream_id=1, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/bad.css"), ], ) # send push promises for i in range(0, 8): h3_server.send_push_promise( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", "/{}.css".format(i).encode("ascii")), ], ) # send one too many with self.assertRaises(NoAvailablePushIDError): h3_server.send_push_promise( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/8.css"), ], ) def test_send_data_after_trailers(self): """ We should not send DATA after trailers. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], ) h3_client.send_headers( stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=False ) with self.assertRaises(FrameUnexpected): h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=False) def test_send_data_before_headers(self): """ We should not send DATA before headers. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) stream_id = quic_client.get_next_available_stream_id() with self.assertRaises(FrameUnexpected): h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=False) def test_send_headers_after_trailers(self): """ We should not send HEADERS after trailers. """ quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], ) h3_client.send_headers( stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=False ) with self.assertRaises(FrameUnexpected): h3_client.send_headers( stream_id=stream_id, headers=[(b"x-other-trailer", b"foo")], end_stream=False, ) def test_blocked_stream(self): quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) h3_client.handle_event( StreamDataReceived( stream_id=3, data=binascii.unhexlify( "0004170150000680020000074064091040bcc0000000faceb00c" ), end_stream=False, ) ) h3_client.handle_event( StreamDataReceived(stream_id=7, data=b"\x02", end_stream=False) ) h3_client.handle_event( StreamDataReceived(stream_id=11, data=b"\x03", end_stream=False) ) h3_client.handle_event( StreamDataReceived( stream_id=0, data=binascii.unhexlify("01040280d910"), end_stream=False ) ) h3_client.handle_event( StreamDataReceived( stream_id=0, data=binascii.unhexlify( "00408d796f752072656163686564206d766673742e6e65742c20726561636820" "746865202f6563686f20656e64706f696e7420666f7220616e206563686f2072" "6573706f6e7365207175657279202f3c6e756d6265723e20656e64706f696e74" "7320666f722061207661726961626c652073697a6520726573706f6e73652077" "6974682072616e646f6d206279746573" ), end_stream=True, ) ) self.assertEqual( h3_client.handle_event( StreamDataReceived( stream_id=7, data=binascii.unhexlify( "3fe101c696d07abe941094cb6d0a08017d403971966e32ca98b46f" ), end_stream=False, ) ), [ HeadersReceived( headers=[ (b":status", b"200"), (b"date", b"Mon, 22 Jul 2019 06:33:33 GMT"), ], stream_id=0, stream_ended=False, ), DataReceived( data=( b"you reached mvfst.net, reach the /echo endpoint for an " b"echo response query / endpoints for a variable " b"size response with random bytes" ), stream_id=0, stream_ended=True, ), ], ) def test_blocked_stream_trailer(self): quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) h3_client = H3Connection(quic_client) h3_client.handle_event( StreamDataReceived( stream_id=3, data=binascii.unhexlify( "0004170150000680020000074064091040bcc0000000faceb00c" ), end_stream=False, ) ) h3_client.handle_event( StreamDataReceived(stream_id=7, data=b"\x02", end_stream=False) ) h3_client.handle_event( StreamDataReceived(stream_id=11, data=b"\x03", end_stream=False) ) self.assertEqual( h3_client.handle_event( StreamDataReceived( stream_id=0, data=binascii.unhexlify( "011b0000d95696d07abe941094cb6d0a08017d403971966e32ca98b46f" ), end_stream=False, ) ), [ HeadersReceived( headers=[ (b":status", b"200"), (b"date", b"Mon, 22 Jul 2019 06:33:33 GMT"), ], stream_id=0, stream_ended=False, ) ], ) self.assertEqual( h3_client.handle_event( StreamDataReceived( stream_id=0, data=binascii.unhexlify( "00408d796f752072656163686564206d766673742e6e65742c20726561636820" "746865202f6563686f20656e64706f696e7420666f7220616e206563686f2072" "6573706f6e7365207175657279202f3c6e756d6265723e20656e64706f696e74" "7320666f722061207661726961626c652073697a6520726573706f6e73652077" "6974682072616e646f6d206279746573" ), end_stream=False, ) ), [ DataReceived( data=( b"you reached mvfst.net, reach the /echo endpoint for an " b"echo response query / endpoints for a variable " b"size response with random bytes" ), stream_id=0, stream_ended=False, ) ], ) self.assertEqual( h3_client.handle_event( StreamDataReceived( stream_id=0, data=binascii.unhexlify("0103028010"), end_stream=True ) ), [], ) self.assertEqual( h3_client.handle_event( StreamDataReceived( stream_id=7, data=binascii.unhexlify("6af2b20f49564d833505b38294e7"), end_stream=False, ) ), [ HeadersReceived( headers=[(b"x-some-trailer", b"foo")], stream_id=0, stream_ended=True, push_id=None, ) ], ) def test_uni_stream_grease(self): with h3_client_and_server() as (quic_client, quic_server): h3_server = H3Connection(quic_server) quic_client.send_stream_data( 14, b"\xff\xff\xff\xff\xff\xff\xff\xfeGREASE is the word" ) self.assertEqual(h3_transfer(quic_client, h3_server), []) def test_request_with_trailers(self): with h3_client_and_server() as (quic_client, quic_server): h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) # send request with trailers stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], end_stream=False, ) h3_client.send_headers( stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=True, ) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ], stream_id=stream_id, stream_ended=False, ), HeadersReceived( headers=[(b"x-some-trailer", b"foo")], stream_id=stream_id, stream_ended=True, ), ], ) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], end_stream=False, ) h3_server.send_data( stream_id=stream_id, data=b"hello", end_stream=False, ) h3_server.send_headers( stream_id=stream_id, headers=[(b"x-some-trailer", b"bar")], end_stream=True, ) # receive response events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ HeadersReceived( headers=[ (b":status", b"200"), (b"content-type", b"text/html; charset=utf-8"), ], stream_id=stream_id, stream_ended=False, ), DataReceived( data=b"hello", stream_id=stream_id, stream_ended=False, ), HeadersReceived( headers=[(b"x-some-trailer", b"bar")], stream_id=stream_id, stream_ended=True, ), ], ) def test_uni_stream_type(self): with h3_client_and_server() as (quic_client, quic_server): h3_server = H3Connection(quic_server) # unknown stream type 9 stream_id = quic_client.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 2) quic_client.send_stream_data(stream_id, b"\x09") self.assertEqual(h3_transfer(quic_client, h3_server), []) self.assertEqual(list(h3_server._stream.keys()), [2]) self.assertEqual(h3_server._stream[2].buffer, b"") self.assertEqual(h3_server._stream[2].stream_type, 9) # unknown stream type 64, one byte at a time stream_id = quic_client.get_next_available_stream_id(is_unidirectional=True) self.assertEqual(stream_id, 6) quic_client.send_stream_data(stream_id, b"\x40") self.assertEqual(h3_transfer(quic_client, h3_server), []) self.assertEqual(list(h3_server._stream.keys()), [2, 6]) self.assertEqual(h3_server._stream[2].buffer, b"") self.assertEqual(h3_server._stream[2].stream_type, 9) self.assertEqual(h3_server._stream[6].buffer, b"\x40") self.assertEqual(h3_server._stream[6].stream_type, None) quic_client.send_stream_data(stream_id, b"\x40") self.assertEqual(h3_transfer(quic_client, h3_server), []) self.assertEqual(list(h3_server._stream.keys()), [2, 6]) self.assertEqual(h3_server._stream[2].buffer, b"") self.assertEqual(h3_server._stream[2].stream_type, 9) self.assertEqual(h3_server._stream[6].buffer, b"") self.assertEqual(h3_server._stream[6].stream_type, 64) def test_validate_settings_h3_datagram_invalid_value(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS with an invalid H3_DATAGRAM value settings = copy.copy(DUMMY_SETTINGS) settings[Setting.H3_DATAGRAM] = 2 h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(settings)), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_SETTINGS_ERROR, "H3_DATAGRAM setting must be 0 or 1", ), ) def test_validate_settings_h3_datagram_without_transport_parameter(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS with H3_DATAGRAM=1 but no max_datagram_frame_size TP settings = copy.copy(DUMMY_SETTINGS) settings[Setting.H3_DATAGRAM] = 1 h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(settings)), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_SETTINGS_ERROR, "H3_DATAGRAM requires max_datagram_frame_size transport parameter", ), ) def test_validate_settings_enable_connect_protocol_invalid_value(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS with an invalid ENABLE_CONNECT_PROTOCOL value settings = copy.copy(DUMMY_SETTINGS) settings[Setting.ENABLE_CONNECT_PROTOCOL] = 2 h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(settings)), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_SETTINGS_ERROR, "ENABLE_CONNECT_PROTOCOL setting must be 0 or 1", ), ) def test_validate_settings_enable_webtransport_invalid_value(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS with an invalid ENABLE_WEBTRANSPORT value settings = copy.copy(DUMMY_SETTINGS) settings[Setting.ENABLE_WEBTRANSPORT] = 2 h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(settings)), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_SETTINGS_ERROR, "ENABLE_WEBTRANSPORT setting must be 0 or 1", ), ) def test_validate_settings_enable_webtransport_without_h3_datagram(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive SETTINGS requesting WebTransport, but DATAGRAM was not offered settings = copy.copy(DUMMY_SETTINGS) settings[Setting.ENABLE_WEBTRANSPORT] = 1 h3_server.handle_event( StreamDataReceived( stream_id=2, data=encode_uint_var(StreamType.CONTROL) + encode_frame(FrameType.SETTINGS, encode_settings(settings)), end_stream=False, ) ) self.assertEqual( quic_server.closed, ( ErrorCode.H3_SETTINGS_ERROR, "ENABLE_WEBTRANSPORT requires H3_DATAGRAM", ), ) def _content_length_template(self, content_length, data, end_at, expected_closed): quic_client = FakeQuicConnection( configuration=QuicConfiguration(is_client=True) ) quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_client = H3Connection(quic_client) h3_server = H3Connection(quic_server) end_stream = False stream_id = quic_client.get_next_available_stream_id() if end_at == "headers": end_stream = True # Headers headers = [ (b":method", b"GET"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), ] if content_length is not None: headers.append((b"content-length", content_length)) h3_client.send_headers( stream_id=stream_id, headers=headers, end_stream=end_stream, ) if not end_stream: # Data if end_at == "data": end_stream = True h3_client.send_data(stream_id=stream_id, data=data, end_stream=end_stream) if not end_stream: # Trailers assert end_at == "trailers" h3_client.send_headers( stream_id=stream_id, headers=[ (b"x-foo", b"hi"), ], end_stream=True, ) h3_transfer(quic_client, h3_server) self.assertEqual(quic_server.closed, expected_closed) def test_content_length_not_specified_is_ok(self): """ If a content-length is specified and OK, all is good. """ self._content_length_template(None, b"hello world", "data", None) def test_content_length_is_ok(self): """ If a content-length is specified and OK, all is good. """ self._content_length_template(b"11", b"hello world", "data", None) def test_content_length_that_is_not_an_int_rejected(self): """ A content-length that doesn't parse as an integer is bad. """ self._content_length_template( b"1bogus1", b"hello world", "data", ( ErrorCode(ErrorCode.H3_MESSAGE_ERROR), "content-length is not a non-negative integer", ), ) def test_content_length_that_is_negative_is_rejected(self): """ A content-length that doesn't parse as an integer is bad. """ self._content_length_template( b"-1", b"hello world", "data", ( ErrorCode(ErrorCode.H3_MESSAGE_ERROR), "content-length is not a non-negative integer", ), ) def test_content_length_does_not_match_data_length(self): """ If a content-length is specified and wrong when data ends the stream, we must close. """ self._content_length_template( b"20", b"hello world", "data", ( ErrorCode(ErrorCode.H3_MESSAGE_ERROR), "content-length does not match data size", ), ) def test_content_length_does_not_match_when_headers_end_stream(self): """ If a content-length is specified and wrong, we must reject it. """ self._content_length_template( b"20", b"", "headers", ( ErrorCode(ErrorCode.H3_MESSAGE_ERROR), "content-length does not match data size", ), ) def test_content_length_does_not_match_when_trailers_end_stream(self): """ If a content-length is specified and wrong, we must reject it. """ self._content_length_template( b"20", b"hello world", "trailers", ( ErrorCode(ErrorCode.H3_MESSAGE_ERROR), "content-length does not match data size", ), ) class H3ParserTest(TestCase): def test_parse_settings_duplicate_identifier(self): buf = Buffer(capacity=1024) buf.push_uint_var(1) buf.push_uint_var(123) buf.push_uint_var(1) buf.push_uint_var(456) with self.assertRaises(SettingsError) as cm: parse_settings(buf.data) self.assertEqual( cm.exception.reason_phrase, "Setting identifier 0x1 is included twice" ) def test_parse_settings_reserved_identifier(self): buf = Buffer(capacity=1024) buf.push_uint_var(0) buf.push_uint_var(123) with self.assertRaises(SettingsError) as cm: parse_settings(buf.data) self.assertEqual( cm.exception.reason_phrase, "Setting identifier 0x0 is reserved" ) def test_validate_push_promise_headers(self): # OK validate_push_promise_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b":authority", b"localhost"), ] ) validate_push_promise_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b":authority", b"localhost"), (b"x-foo", b"bar"), ] ) # invalid pseudo-header with self.assertRaises(MessageError) as cm: validate_push_promise_headers([(b":status", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':status' is not valid" ) # duplicate pseudo-header with self.assertRaises(MessageError) as cm: validate_push_promise_headers( [ (b":method", b"GET"), (b":method", b"POST"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':method' is included twice" ) # pseudo-header after regular headers with self.assertRaises(MessageError) as cm: validate_push_promise_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b"x-foo", b"bar"), (b":authority", b"foo"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':authority' is not allowed after regular headers", ) # missing pseudo-headers with self.assertRaises(MessageError) as cm: validate_push_promise_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-headers [b':authority'] are missing", ) def test_validate_request_headers(self): # OK validate_request_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b":authority", b"localhost"), ] ) validate_request_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b":authority", b"localhost"), (b"x-foo", b"bar"), ] ) # uppercase header with self.assertRaises(MessageError) as cm: validate_request_headers([(b"X-Foo", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'X-Foo' contains invalid characters" ) # header with too small a value with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-\x00foo", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-\\x00foo' contains invalid characters", ) # header with too big a value with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-\x7ffoo", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-\\x7ffoo' contains invalid characters", ) # header with non-initial colon with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-f:oo", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-f:oo' contains a non-initial colon" ) # good transfer-encoding; this test passes by not asserting with self.assertRaises(MessageError) as cm: validate_request_headers([(b"transfer-encoding", b"trailers")]) # bad transfer-encoding with self.assertRaises(MessageError) as cm: validate_request_headers([(b"transfer-encoding", b"bogus")]) self.assertEqual( cm.exception.reason_phrase, "The only valid value for transfer-encoding is trailers", ) # value with a forbidden NUL, LF, or CR: for prefix in [b"\x00", b"\x0a", b"\x0d"]: with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-foo", prefix + b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-foo' value has forbidden characters", ) # value with an initial TAB or SP: for prefix in [b"\x09", b"\x20"]: with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-foo", prefix + b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-foo' value starts with whitespace", ) # value with a final TAB or SP: for suffix in [b"\x09", b"\x20"]: with self.assertRaises(MessageError) as cm: validate_request_headers([(b"x-foo", b"foo" + suffix)]) self.assertEqual( cm.exception.reason_phrase, "Header b'x-foo' value ends with whitespace", ) # invalid pseudo-header with self.assertRaises(MessageError) as cm: validate_request_headers([(b":status", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':status' is not valid" ) # duplicate pseudo-header with self.assertRaises(MessageError) as cm: validate_request_headers( [ (b":method", b"GET"), (b":method", b"POST"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':method' is included twice" ) # pseudo-header after regular headers with self.assertRaises(MessageError) as cm: validate_request_headers( [ (b":method", b"GET"), (b":scheme", b"https"), (b":path", b"/"), (b"x-foo", b"bar"), (b":authority", b"foo"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':authority' is not allowed after regular headers", ) # missing pseudo-headers with self.assertRaises(MessageError) as cm: validate_request_headers([(b":method", b"GET")]) self.assertEqual( cm.exception.reason_phrase, "Pseudo-headers [b':authority'] are missing", ) # empty :authority pseudo-header for http/https for scheme in [b"http", b"https"]: with self.assertRaises(MessageError) as cm: validate_request_headers( [ (b":method", b"GET"), (b":scheme", scheme), (b":authority", b""), (b":path", b"/"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':authority' cannot be empty", ) # empty :path pseudo-header for http/https for scheme in [b"http", b"https"]: with self.assertRaises(MessageError) as cm: validate_request_headers( [ (b":method", b"GET"), (b":scheme", scheme), (b":authority", b"localhost"), (b":path", b""), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':path' cannot be empty" ) def test_validate_response_headers(self): # OK validate_response_headers([(b":status", b"200")]) validate_response_headers( [ (b":status", b"200"), (b"x-foo", b"bar"), ] ) # invalid pseudo-header with self.assertRaises(MessageError) as cm: validate_response_headers([(b":method", b"GET")]) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':method' is not valid" ) # duplicate pseudo-header with self.assertRaises(MessageError) as cm: validate_response_headers( [ (b":status", b"200"), (b":status", b"501"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':status' is included twice" ) def test_validate_trailers(self): # OK validate_trailers([(b"x-foo", b"bar")]) # invalid pseudo-header with self.assertRaises(MessageError) as cm: validate_trailers([(b":status", b"foo")]) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':status' is not valid" ) # pseudo-header after regular headers with self.assertRaises(MessageError) as cm: validate_trailers( [ (b"x-foo", b"bar"), (b":authority", b"foo"), ] ) self.assertEqual( cm.exception.reason_phrase, "Pseudo-header b':authority' is not allowed after regular headers", ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_logger.py0000644000175100001770000000325300000000000020007 0ustar00runnerdocker00000000000000import json import os import tempfile from unittest import TestCase from aioquic.quic.logger import QuicFileLogger, QuicLogger SINGLE_TRACE = { "qlog_format": "JSON", "qlog_version": "0.3", "traces": [ { "common_fields": { "ODCID": "0000000000000000", }, "events": [], "vantage_point": {"name": "aioquic", "type": "client"}, } ], } class QuicLoggerTest(TestCase): def test_empty(self): logger = QuicLogger() self.assertEqual( logger.to_dict(), {"qlog_format": "JSON", "qlog_version": "0.3", "traces": []}, ) def test_single_trace(self): logger = QuicLogger() trace = logger.start_trace(is_client=True, odcid=bytes(8)) logger.end_trace(trace) self.assertEqual(logger.to_dict(), SINGLE_TRACE) class QuicFileLoggerTest(TestCase): def test_invalid_path(self): with self.assertRaises(ValueError) as cm: QuicFileLogger("this_path_should_not_exist") self.assertEqual( str(cm.exception), "QUIC log output directory 'this_path_should_not_exist' does not exist", ) def test_single_trace(self): with tempfile.TemporaryDirectory() as dirpath: logger = QuicFileLogger(dirpath) trace = logger.start_trace(is_client=True, odcid=bytes(8)) logger.end_trace(trace) filepath = os.path.join(dirpath, "0000000000000000.qlog") self.assertTrue(os.path.exists(filepath)) with open(filepath, "r") as fp: data = json.load(fp) self.assertEqual(data, SINGLE_TRACE) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_packet.py0000644000175100001770000005127700000000000020010 0ustar00runnerdocker00000000000000import binascii from unittest import TestCase from aioquic.buffer import Buffer, BufferReadError from aioquic.quic import packet from aioquic.quic.packet import ( QuicPacketType, QuicPreferredAddress, QuicProtocolVersion, QuicTransportParameters, QuicVersionInformation, decode_packet_number, encode_quic_retry, encode_quic_version_negotiation, get_retry_integrity_tag, pull_quic_header, pull_quic_preferred_address, pull_quic_transport_parameters, push_quic_preferred_address, push_quic_transport_parameters, ) from .test_crypto_v1 import LONG_CLIENT_ENCRYPTED_PACKET as CLIENT_INITIAL_V1 from .test_crypto_v1 import LONG_SERVER_ENCRYPTED_PACKET as SERVER_INITIAL_V1 from .test_crypto_v2 import LONG_CLIENT_ENCRYPTED_PACKET as CLIENT_INITIAL_V2 from .test_crypto_v2 import LONG_SERVER_ENCRYPTED_PACKET as SERVER_INITIAL_V2 class PacketTest(TestCase): def test_decode_packet_number(self): # expected = 0 for i in range(0, 256): self.assertEqual(decode_packet_number(i, 8, expected=0), i) # expected = 128 self.assertEqual(decode_packet_number(0, 8, expected=128), 256) for i in range(1, 256): self.assertEqual(decode_packet_number(i, 8, expected=128), i) # expected = 129 self.assertEqual(decode_packet_number(0, 8, expected=129), 256) self.assertEqual(decode_packet_number(1, 8, expected=129), 257) for i in range(2, 256): self.assertEqual(decode_packet_number(i, 8, expected=129), i) # expected = 256 for i in range(0, 128): self.assertEqual(decode_packet_number(i, 8, expected=256), 256 + i) for i in range(129, 256): self.assertEqual(decode_packet_number(i, 8, expected=256), i) def test_pull_empty(self): buf = Buffer(data=b"") with self.assertRaises(BufferReadError): pull_quic_header(buf, host_cid_length=8) def test_pull_initial_client_v1(self): buf = Buffer(data=CLIENT_INITIAL_V1) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) self.assertEqual(header.packet_type, QuicPacketType.INITIAL) self.assertEqual(header.packet_length, 1200) self.assertEqual(header.destination_cid, binascii.unhexlify("8394c8f03e515708")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(buf.tell(), 18) def test_pull_initial_client_v1_truncated(self): buf = Buffer(data=CLIENT_INITIAL_V1[0:100]) with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Packet payload is truncated") def test_pull_initial_client_v2(self): buf = Buffer(data=CLIENT_INITIAL_V2) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) self.assertEqual(header.packet_type, QuicPacketType.INITIAL) self.assertEqual(header.packet_length, 1200) self.assertEqual(header.destination_cid, binascii.unhexlify("8394c8f03e515708")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(buf.tell(), 18) def test_pull_initial_server_v1(self): buf = Buffer(data=SERVER_INITIAL_V1) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) self.assertEqual(header.packet_type, QuicPacketType.INITIAL) self.assertEqual(header.packet_length, 135) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(buf.tell(), 18) def test_pull_initial_server_v2(self): buf = Buffer(data=SERVER_INITIAL_V2) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) self.assertEqual(header.packet_type, QuicPacketType.INITIAL) self.assertEqual(header.packet_length, 135) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(buf.tell(), 18) def test_pull_retry_v1(self): # https://datatracker.ietf.org/doc/html/rfc9001#appendix-A.4 original_destination_cid = binascii.unhexlify("8394c8f03e515708") data = binascii.unhexlify( "ff000000010008f067a5502a4262b5746f6b656e04a265ba2eff4d829058fb3f0f2496ba" ) buf = Buffer(data=data) header = pull_quic_header(buf) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) self.assertEqual(header.packet_type, QuicPacketType.RETRY) self.assertEqual(header.packet_length, 36) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) self.assertEqual(header.token, b"token") self.assertEqual( header.integrity_tag, binascii.unhexlify("04a265ba2eff4d829058fb3f0f2496ba") ) self.assertEqual(buf.tell(), 36) # check integrity self.assertEqual( get_retry_integrity_tag( buf.data_slice(0, 20), original_destination_cid, version=header.version ), header.integrity_tag, ) # serialize encoded = encode_quic_retry( version=header.version, source_cid=header.source_cid, destination_cid=header.destination_cid, original_destination_cid=original_destination_cid, retry_token=header.token, # This value is arbitrary, we set it to match the value in the RFC. unused=0xF, ) self.assertEqual(encoded, data) def test_pull_retry_v2(self): # https://datatracker.ietf.org/doc/html/rfc9369#appendix-A.4 original_destination_cid = binascii.unhexlify("8394c8f03e515708") data = binascii.unhexlify( "cf6b3343cf0008f067a5502a4262b5746f6b656ec8646ce8bfe33952d955543665dcc7b6" ) buf = Buffer(data=data) header = pull_quic_header(buf) self.assertEqual(header.version, QuicProtocolVersion.VERSION_2) self.assertEqual(header.packet_type, QuicPacketType.RETRY) self.assertEqual(header.packet_length, 36) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("f067a5502a4262b5")) self.assertEqual(header.token, b"token") self.assertEqual( header.integrity_tag, binascii.unhexlify("c8646ce8bfe33952d955543665dcc7b6") ) self.assertEqual(buf.tell(), 36) # check integrity self.assertEqual( get_retry_integrity_tag( buf.data_slice(0, 20), original_destination_cid, version=header.version ), header.integrity_tag, ) # serialize encoded = encode_quic_retry( version=header.version, source_cid=header.source_cid, destination_cid=header.destination_cid, original_destination_cid=original_destination_cid, retry_token=header.token, # This value is arbitrary, we set it to match the value in the RFC. unused=0xF, ) self.assertEqual(encoded, data) def test_pull_version_negotiation(self): data = binascii.unhexlify( "ea00000000089aac5a49ba87a84908f92f4336fa951ba14547471600000001" ) buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) self.assertEqual(header.packet_type, QuicPacketType.VERSION_NEGOTIATION) self.assertEqual(header.packet_length, 31) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual( header.supported_versions, [0x45474716, QuicProtocolVersion.VERSION_1] ) self.assertEqual(buf.tell(), 31) encoded = encode_quic_version_negotiation( destination_cid=header.destination_cid, source_cid=header.source_cid, supported_versions=header.supported_versions, ) # The first byte may differ as it is random. self.assertEqual(encoded[1:], data[1:]) def test_pull_long_header_dcid_too_long(self): buf = Buffer( data=binascii.unhexlify( "c6ff0000161500000000000000000000000000000000000000000000004" "01c514f99ec4bbf1f7a30f9b0c94fef717f1c1d07fec24c99a864da7ede" ) ) with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Destination CID is too long (21 bytes)") def test_pull_long_header_scid_too_long(self): buf = Buffer( data=binascii.unhexlify( "c2ff0000160015000000000000000000000000000000000000000000004" "01cfcee99ec4bbf1f7a30f9b0c9417b8c263cdd8cc972a4439d68a46320" ) ) with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Source CID is too long (21 bytes)") def test_pull_long_header_no_fixed_bit(self): buf = Buffer(data=b"\x80\xff\x00\x00\x11\x00\x00") with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Packet fixed bit is zero") def test_pull_long_header_too_short(self): buf = Buffer(data=b"\xc0\x00") with self.assertRaises(BufferReadError): pull_quic_header(buf, host_cid_length=8) def test_pull_short_header(self): buf = Buffer( data=binascii.unhexlify("5df45aa7b59c0e1ad6e668f5304cd4fd1fb3799327") ) header = pull_quic_header(buf, host_cid_length=8) self.assertEqual(header.version, None) self.assertEqual(header.packet_type, QuicPacketType.ONE_RTT) self.assertEqual(header.packet_length, 21) self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(buf.tell(), 9) def test_pull_short_header_no_fixed_bit(self): buf = Buffer(data=b"\x00") with self.assertRaises(ValueError) as cm: pull_quic_header(buf, host_cid_length=8) self.assertEqual(str(cm.exception), "Packet fixed bit is zero") class ParamsTest(TestCase): maxDiff = None def test_params(self): data = binascii.unhexlify( "010267100210cc2fd6e7d97a53ab5be85b28d75c8008030247e404048005fff" "a05048000ffff06048000ffff0801060a01030b0119" ) # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual( params, QuicTransportParameters( max_idle_timeout=10000, stateless_reset_token=b"\xcc/\xd6\xe7\xd9zS\xab[\xe8[(\xd7\\\x80\x08", max_udp_payload_size=2020, initial_max_data=393210, initial_max_stream_data_bidi_local=65535, initial_max_stream_data_bidi_remote=65535, initial_max_stream_data_uni=None, initial_max_streams_bidi=6, initial_max_streams_uni=None, ack_delay_exponent=3, max_ack_delay=25, ), ) # serialize buf = Buffer(capacity=len(data)) push_quic_transport_parameters(buf, params) self.assertEqual(len(buf.data), len(data)) def test_params_disable_active_migration(self): data = binascii.unhexlify("0c00") # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual(params, QuicTransportParameters(disable_active_migration=True)) # serialize buf = Buffer(capacity=len(data)) push_quic_transport_parameters(buf, params) self.assertEqual(buf.data, data) def test_params_max_ack_delay(self): data = binascii.unhexlify("0b010a") # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual(params, QuicTransportParameters(max_ack_delay=10)) # serialize buf = Buffer(capacity=len(data)) push_quic_transport_parameters(buf, params) self.assertEqual(buf.data, data) def test_params_max_ack_delay_length_mismatch(self): buf = Buffer(data=binascii.unhexlify("0b020a")) with self.assertRaises(ValueError) as cm: pull_quic_transport_parameters(buf) self.assertEqual(str(cm.exception), "Transport parameter length does not match") def test_params_preferred_address(self): data = binascii.unhexlify( "0d3b8ba27b8611532400890200000000f03c91fffe69a45411531262c4518d6" "3013f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" ) # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual( params, QuicTransportParameters( preferred_address=QuicPreferredAddress( ipv4_address=("139.162.123.134", 4435), ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435), connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", ), ), ) # serialize buf = Buffer(capacity=len(data)) push_quic_transport_parameters(buf, params) self.assertEqual(buf.data, data) def test_params_unknown(self): data = binascii.unhexlify("8000ff000100") # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual(params, QuicTransportParameters()) def test_params_version_information(self): data = binascii.unhexlify("110c00000001000000016b3343cf") # parse buf = Buffer(data=data) params = pull_quic_transport_parameters(buf) self.assertEqual( params, QuicTransportParameters( version_information=QuicVersionInformation( chosen_version=QuicProtocolVersion.VERSION_1, available_versions=[ QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2, ], ), ), ) # serialize buf = Buffer(capacity=len(data)) push_quic_transport_parameters(buf, params) self.assertEqual(buf.data, data) def test_params_version_information_available_version_0(self): buf = Buffer(data=binascii.unhexlify("11080000000100000000")) with self.assertRaises(ValueError) as cm: pull_quic_transport_parameters(buf) self.assertEqual( str(cm.exception), "Version Information must not contain version 0" ) def test_params_version_information_chosen_version_0(self): buf = Buffer(data=binascii.unhexlify("110400000000")) with self.assertRaises(ValueError) as cm: pull_quic_transport_parameters(buf) self.assertEqual( str(cm.exception), "Version Information must not contain version 0" ) def test_params_version_information_length_not_divisible_by_four(self): buf = Buffer(data=binascii.unhexlify("11050000000100")) with self.assertRaises(ValueError) as cm: pull_quic_transport_parameters(buf) self.assertEqual(str(cm.exception), "Transport parameter length does not match") def test_params_version_information_truncated(self): buf = Buffer(data=binascii.unhexlify("110800000000")) with self.assertRaises(ValueError) as cm: pull_quic_transport_parameters(buf) self.assertEqual(str(cm.exception), "Read out of bounds") def test_preferred_address_ipv4_only(self): data = binascii.unhexlify( "8ba27b8611530000000000000000000000000000000000001262c4518d63013" "f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" ) # parse buf = Buffer(data=data) preferred_address = pull_quic_preferred_address(buf) self.assertEqual( preferred_address, QuicPreferredAddress( ipv4_address=("139.162.123.134", 4435), ipv6_address=None, connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", ), ) # serialize buf = Buffer(capacity=len(data)) push_quic_preferred_address(buf, preferred_address) self.assertEqual(buf.data, data) def test_preferred_address_ipv6_only(self): data = binascii.unhexlify( "0000000000002400890200000000f03c91fffe69a45411531262c4518d63013" "f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" ) # parse buf = Buffer(data=data) preferred_address = pull_quic_preferred_address(buf) self.assertEqual( preferred_address, QuicPreferredAddress( ipv4_address=None, ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435), connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", ), ) # serialize buf = Buffer(capacity=len(data)) push_quic_preferred_address(buf, preferred_address) self.assertEqual(buf.data, data) class FrameTest(TestCase): def test_ack_frame(self): data = b"\x00\x02\x00\x00" # parse buf = Buffer(data=data) rangeset, delay = packet.pull_ack_frame(buf) self.assertEqual(list(rangeset), [range(0, 1)]) self.assertEqual(delay, 2) # serialize buf = Buffer(capacity=len(data)) packet.push_ack_frame(buf, rangeset, delay) self.assertEqual(buf.data, data) def test_ack_frame_with_one_range(self): data = b"\x02\x02\x01\x00\x00\x00" # parse buf = Buffer(data=data) rangeset, delay = packet.pull_ack_frame(buf) self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)]) self.assertEqual(delay, 2) # serialize buf = Buffer(capacity=len(data)) packet.push_ack_frame(buf, rangeset, delay) self.assertEqual(buf.data, data) def test_ack_frame_with_one_range_2(self): data = b"\x05\x02\x01\x00\x00\x03" # parse buf = Buffer(data=data) rangeset, delay = packet.pull_ack_frame(buf) self.assertEqual(list(rangeset), [range(0, 4), range(5, 6)]) self.assertEqual(delay, 2) # serialize buf = Buffer(capacity=len(data)) packet.push_ack_frame(buf, rangeset, delay) self.assertEqual(buf.data, data) def test_ack_frame_with_one_range_3(self): data = b"\x05\x02\x01\x00\x01\x02" # parse buf = Buffer(data=data) rangeset, delay = packet.pull_ack_frame(buf) self.assertEqual(list(rangeset), [range(0, 3), range(5, 6)]) self.assertEqual(delay, 2) # serialize buf = Buffer(capacity=len(data)) packet.push_ack_frame(buf, rangeset, delay) self.assertEqual(buf.data, data) def test_ack_frame_with_two_ranges(self): data = b"\x04\x02\x02\x00\x00\x00\x00\x00" # parse buf = Buffer(data=data) rangeset, delay = packet.pull_ack_frame(buf) self.assertEqual(list(rangeset), [range(0, 1), range(2, 3), range(4, 5)]) self.assertEqual(delay, 2) # serialize buf = Buffer(capacity=len(data)) packet.push_ack_frame(buf, rangeset, delay) self.assertEqual(buf.data, data) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_packet_builder.py0000644000175100001770000006270700000000000021516 0ustar00runnerdocker00000000000000from typing import List from unittest import TestCase from aioquic.quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE from aioquic.quic.crypto import CryptoPair from aioquic.quic.packet import QuicFrameType, QuicPacketType, QuicProtocolVersion from aioquic.quic.packet_builder import ( QuicPacketBuilder, QuicPacketBuilderStop, QuicSentPacket, ) from aioquic.tls import Epoch def create_builder(is_client=False): return QuicPacketBuilder( host_cid=bytes(8), is_client=is_client, max_datagram_size=SMALLEST_MAX_DATAGRAM_SIZE, packet_number=0, peer_cid=bytes(8), peer_token=b"", spin_bit=False, version=QuicProtocolVersion.VERSION_1, ) def create_crypto(): crypto = CryptoPair() crypto.setup_initial( bytes(8), is_client=True, version=QuicProtocolVersion.VERSION_1 ) return crypto def datagram_sizes(datagrams: List[bytes]) -> List[int]: return [len(x) for x in datagrams] class QuicPacketBuilderTest(TestCase): def test_long_header_empty(self): builder = create_builder() crypto = create_crypto() builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 0) self.assertEqual(packets, []) # check builder self.assertEqual(builder.packet_number, 0) def test_long_header_initial_client(self): builder = create_builder(is_client=True) crypto = create_crypto() # INITIAL, fully padded builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagram_sizes(datagrams), [1200]) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=145, ) ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_long_header_initial_client_2(self): builder = create_builder(is_client=True) crypto = create_crypto() # INITIAL, full length builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL, full length builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagram_sizes(datagrams), [1200, 1200]) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, packet_type=QuicPacketType.INITIAL, sent_bytes=145, ), ], ) # check builder self.assertEqual(builder.packet_number, 2) def test_long_header_initial_client_zero_rtt(self): builder = create_builder(is_client=True) crypto = create_crypto() # INITIAL builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(613)) self.assertFalse(builder.packet_is_empty) # 0-RTT builder.start_packet(QuicPacketType.ZERO_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 499) buf = builder.start_frame(QuicFrameType.STREAM_BASE) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagram_sizes(datagrams), [1200]) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=658, ), QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=1, packet_type=QuicPacketType.ZERO_RTT, sent_bytes=144, ), ], ) # check builder self.assertEqual(builder.packet_number, 2) def test_long_header_initial_server(self): builder = create_builder() crypto = create_crypto() # INITIAL with ACK + CRYPTO + PADDING builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.ACK) buf.push_bytes(bytes(16)) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # HANDSHAKE with CRYPTO builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 995) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(994)) self.assertFalse(builder.packet_is_empty) # HANDSHAKE with CRYPTO builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 1157) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(800)) self.assertFalse(builder.packet_is_empty) # HANDSHAKE, empty builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagram_sizes(datagrams), [1200, 844]) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=162, ), QuicSentPacket( epoch=Epoch.HANDSHAKE, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, packet_type=QuicPacketType.HANDSHAKE, sent_bytes=1038, ), QuicSentPacket( epoch=Epoch.HANDSHAKE, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=2, packet_type=QuicPacketType.HANDSHAKE, sent_bytes=844, ), ], ) # check builder self.assertEqual(builder.packet_number, 3) def test_long_header_initial_server_without_handshake(self): builder = create_builder() crypto = create_crypto() # INITIAL builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # HANDSHAKE, empty builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagram_sizes(datagrams), [1200]) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=145, ) ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_long_header_ping_only(self): """ The payload is too short to provide enough data for header protection, so padding needs to be applied. """ builder = create_builder() crypto = create_crypto() # HANDSHAKE, with only a PING frame builder.start_packet(QuicPacketType.HANDSHAKE, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 45) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.HANDSHAKE, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.HANDSHAKE, sent_bytes=45, ) ], ) def test_long_header_then_short_header(self): builder = create_builder() crypto = create_crypto() # INITIAL, full length builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # ONE_RTT, full length builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.STREAM_BASE) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # ONE_RTT, empty builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 2) self.assertEqual(len(datagrams[0]), 1200) self.assertEqual(len(datagrams[1]), 1200) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=1, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ), ], ) # check builder self.assertEqual(builder.packet_number, 2) def test_long_header_then_long_header_then_short_header(self): builder = create_builder() crypto = create_crypto() # INITIAL builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(199)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # HANDSHAKE builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 913) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(299)) self.assertFalse(builder.packet_is_empty) self.assertEqual(builder.remaining_flight_space, 613) # HANDSHAKE, empty builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertTrue(builder.packet_is_empty) # ONE_RTT, padded builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 586) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(300)) self.assertFalse(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 1200) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=244, ), QuicSentPacket( epoch=Epoch.HANDSHAKE, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, packet_type=QuicPacketType.HANDSHAKE, sent_bytes=343, ), QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=2, packet_type=QuicPacketType.ONE_RTT, sent_bytes=613, # includes padding ), ], ) # check builder self.assertEqual(builder.packet_number, 3) def test_short_header_empty(self): builder = create_builder() crypto = create_crypto() builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) self.assertTrue(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagrams, []) self.assertEqual(packets, []) # check builder self.assertEqual(builder.packet_number, 0) def test_short_header_full_length(self): builder = create_builder() crypto = create_crypto() # ONE_RTT, full length builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 1200) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ) ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_short_header_max_flight_bytes(self): """ max_flight_bytes limits sent data. """ builder = create_builder() builder.max_flight_bytes = 1000 crypto = create_crypto() builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 973) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 1000) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1000, ), ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_short_header_max_flight_bytes_zero(self): """ max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE. Check CRYPTO is not allowed. """ builder = create_builder() builder.max_flight_bytes = 0 crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 0) # check builder self.assertEqual(builder.packet_number, 0) def test_short_header_max_flight_bytes_zero_ack(self): """ max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE. Check ACK is allowed. """ builder = create_builder() builder.max_flight_bytes = 0 crypto = create_crypto() builder.start_packet(QuicPacketType.ONE_RTT, crypto) buf = builder.start_frame(QuicFrameType.ACK) buf.push_bytes(bytes(64)) with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 92) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=False, is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=92, ), ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_short_header_max_total_bytes_1(self): """ max_total_bytes doesn't allow any packets. """ builder = create_builder() builder.max_total_bytes = 11 crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() self.assertEqual(datagrams, []) self.assertEqual(packets, []) # check builder self.assertEqual(builder.packet_number, 0) def test_short_header_max_total_bytes_2(self): """ max_total_bytes allows a short packet. """ builder = create_builder() builder.max_total_bytes = 800 crypto = create_crypto() builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 773) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 800) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=800, ) ], ) # check builder self.assertEqual(builder.packet_number, 1) def test_short_header_max_total_bytes_3(self): builder = create_builder() builder.max_total_bytes = 2000 crypto = create_crypto() builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 773) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 2) self.assertEqual(len(datagrams[0]), 1200) self.assertEqual(len(datagrams[1]), 800) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ), QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, packet_type=QuicPacketType.ONE_RTT, sent_bytes=800, ), ], ) # check builder self.assertEqual(builder.packet_number, 2) def test_short_header_ping_only(self): """ The payload is too short to provide enough data for header protection, so padding needs to be applied. """ builder = create_builder() crypto = create_crypto() # HANDSHAKE, with only a PING frame builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) # check datagrams datagrams, packets = builder.flush() self.assertEqual(len(datagrams), 1) self.assertEqual(len(datagrams[0]), 29) self.assertEqual( packets, [ QuicSentPacket( epoch=Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=29, ) ], ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_rangeset.py0000644000175100001770000001537500000000000020350 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.quic.rangeset import RangeSet class RangeSetTest(TestCase): def test_add_single_duplicate(self): rangeset = RangeSet() rangeset.add(0) self.assertEqual(list(rangeset), [range(0, 1)]) rangeset.add(0) self.assertEqual(list(rangeset), [range(0, 1)]) def test_add_single_ordered(self): rangeset = RangeSet() rangeset.add(0) self.assertEqual(list(rangeset), [range(0, 1)]) rangeset.add(1) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(2) self.assertEqual(list(rangeset), [range(0, 3)]) def test_add_single_merge(self): rangeset = RangeSet() rangeset.add(0) self.assertEqual(list(rangeset), [range(0, 1)]) rangeset.add(2) self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)]) rangeset.add(1) self.assertEqual(list(rangeset), [range(0, 3)]) def test_add_single_reverse(self): rangeset = RangeSet() rangeset.add(2) self.assertEqual(list(rangeset), [range(2, 3)]) rangeset.add(1) self.assertEqual(list(rangeset), [range(1, 3)]) rangeset.add(0) self.assertEqual(list(rangeset), [range(0, 3)]) def test_add_range_ordered(self): rangeset = RangeSet() rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(2, 4) self.assertEqual(list(rangeset), [range(0, 4)]) rangeset.add(4, 6) self.assertEqual(list(rangeset), [range(0, 6)]) def test_add_range_merge(self): rangeset = RangeSet() rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(3, 5) self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)]) rangeset.add(2, 3) self.assertEqual(list(rangeset), [range(0, 5)]) def test_add_range_overlap(self): rangeset = RangeSet() rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(3, 5) self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)]) rangeset.add(1, 5) self.assertEqual(list(rangeset), [range(0, 5)]) def test_add_range_overlap_2(self): rangeset = RangeSet() rangeset.add(2, 4) rangeset.add(6, 8) rangeset.add(10, 12) rangeset.add(16, 18) self.assertEqual( list(rangeset), [range(2, 4), range(6, 8), range(10, 12), range(16, 18)] ) rangeset.add(1, 15) self.assertEqual(list(rangeset), [range(1, 15), range(16, 18)]) def test_add_range_reverse(self): rangeset = RangeSet() rangeset.add(6, 8) self.assertEqual(list(rangeset), [range(6, 8)]) rangeset.add(3, 5) self.assertEqual(list(rangeset), [range(3, 5), range(6, 8)]) rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)]) def test_add_range_unordered_contiguous(self): rangeset = RangeSet() rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(4, 6) self.assertEqual(list(rangeset), [range(0, 2), range(4, 6)]) rangeset.add(2, 4) self.assertEqual(list(rangeset), [range(0, 6)]) def test_add_range_unordered_sparse(self): rangeset = RangeSet() rangeset.add(0, 2) self.assertEqual(list(rangeset), [range(0, 2)]) rangeset.add(6, 8) self.assertEqual(list(rangeset), [range(0, 2), range(6, 8)]) rangeset.add(3, 5) self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)]) def test_subtract(self): rangeset = RangeSet() rangeset.add(0, 10) rangeset.add(20, 30) rangeset.subtract(0, 3) self.assertEqual(list(rangeset), [range(3, 10), range(20, 30)]) def test_subtract_no_change(self): rangeset = RangeSet() rangeset.add(5, 10) rangeset.add(15, 20) rangeset.add(25, 30) rangeset.subtract(0, 5) self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)]) rangeset.subtract(10, 15) self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)]) def test_subtract_overlap(self): rangeset = RangeSet() rangeset.add(1, 4) rangeset.add(6, 8) rangeset.add(10, 20) rangeset.add(30, 40) self.assertEqual( list(rangeset), [range(1, 4), range(6, 8), range(10, 20), range(30, 40)] ) rangeset.subtract(0, 2) self.assertEqual( list(rangeset), [range(2, 4), range(6, 8), range(10, 20), range(30, 40)] ) rangeset.subtract(3, 11) self.assertEqual(list(rangeset), [range(2, 3), range(11, 20), range(30, 40)]) def test_subtract_split(self): rangeset = RangeSet() rangeset.add(0, 10) rangeset.subtract(2, 5) self.assertEqual(list(rangeset), [range(0, 2), range(5, 10)]) def test_bool(self): with self.assertRaises(NotImplementedError): bool(RangeSet()) def test_contains(self): rangeset = RangeSet() self.assertFalse(0 in rangeset) rangeset = RangeSet([range(0, 1)]) self.assertTrue(0 in rangeset) self.assertFalse(1 in rangeset) rangeset = RangeSet([range(0, 1), range(3, 6)]) self.assertTrue(0 in rangeset) self.assertFalse(1 in rangeset) self.assertFalse(2 in rangeset) self.assertTrue(3 in rangeset) self.assertTrue(4 in rangeset) self.assertTrue(5 in rangeset) self.assertFalse(6 in rangeset) def test_eq(self): r0 = RangeSet([range(0, 1)]) r1 = RangeSet([range(1, 2), range(3, 4)]) r2 = RangeSet([range(3, 4), range(1, 2)]) self.assertTrue(r0 == r0) self.assertFalse(r0 == r1) self.assertFalse(r0 == 0) self.assertTrue(r1 == r1) self.assertFalse(r1 == r0) self.assertTrue(r1 == r2) self.assertFalse(r1 == 0) self.assertTrue(r2 == r2) self.assertTrue(r2 == r1) self.assertFalse(r2 == r0) self.assertFalse(r2 == 0) def test_len(self): rangeset = RangeSet() self.assertEqual(len(rangeset), 0) rangeset = RangeSet([range(0, 1)]) self.assertEqual(len(rangeset), 1) def test_pop(self): rangeset = RangeSet([range(1, 2), range(3, 4)]) r = rangeset.shift() self.assertEqual(r, range(1, 2)) self.assertEqual(list(rangeset), [range(3, 4)]) def test_repr(self): rangeset = RangeSet([range(1, 2), range(3, 4)]) self.assertEqual(repr(rangeset), "RangeSet([range(1, 2), range(3, 4)])") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_recovery.py0000644000175100001770000001067000000000000020367 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.quic.congestion.base import QuicRttMonitor, create_congestion_control from aioquic.quic.recovery import QuicPacketPacer class QuicCongestionControlTest(TestCase): def test_create_unknown_congestion_control(self): with self.assertRaises(Exception) as cm: create_congestion_control("bogus", max_datagram_size=1280) self.assertEqual( str(cm.exception), "Unknown congestion control algorithm: bogus" ) class QuicPacketPacerTest(TestCase): def setUp(self): self.pacer = QuicPacketPacer(max_datagram_size=1280) def test_no_measurement(self): self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) def test_with_measurement(self): self.assertIsNone(self.pacer.next_send_time(now=0.0)) self.pacer.update_after_send(now=0.0) self.pacer.update_rate(congestion_window=1280000, smoothed_rtt=0.05) self.assertEqual(self.pacer.bucket_max, 0.0008) self.assertEqual(self.pacer.bucket_time, 0.0) self.assertEqual(self.pacer.packet_time, 0.00005) # 16 packets for i in range(16): self.assertIsNone(self.pacer.next_send_time(now=1.0)) self.pacer.update_after_send(now=1.0) self.assertAlmostEqual(self.pacer.next_send_time(now=1.0), 1.00005) # 2 packets for i in range(2): self.assertIsNone(self.pacer.next_send_time(now=1.00005)) self.pacer.update_after_send(now=1.00005) self.assertAlmostEqual(self.pacer.next_send_time(now=1.00005), 1.0001) # 1 packet self.assertIsNone(self.pacer.next_send_time(now=1.0001)) self.pacer.update_after_send(now=1.0001) self.assertAlmostEqual(self.pacer.next_send_time(now=1.0001), 1.00015) # 2 packets for i in range(2): self.assertIsNone(self.pacer.next_send_time(now=1.00015)) self.pacer.update_after_send(now=1.00015) self.assertAlmostEqual(self.pacer.next_send_time(now=1.00015), 1.0002) class QuicRttMonitorTest(TestCase): def test_monitor(self): monitor = QuicRttMonitor() self.assertFalse(monitor.is_rtt_increasing(rtt=10, now=1000)) self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) # not taken into account self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1000)) self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1001)) self.assertEqual(monitor._samples, [10, 11, 0.0, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=12, now=1002)) self.assertEqual(monitor._samples, [10, 11, 12, 0.0, 0.0]) self.assertFalse(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=13, now=1003)) self.assertEqual(monitor._samples, [10, 11, 12, 13, 0.0]) self.assertFalse(monitor._ready) # we now have enough samples self.assertFalse(monitor.is_rtt_increasing(rtt=14, now=1004)) self.assertEqual(monitor._samples, [10, 11, 12, 13, 14]) self.assertTrue(monitor._ready) self.assertFalse(monitor.is_rtt_increasing(rtt=20, now=1005)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=30, now=1006)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=40, now=1007)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=50, now=1008)) self.assertEqual(monitor._increases, 0) self.assertFalse(monitor.is_rtt_increasing(rtt=60, now=1009)) self.assertEqual(monitor._increases, 1) self.assertFalse(monitor.is_rtt_increasing(rtt=70, now=1010)) self.assertEqual(monitor._increases, 2) self.assertFalse(monitor.is_rtt_increasing(rtt=80, now=1011)) self.assertEqual(monitor._increases, 3) self.assertFalse(monitor.is_rtt_increasing(rtt=90, now=1012)) self.assertEqual(monitor._increases, 4) self.assertTrue(monitor.is_rtt_increasing(rtt=100, now=1013)) self.assertEqual(monitor._increases, 5) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_recovery_cubic.py0000644000175100001770000004561400000000000021542 0ustar00runnerdocker00000000000000import math from unittest import TestCase from aioquic import tls from aioquic.quic.congestion.base import K_INITIAL_WINDOW, K_MINIMUM_WINDOW from aioquic.quic.congestion.cubic import ( K_CUBIC_C, K_CUBIC_LOSS_REDUCTION_FACTOR, CubicCongestionControl, better_cube_root, ) from aioquic.quic.packet import QuicPacketType from aioquic.quic.packet_builder import QuicSentPacket from aioquic.quic.rangeset import RangeSet from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace def send_probe(): pass def W_cubic(t, K, W_max): return K_CUBIC_C * (t - K) ** 3 + (W_max) class QuicPacketRecoveryCubicTest(TestCase): def setUp(self): self.INITIAL_SPACE = QuicPacketSpace() self.HANDSHAKE_SPACE = QuicPacketSpace() self.ONE_RTT_SPACE = QuicPacketSpace() self.recovery = QuicPacketRecovery( congestion_control_algorithm="cubic", initial_rtt=0.1, max_datagram_size=1280, peer_completed_address_validation=True, send_probe=send_probe, ) self.recovery.spaces = [ self.INITIAL_SPACE, self.HANDSHAKE_SPACE, self.ONE_RTT_SPACE, ] def test_better_cube_root(self): self.assertAlmostEqual(better_cube_root(8), 2) self.assertAlmostEqual(better_cube_root(-8), -2) self.assertAlmostEqual(better_cube_root(0), 0) self.assertAlmostEqual(better_cube_root(27), 3) def test_discard_space(self): self.recovery.discard_space(self.INITIAL_SPACE) def test_on_ack_received_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0, space=space, ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertTrue(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 10.0) self.assertEqual(self.recovery._rtt_min, 10.0) self.assertEqual(self.recovery._rtt_smoothed, 10.0) def test_on_ack_received_non_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=123.45, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0, space=space, ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertFalse(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 0.0) self.assertEqual(self.recovery._rtt_min, math.inf) self.assertEqual(self.recovery._rtt_smoothed, 0.0) def test_on_packet_lost_crypto(self): packet = QuicSentPacket( epoch=tls.Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=1280, sent_time=0.0, ) space = self.INITIAL_SPACE self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) self.recovery._detect_loss(space=space, now=1.0) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) def test_packet_expired(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) cubic = CubicCongestionControl(1440) cubic.on_packet_sent(packet=packet) cubic.on_packets_expired(packets=[packet]) self.assertEqual(cubic.bytes_in_flight, 0) def test_log_data(self): cubic = CubicCongestionControl(1440) self.assertEqual( cubic.get_log_data(), { "cwnd": cubic.congestion_window, "bytes_in_flight": cubic.bytes_in_flight, "cubic-wmax": cubic._W_max, }, ) cubic._W_max = 5000 cubic.ssthresh = 5000 self.assertEqual( cubic.get_log_data(), { "cwnd": cubic.congestion_window, "ssthresh": cubic.ssthresh, "bytes_in_flight": cubic.bytes_in_flight, "cubic-wmax": cubic._W_max, }, ) def test_congestion_avoidance(self): """ Check if the cubic implementation respects the mathematical formula defined in the rfc 9438 """ max_datagram_size = 1440 n = 400 # number of ms to check W_max = 5 # starting W_max K = better_cube_root(W_max * (1 - K_CUBIC_LOSS_REDUCTION_FACTOR) / K_CUBIC_C) cwnd = W_max * K_CUBIC_LOSS_REDUCTION_FACTOR correct = [] test_range = range(n) for i in test_range: correct.append(W_cubic(i / 1000, K, W_max) * max_datagram_size) cubic = CubicCongestionControl(max_datagram_size) cubic.rtt = 0 cubic._W_max = W_max * max_datagram_size cubic._starting_congestion_avoidance = True cubic.congestion_window = cwnd * max_datagram_size cubic.ssthresh = cubic.congestion_window cubic._W_est = 0 results = [] for i in test_range: cwnd = cubic.congestion_window // max_datagram_size # number of segments # simulate the reception of cwnd packets (a full window of acks) for _ in range(int(cwnd)): packet = QuicSentPacket(None, True, True, True, 0, 0) packet.sent_bytes = 0 # won't affect results cubic.on_packet_acked(packet=packet, now=(i / 1000)) results.append(cubic.congestion_window) for i in test_range: # check if it is almost equal to the value of W_cubic self.assertTrue( correct[i] * 0.99 <= results[i] <= 1.01 * correct[i], f"Error at {i}ms, Result={results[i]}, Expected={correct[i]}", ) def test_reset_idle(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=10.0, ) max_datagram_size = 1440 cubic = CubicCongestionControl(1440) # set last received at time 1 cubic.last_ack = 1 # receive a packet after 9s of idle time cubic.on_packet_sent(packet=packet) cubic.on_packets_expired(packets=[packet]) self.assertEqual(cubic.congestion_window, K_INITIAL_WINDOW * max_datagram_size) self.assertIsNone(cubic.ssthresh) self.assertTrue(cubic._first_slow_start) self.assertFalse(cubic._starting_congestion_avoidance) self.assertEqual(cubic.K, 0.0) self.assertEqual(cubic._W_est, 0) self.assertEqual(cubic._cwnd_epoch, 0) self.assertEqual(cubic._t_epoch, 0.0) self.assertEqual(cubic._W_max, K_INITIAL_WINDOW * max_datagram_size) def test_reno_friendly_region(self): cubic = CubicCongestionControl(1440) cubic._W_max = 5000 # set the target number of bytes to 5000 cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion cubic._starting_congestion_avoidance = False cubic._first_slow_start = False cubic.ssthresh = 2880 cubic._t_epoch = 5 # set an arbitrarily high W_est, # meaning that cubic would underperform compared to reno cubic._W_est = 100000 # calculate K W_max_segments = cubic._W_max / cubic._max_datagram_size cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) previous_cwnd = cubic.congestion_window cubic.on_packet_acked(now=10, packet=packet) # congestion window should be equal to W_est (Reno estimated window) self.assertAlmostEqual( cubic.congestion_window, 100000 + cubic.additive_increase_factor * (packet.sent_bytes / previous_cwnd), ) def test_convex_region(self): cubic = CubicCongestionControl(1440) cubic._W_max = 5000 # set the target number of bytes to 5000 cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion cubic._starting_congestion_avoidance = False cubic._first_slow_start = False cubic.ssthresh = 2880 cubic._t_epoch = 5 cubic._W_est = 0 # calculate K W_max_segments = cubic._W_max / cubic._max_datagram_size cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) previous_cwnd = cubic.congestion_window cubic.on_packet_acked(now=10, packet=packet) # elapsed time + basic rtt target = int(previous_cwnd * 1.5) expected = int( previous_cwnd + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) ) # congestion window should be equal to W_est (Reno estimated window) self.assertAlmostEqual(cubic.congestion_window, expected) def test_concave_region(self): cubic = CubicCongestionControl(1440) cubic._W_max = 25000 # set the target number of bytes to 25000 cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion cubic._starting_conges2ion_avoidance = False cubic._first_slow_start = False cubic.ssthresh = 2880 cubic._t_epoch = 5 cubic._W_est = 0 # calculate K W_max_segments = cubic._W_max / cubic._max_datagram_size cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) previous_cwnd = cubic.congestion_window cubic.on_packet_acked(now=6, packet=packet) # elapsed time + basic rtt target = cubic.W_cubic(1 + 0.02) expected = int( previous_cwnd + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) ) self.assertAlmostEqual(cubic.congestion_window, expected) def test_increasing_rtt(self): cubic = CubicCongestionControl(1440) # get some low rtt for i in range(10): cubic.on_rtt_measurement(now=i + 1, rtt=1) # rtt increase (because of congestion for example) for i in range(10): cubic.on_rtt_measurement(now=100 + i, rtt=1000) self.assertEqual(cubic.ssthresh, cubic.congestion_window) def test_increasing_rtt_exiting_slow_start(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) cubic = CubicCongestionControl(1440) # get some low rtt for i in range(10): cubic.on_rtt_measurement(now=i + 1, rtt=1) # rtt increase (because of congestion for example) for i in range(10): cubic.on_rtt_measurement(now=100 + i, rtt=1000) previous_cwnd = cubic.congestion_window self.assertFalse(cubic._starting_congestion_avoidance) cubic.on_packet_acked(packet=packet, now=220) self.assertFalse(cubic._first_slow_start) self.assertEqual(cubic._W_max, previous_cwnd) self.assertEqual(cubic._t_epoch, 220) self.assertEqual(cubic._cwnd_epoch, previous_cwnd) self.assertEqual( cubic._W_est, previous_cwnd + cubic.additive_increase_factor * (packet.sent_bytes / previous_cwnd), ) # calculate K W_max_segments = previous_cwnd / cubic._max_datagram_size cwnd_epoch_segments = previous_cwnd / cubic._max_datagram_size K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) self.assertEqual(cubic.K, K) def test_packet_lost(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) packet2 = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=240.0, ) cubic = CubicCongestionControl(1440) previous_cwnd = cubic.congestion_window cubic.on_packets_lost(now=210, packets=[packet]) self.assertEqual(cubic._congestion_recovery_start_time, 210) self.assertEqual(cubic._W_max, previous_cwnd) self.assertEqual(cubic.ssthresh, K_MINIMUM_WINDOW * cubic._max_datagram_size) self.assertEqual( cubic.congestion_window, K_MINIMUM_WINDOW * cubic._max_datagram_size ) self.assertTrue(cubic._starting_congestion_avoidance) previous_cwnd = cubic.congestion_window W_max = cubic._W_max cubic.on_packet_acked(now=250, packet=packet) self.assertFalse(cubic._starting_congestion_avoidance) self.assertFalse(cubic._first_slow_start) self.assertEqual(cubic._t_epoch, 250) self.assertEqual(cubic._cwnd_epoch, previous_cwnd) self.assertEqual( cubic._W_est, previous_cwnd + cubic.additive_increase_factor * (packet2.sent_bytes / previous_cwnd), ) # calculate K W_max_segments = W_max / cubic._max_datagram_size cwnd_epoch_segments = previous_cwnd / cubic._max_datagram_size K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) self.assertEqual(cubic.K, K) def test_lost_with_W_max(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) cubic = CubicCongestionControl(1440) cubic._W_max = 100000 previous_cwnd = cubic.congestion_window cubic.on_packets_lost(now=210, packets=[packet]) # test when W_max was much more than cwnd # and a loss occur self.assertEqual( cubic._W_max, previous_cwnd * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2 ) def test_cwnd_target(self): cubic = CubicCongestionControl(1440) cubic._W_max = 25000 # set the target number of bytes to 25000 cubic._cwnd_epoch = 2880 # a cwnd of 1440 bytes when we had congestion cubic._starting_conges2ion_avoidance = False cubic._first_slow_start = False cubic.ssthresh = 2880 cubic._t_epoch = 5 cubic.congestion_window = 100000 cubic._W_est = 0 # calculate K W_max_segments = cubic._W_max / cubic._max_datagram_size cwnd_epoch_segments = cubic._cwnd_epoch / cubic._max_datagram_size cubic.K = better_cube_root((W_max_segments - cwnd_epoch_segments) / K_CUBIC_C) packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) previous_cwnd = cubic.congestion_window cubic.on_packet_acked(now=6, packet=packet) # elapsed time + basic rtt target = previous_cwnd expected = int( previous_cwnd + ((target - previous_cwnd) * (cubic._max_datagram_size / previous_cwnd)) ) self.assertAlmostEqual(cubic.congestion_window, expected) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_recovery_reno.py0000644000175100001770000001052600000000000021412 0ustar00runnerdocker00000000000000import math from unittest import TestCase from aioquic import tls from aioquic.quic.packet import QuicPacketType from aioquic.quic.packet_builder import QuicSentPacket from aioquic.quic.rangeset import RangeSet from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace def send_probe(): pass class QuicPacketRecoveryRenoTest(TestCase): def setUp(self): self.INITIAL_SPACE = QuicPacketSpace() self.HANDSHAKE_SPACE = QuicPacketSpace() self.ONE_RTT_SPACE = QuicPacketSpace() self.recovery = QuicPacketRecovery( congestion_control_algorithm="reno", initial_rtt=0.1, max_datagram_size=1280, peer_completed_address_validation=True, send_probe=send_probe, ) self.recovery.spaces = [ self.INITIAL_SPACE, self.HANDSHAKE_SPACE, self.ONE_RTT_SPACE, ] def test_discard_space(self): self.recovery.discard_space(self.INITIAL_SPACE) def test_on_ack_received_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0, space=space, ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertTrue(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 10.0) self.assertEqual(self.recovery._rtt_min, 10.0) self.assertEqual(self.recovery._rtt_smoothed, 10.0) def test_on_ack_received_non_ack_eliciting(self): packet = QuicSentPacket( epoch=tls.Epoch.ONE_RTT, in_flight=True, is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=123.45, ) space = self.ONE_RTT_SPACE #  packet sent self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 1) # packet ack'd self.recovery.on_ack_received( ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0, space=space, ) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) # check RTT self.assertFalse(self.recovery._rtt_initialized) self.assertEqual(self.recovery._rtt_latest, 0.0) self.assertEqual(self.recovery._rtt_min, math.inf) self.assertEqual(self.recovery._rtt_smoothed, 0.0) def test_on_packet_lost_crypto(self): packet = QuicSentPacket( epoch=tls.Epoch.INITIAL, in_flight=True, is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, packet_type=QuicPacketType.INITIAL, sent_bytes=1280, sent_time=0.0, ) space = self.INITIAL_SPACE self.recovery.on_packet_sent(packet=packet, space=space) self.assertEqual(self.recovery.bytes_in_flight, 1280) self.assertEqual(space.ack_eliciting_in_flight, 1) self.assertEqual(len(space.sent_packets), 1) self.recovery._detect_loss(space=space, now=1.0) self.assertEqual(self.recovery.bytes_in_flight, 0) self.assertEqual(space.ack_eliciting_in_flight, 0) self.assertEqual(len(space.sent_packets), 0) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_retry.py0000644000175100001770000000236500000000000017700 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.quic.retry import QuicRetryTokenHandler class QuicRetryTokenHandlerTest(TestCase): def test_retry_token(self): addr = ("127.0.0.1", 1234) original_destination_connection_id = b"\x08\x07\x06\05\x04\x03\x02\x01" retry_source_connection_id = b"abcdefgh" handler = QuicRetryTokenHandler() # create token token = handler.create_token( addr, original_destination_connection_id, retry_source_connection_id ) self.assertIsNotNone(token) self.assertEqual(len(token), 256) # validate token - ok self.assertEqual( handler.validate_token(addr, token), (original_destination_connection_id, retry_source_connection_id), ) # validate token - empty with self.assertRaises(ValueError) as cm: handler.validate_token(addr, b"") self.assertEqual( str(cm.exception), "Ciphertext length must be equal to key size." ) # validate token - wrong address with self.assertRaises(ValueError) as cm: handler.validate_token(("1.2.3.4", 12345), token) self.assertEqual(str(cm.exception), "Remote address does not match.") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_stream.py0000644000175100001770000007313200000000000020026 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.quic.events import StreamDataReceived, StreamReset from aioquic.quic.packet import QuicErrorCode, QuicStreamFrame from aioquic.quic.packet_builder import QuicDeliveryState from aioquic.quic.stream import FinalSizeError, QuicStream class QuicStreamTest(TestCase): def test_receiver_empty(self): stream = QuicStream(stream_id=0) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 0) # empty self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"")), None ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 0) def test_receiver_ordered(self): stream = QuicStream(stream_id=0) # add data at start self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 8) self.assertEqual(stream.receiver.highest_offset, 8) self.assertFalse(stream.receiver.is_finished) # add more data self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=8, data=b"89012345")), StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 16) self.assertEqual(stream.receiver.highest_offset, 16) self.assertFalse(stream.receiver.is_finished) # add data and fin self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=16, data=b"67890123", fin=True) ), StreamDataReceived(data=b"67890123", end_stream=True, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 24) self.assertEqual(stream.receiver.highest_offset, 24) self.assertTrue(stream.receiver.is_finished) def test_receiver_unordered(self): stream = QuicStream(stream_id=0) # add data at offset 8 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=8, data=b"89012345")), None, ) self.assertEqual( bytes(stream.receiver._buffer), b"\x00\x00\x00\x00\x00\x00\x00\x0089012345" ) self.assertEqual(list(stream.receiver._ranges), [range(8, 16)]) self.assertEqual(stream.receiver._buffer_start, 0) self.assertEqual(stream.receiver.highest_offset, 16) # add data at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"0123456789012345", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 16) self.assertEqual(stream.receiver.highest_offset, 16) def test_receiver_offset_only(self): stream = QuicStream(stream_id=0) # add data at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"")), None ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 0) self.assertEqual(stream.receiver.highest_offset, 0) # add data at offset 8 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=8, data=b"")), None ) self.assertEqual( bytes(stream.receiver._buffer), b"\x00\x00\x00\x00\x00\x00\x00\x00" ) self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 0) self.assertEqual(stream.receiver.highest_offset, 8) def test_receiver_already_fully_consumed(self): stream = QuicStream(stream_id=0) # add data at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 8) # add data again at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), None, ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 8) # add data again at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01")), None ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 8) def test_receiver_already_partially_consumed(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=0, data=b"0123456789012345") ), StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 16) def test_receiver_already_partially_consumed_2(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=16, data=b"abcdefgh")), None, ) self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=2, data=b"23456789012345") ), StreamDataReceived(data=b"89012345abcdefgh", end_stream=False, stream_id=0), ) self.assertEqual(bytes(stream.receiver._buffer), b"") self.assertEqual(list(stream.receiver._ranges), []) self.assertEqual(stream.receiver._buffer_start, 24) def test_receiver_fin(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=8, data=b"89012345", fin=True) ), StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0), ) def test_receiver_fin_out_of_order(self): stream = QuicStream(stream_id=0) # add data at offset 8 with FIN self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=8, data=b"89012345", fin=True) ), None, ) self.assertEqual(stream.receiver.highest_offset, 16) self.assertFalse(stream.receiver.is_finished) # add data at offset 0 self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"0123456789012345", end_stream=True, stream_id=0), ) self.assertEqual(stream.receiver.highest_offset, 16) self.assertTrue(stream.receiver.is_finished) def test_receiver_fin_then_data(self): stream = QuicStream(stream_id=0) stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"0123", fin=True)) # data beyond final size with self.assertRaises(FinalSizeError) as cm: stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")) self.assertEqual(str(cm.exception), "Data received beyond final size") # final size would be lowered with self.assertRaises(FinalSizeError) as cm: stream.receiver.handle_frame( QuicStreamFrame(offset=0, data=b"01", fin=True) ) self.assertEqual(str(cm.exception), "Cannot change final size") def test_receiver_fin_twice(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"01234567")), StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), ) self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=8, data=b"89012345", fin=True) ), StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0), ) self.assertEqual( stream.receiver.handle_frame( QuicStreamFrame(offset=8, data=b"89012345", fin=True) ), StreamDataReceived(data=b"", end_stream=True, stream_id=0), ) def test_receiver_fin_without_data(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"", fin=True)), StreamDataReceived(data=b"", end_stream=True, stream_id=0), ) def test_receiver_reset(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_reset(final_size=4), StreamReset(error_code=QuicErrorCode.NO_ERROR, stream_id=0), ) self.assertTrue(stream.receiver.is_finished) def test_receiver_reset_after_fin(self): stream = QuicStream(stream_id=0) stream.receiver.handle_frame(QuicStreamFrame(offset=0, data=b"0123", fin=True)) self.assertEqual( stream.receiver.handle_reset(final_size=4), StreamReset(error_code=QuicErrorCode.NO_ERROR, stream_id=0), ) def test_receiver_reset_twice(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_reset(final_size=4), StreamReset(error_code=QuicErrorCode.NO_ERROR, stream_id=0), ) self.assertEqual( stream.receiver.handle_reset(final_size=4), StreamReset(error_code=QuicErrorCode.NO_ERROR, stream_id=0), ) def test_receiver_reset_twice_final_size_error(self): stream = QuicStream(stream_id=0) self.assertEqual( stream.receiver.handle_reset(final_size=4), StreamReset(error_code=QuicErrorCode.NO_ERROR, stream_id=0), ) with self.assertRaises(FinalSizeError) as cm: stream.receiver.handle_reset(final_size=5) self.assertEqual(str(cm.exception), "Cannot change final size") def test_receiver_stop(self): stream = QuicStream() # stop is requested stream.receiver.stop(QuicErrorCode.NO_ERROR) self.assertTrue(stream.receiver.stop_pending) # stop is sent frame = stream.receiver.get_stop_frame() self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR) self.assertFalse(stream.receiver.stop_pending) # stop is acklowledged stream.receiver.on_stop_sending_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.receiver.stop_pending) def test_receiver_stop_lost(self): stream = QuicStream() # stop is requested stream.receiver.stop(QuicErrorCode.NO_ERROR) self.assertTrue(stream.receiver.stop_pending) # stop is sent frame = stream.receiver.get_stop_frame() self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR) self.assertFalse(stream.receiver.stop_pending) # stop is lost stream.receiver.on_stop_sending_delivery(QuicDeliveryState.LOST) self.assertTrue(stream.receiver.stop_pending) # stop is sent again frame = stream.receiver.get_stop_frame() self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR) self.assertFalse(stream.receiver.stop_pending) # stop is acklowledged stream.receiver.on_stop_sending_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.receiver.stop_pending) def test_sender_data(self): stream = QuicStream() self.assertEqual(stream.sender.next_offset, 0) # nothing to send yet frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write data stream.sender.write(b"0123456789012345") self.assertEqual(list(stream.sender._pending), [range(0, 16)]) self.assertEqual(stream.sender.next_offset, 0) # send a chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"01234567") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 0) self.assertEqual(list(stream.sender._pending), [range(8, 16)]) self.assertEqual(stream.sender.next_offset, 8) # send another chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"89012345") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 8) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # nothing more to send frame = stream.sender.get_frame(8) self.assertIsNone(frame) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # first chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False) self.assertFalse(stream.sender.is_finished) # second chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, False) self.assertFalse(stream.sender.is_finished) def test_sender_data_and_fin(self): stream = QuicStream() # nothing to send yet frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write data and EOF stream.sender.write(b"0123456789012345", end_stream=True) self.assertEqual(list(stream.sender._pending), [range(0, 16)]) self.assertEqual(stream.sender.next_offset, 0) # send a chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"01234567") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 0) self.assertEqual(stream.sender.next_offset, 8) # send another chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"89012345") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 8) self.assertEqual(stream.sender.next_offset, 16) # nothing more to send frame = stream.sender.get_frame(8) self.assertIsNone(frame) self.assertEqual(stream.sender.next_offset, 16) # first chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False) self.assertFalse(stream.sender.is_finished) # second chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True) self.assertTrue(stream.sender.is_finished) def test_sender_data_and_fin_ack_out_of_order(self): stream = QuicStream() # nothing to send yet frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write data and EOF stream.sender.write(b"0123456789012345", end_stream=True) self.assertEqual(list(stream.sender._pending), [range(0, 16)]) self.assertEqual(stream.sender.next_offset, 0) # send a chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"01234567") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 0) self.assertEqual(stream.sender.next_offset, 8) # send another chunk frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"89012345") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 8) self.assertEqual(stream.sender.next_offset, 16) # nothing more to send frame = stream.sender.get_frame(8) self.assertIsNone(frame) self.assertEqual(stream.sender.next_offset, 16) # second chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True) self.assertFalse(stream.sender.is_finished) # first chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False) self.assertTrue(stream.sender.is_finished) def test_sender_data_lost(self): stream = QuicStream() # nothing to send yet frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write data and EOF stream.sender.write(b"0123456789012345", end_stream=True) self.assertEqual(list(stream.sender._pending), [range(0, 16)]) self.assertEqual(stream.sender.next_offset, 0) # send a chunk self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0), ) self.assertEqual(list(stream.sender._pending), [range(8, 16)]) self.assertEqual(stream.sender.next_offset, 8) # send another chunk self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8), ) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # nothing more to send self.assertIsNone(stream.sender.get_frame(8)) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # a chunk gets lost stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8, False) self.assertEqual(list(stream.sender._pending), [range(0, 8)]) self.assertEqual(stream.sender.next_offset, 0) # send chunk again self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0), ) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) def test_sender_data_lost_fin(self): stream = QuicStream() # nothing to send yet frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write data and EOF stream.sender.write(b"0123456789012345", end_stream=True) self.assertEqual(list(stream.sender._pending), [range(0, 16)]) self.assertEqual(stream.sender.next_offset, 0) # send a chunk self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0), ) self.assertEqual(list(stream.sender._pending), [range(8, 16)]) self.assertEqual(stream.sender.next_offset, 8) # send another chunk self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8), ) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # nothing more to send self.assertIsNone(stream.sender.get_frame(8)) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # a chunk gets lost stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16, True) self.assertEqual(list(stream.sender._pending), [range(8, 16)]) self.assertEqual(stream.sender.next_offset, 8) # send chunk again self.assertEqual( stream.sender.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8), ) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 16) # first chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False) self.assertFalse(stream.sender.is_finished) # second chunk gets acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True) self.assertTrue(stream.sender.is_finished) def test_sender_blocked(self): stream = QuicStream() max_offset = 12 # nothing to send yet frame = stream.sender.get_frame(8, max_offset) self.assertIsNone(frame) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 0) # write data, send a chunk stream.sender.write(b"0123456789012345") frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"01234567") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 0) self.assertEqual(list(stream.sender._pending), [range(8, 16)]) self.assertEqual(stream.sender.next_offset, 8) # send is limited by peer frame = stream.sender.get_frame(8, max_offset) self.assertEqual(frame.data, b"8901") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 8) self.assertEqual(list(stream.sender._pending), [range(12, 16)]) self.assertEqual(stream.sender.next_offset, 12) # unable to send, blocked frame = stream.sender.get_frame(8, max_offset) self.assertIsNone(frame) self.assertEqual(list(stream.sender._pending), [range(12, 16)]) self.assertEqual(stream.sender.next_offset, 12) # write more data, still blocked stream.sender.write(b"abcdefgh") frame = stream.sender.get_frame(8, max_offset) self.assertIsNone(frame) self.assertEqual(list(stream.sender._pending), [range(12, 24)]) self.assertEqual(stream.sender.next_offset, 12) # peer raises limit, send some data max_offset += 8 frame = stream.sender.get_frame(8, max_offset) self.assertEqual(frame.data, b"2345abcd") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 12) self.assertEqual(list(stream.sender._pending), [range(20, 24)]) self.assertEqual(stream.sender.next_offset, 20) # peer raises limit again, send remaining data max_offset += 8 frame = stream.sender.get_frame(8, max_offset) self.assertEqual(frame.data, b"efgh") self.assertFalse(frame.fin) self.assertEqual(frame.offset, 20) self.assertEqual(list(stream.sender._pending), []) self.assertEqual(stream.sender.next_offset, 24) # nothing more to send frame = stream.sender.get_frame(8, max_offset) self.assertIsNone(frame) def test_sender_fin_only(self): stream = QuicStream() # nothing to send yet self.assertTrue(stream.sender.buffer_is_empty) frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write EOF stream.sender.write(b"", end_stream=True) self.assertFalse(stream.sender.buffer_is_empty) frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 0) # nothing more to send self.assertFalse(stream.sender.buffer_is_empty) # FIXME? frame = stream.sender.get_frame(8) self.assertIsNone(frame) self.assertTrue(stream.sender.buffer_is_empty) # EOF is acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 0, True) self.assertTrue(stream.sender.is_finished) def test_sender_fin_only_despite_blocked(self): stream = QuicStream() # nothing to send yet self.assertTrue(stream.sender.buffer_is_empty) frame = stream.sender.get_frame(8) self.assertIsNone(frame) # write EOF stream.sender.write(b"", end_stream=True) self.assertFalse(stream.sender.buffer_is_empty) frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 0) # nothing more to send self.assertFalse(stream.sender.buffer_is_empty) # FIXME? frame = stream.sender.get_frame(8) self.assertIsNone(frame) self.assertTrue(stream.sender.buffer_is_empty) def test_sender_fin_then_ack(self): stream = QuicStream() # send some data stream.sender.write(b"data") frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"data") # data is acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, False) self.assertFalse(stream.sender.is_finished) # write EOF stream.sender.write(b"", end_stream=True) self.assertFalse(stream.sender.buffer_is_empty) frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 4) # EOF is acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 4, 4, True) self.assertTrue(stream.sender.is_finished) def test_sender_reset(self): stream = QuicStream() # send some data and EOF stream.sender.write(b"data", end_stream=True) frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"data") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 0) # reset is requested stream.sender.reset(QuicErrorCode.NO_ERROR) self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.reset_pending) # reset is sent reset = stream.sender.get_reset_frame() self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(reset.final_size, 4) self.assertFalse(stream.sender.reset_pending) self.assertFalse(stream.sender.is_finished) # data and EOF are acknowledged stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, True) self.assertTrue(stream.sender.buffer_is_empty) self.assertFalse(stream.sender.is_finished) # reset is acklowledged stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.sender.reset_pending) self.assertTrue(stream.sender.is_finished) def test_sender_reset_lost(self): stream = QuicStream() # reset is requested stream.sender.reset(QuicErrorCode.NO_ERROR) self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.reset_pending) # reset is sent reset = stream.sender.get_reset_frame() self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(reset.final_size, 0) self.assertFalse(stream.sender.reset_pending) # reset is lost stream.sender.on_reset_delivery(QuicDeliveryState.LOST) self.assertTrue(stream.sender.reset_pending) self.assertFalse(stream.sender.is_finished) # reset is sent again reset = stream.sender.get_reset_frame() self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(reset.final_size, 0) self.assertFalse(stream.sender.reset_pending) # reset is acklowledged stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.sender.reset_pending) self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.is_finished) def test_sender_reset_with_data_lost(self): stream = QuicStream() # send some data and EOF stream.sender.write(b"data", end_stream=True) frame = stream.sender.get_frame(8) self.assertEqual(frame.data, b"data") self.assertTrue(frame.fin) self.assertEqual(frame.offset, 0) # reset is requested stream.sender.reset(QuicErrorCode.NO_ERROR) self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.reset_pending) # reset is sent reset = stream.sender.get_reset_frame() self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) self.assertEqual(reset.final_size, 4) self.assertFalse(stream.sender.reset_pending) self.assertFalse(stream.sender.is_finished) # data and EOF are lost stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 4, True) self.assertTrue(stream.sender.buffer_is_empty) self.assertFalse(stream.sender.is_finished) # reset is acklowledged stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.sender.reset_pending) self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.is_finished) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_tls.py0000644000175100001770000020231100000000000017326 0ustar00runnerdocker00000000000000import binascii import datetime import ssl from functools import partial from unittest import TestCase from unittest.mock import patch from aioquic import tls from aioquic.buffer import Buffer, BufferReadError from aioquic.quic.configuration import QuicConfiguration from aioquic.tls import ( Certificate, CertificateRequest, CertificateVerify, ClientHello, Context, EncryptedExtensions, Finished, NewSessionTicket, ServerHello, State, load_pem_x509_certificates, pull_block, pull_certificate, pull_certificate_request, pull_certificate_verify, pull_client_hello, pull_encrypted_extensions, pull_finished, pull_new_session_ticket, pull_server_hello, pull_server_name, push_certificate, push_certificate_request, push_certificate_verify, push_client_hello, push_encrypted_extensions, push_finished, push_new_session_ticket, push_server_hello, push_server_name, verify_certificate, ) from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec from .utils import ( SERVER_CACERTFILE, SERVER_CERTFILE, SERVER_KEYFILE, generate_ec_certificate, generate_ed448_certificate, generate_ed25519_certificate, generate_rsa_certificate, load, ) CERTIFICATE_DATA = load("tls_certificate.bin")[11:-2] CERTIFICATE_VERIFY_SIGNATURE = load("tls_certificate_verify.bin")[-384:] CLIENT_QUIC_TRANSPORT_PARAMETERS = binascii.unhexlify( b"ff0000110031000500048010000000060004801000000007000480100000000" b"4000481000000000100024258000800024064000a00010a" ) SERVER_QUIC_TRANSPORT_PARAMETERS = binascii.unhexlify( b"ff00001104ff000011004500050004801000000006000480100000000700048" b"010000000040004810000000001000242580002001000000000000000000000" b"000000000000000800024064000a00010a" ) SERVER_QUIC_TRANSPORT_PARAMETERS_2 = binascii.unhexlify( b"0057000600048000ffff000500048000ffff00020010c5ac410fbdd4fe6e2c1" b"42279f231e8e0000a000103000400048005fffa000b000119000100026710ff" b"42000c5c067f27e39321c63e28e7c90003000247e40008000106" ) SERVER_QUIC_TRANSPORT_PARAMETERS_3 = binascii.unhexlify( b"0054000200100dcb50a442513295b4679baf04cb5effff8a0009c8afe72a6397" b"255407000600048000ffff0008000106000400048005fffa000500048000ffff" b"0003000247e4000a000103000100026710000b000119" ) class BufferTest(TestCase): def test_pull_block_truncated(self): buf = Buffer(capacity=0) with self.assertRaises(BufferReadError): with pull_block(buf, 1): pass def corrupt_hello_version(data: bytes) -> bytes: """ Corrupt a ClientHello or ServerHello's protocol version. """ return data[:4] + b"\xff\xff" + data[6:] def create_buffers(): return { tls.Epoch.INITIAL: Buffer(capacity=4096), tls.Epoch.HANDSHAKE: Buffer(capacity=4096), tls.Epoch.ONE_RTT: Buffer(capacity=4096), } def merge_buffers(buffers): return b"".join(x.data for x in buffers.values()) def reset_buffers(buffers): for k in buffers.keys(): buffers[k].seek(0) class ContextTest(TestCase): def assertClientHello(self, data: bytes): self.assertEqual(data[0], tls.HandshakeType.CLIENT_HELLO) self.assertGreaterEqual(len(data), 191) self.assertLessEqual(len(data), 564) def create_client( self, alpn_protocols=None, cadata=None, cafile=SERVER_CACERTFILE, **kwargs ): client = Context( alpn_protocols=alpn_protocols, cadata=cadata, cafile=cafile, is_client=True, **kwargs, ) client.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ] self.assertEqual(client.state, State.CLIENT_HANDSHAKE_START) return client def create_server(self, alpn_protocols=None, **kwargs): configuration = QuicConfiguration(is_client=False) configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) server = Context( alpn_protocols=alpn_protocols, is_client=False, max_early_data=0xFFFFFFFF, **kwargs, ) server.certificate = configuration.certificate server.certificate_private_key = configuration.private_key server.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, SERVER_QUIC_TRANSPORT_PARAMETERS, ) ] self.assertEqual(server.state, State.SERVER_EXPECT_CLIENT_HELLO) return server def handshake_with_client_input_corruption( self, corrupt_client_input, expected_exception, ): client = self.create_client() server = self.create_server() # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) reset_buffers(client_buf) # Handle client hello. # # send server hello, encrypted extensions, certificate, certificate verify, # finished. server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) reset_buffers(server_buf) # Mess with compression method. client_input = corrupt_client_input(client_input) # Handle server hello, encrypted extensions, certificate, certificate verify, # finished. with self.assertRaises(expected_exception.__class__) as cm: client.handle_message(client_input, client_buf) self.assertEqual(str(cm.exception), str(expected_exception)) def test_client_unexpected_message(self): client = self.create_client() client.state = State.CLIENT_EXPECT_SERVER_HELLO with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_EXPECT_CERTIFICATE with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_EXPECT_CERTIFICATE_VERIFY with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_EXPECT_FINISHED with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) client.state = State.CLIENT_POST_HANDSHAKE with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) def test_client_bad_hello_buffer_read_error(self): buf = Buffer(capacity=100) buf.push_uint8(tls.HandshakeType.SERVER_HELLO) with tls.push_block(buf, 3): pass self.handshake_with_client_input_corruption( # Receive a malformed ServerHello lambda x: buf.data, tls.AlertDecodeError("Could not parse TLS message"), ) def test_client_bad_hello_compression_method(self): self.handshake_with_client_input_corruption( # Mess with compression method. lambda x: x[:41] + b"\xff" + x[42:], tls.AlertIllegalParameter( "ServerHello has a compression method we did not advertise" ), ) def test_client_bad_hello_version(self): self.handshake_with_client_input_corruption( # Mess with supported version. lambda x: x[:48] + b"\xff\xff" + x[50:], tls.AlertIllegalParameter("ServerHello has a version we did not advertise"), ) def test_client_bad_certificate_verify_algorithm(self): self.handshake_with_client_input_corruption( # Mess with certificate verify. lambda x: x[:-440] + b"\xff\xff" + x[-438:], tls.AlertDecryptError( "CertificateVerify has a signature algorithm we did not advertise" ), ) def test_client_bad_certificate_verify_data(self): self.handshake_with_client_input_corruption( # Mess with certificate verify. lambda x: x[:-56] + bytes(4) + x[-52:], tls.AlertDecryptError(), ) def test_client_bad_finished_verify_data(self): self.handshake_with_client_input_corruption( # Mess with finished verify data. lambda x: x[:-4] + bytes(4), tls.AlertDecryptError(), ) def test_server_unexpected_message(self): server = self.create_server() server.state = State.SERVER_EXPECT_CLIENT_HELLO with self.assertRaises(tls.AlertUnexpectedMessage): server.handle_message(b"\x00\x00\x00\x00", create_buffers()) server.state = State.SERVER_EXPECT_CERTIFICATE with self.assertRaises(tls.AlertUnexpectedMessage): server.handle_message(b"\x00\x00\x00\x00", create_buffers()) server.state = State.SERVER_EXPECT_CERTIFICATE_VERIFY with self.assertRaises(tls.AlertUnexpectedMessage): server.handle_message(b"\x00\x00\x00\x00", create_buffers()) server.state = State.SERVER_EXPECT_FINISHED with self.assertRaises(tls.AlertUnexpectedMessage): server.handle_message(b"\x00\x00\x00\x00", create_buffers()) server.state = State.SERVER_POST_HANDSHAKE with self.assertRaises(tls.AlertUnexpectedMessage): server.handle_message(b"\x00\x00\x00\x00", create_buffers()) def _server_fail_hello(self, client, server): # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) reset_buffers(client_buf) # Handle client hello. server_buf = create_buffers() server.handle_message(server_input, server_buf) def test_server_unsupported_cipher_suite(self): client = self.create_client(cipher_suites=[tls.CipherSuite.AES_128_GCM_SHA256]) server = self.create_server(cipher_suites=[tls.CipherSuite.AES_256_GCM_SHA384]) with self.assertRaises(tls.AlertHandshakeFailure) as cm: self._server_fail_hello(client, server) self.assertEqual(str(cm.exception), "No supported cipher suite") def test_server_unsupported_signature_algorithm(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.ED448] server = self.create_server() with self.assertRaises(tls.AlertHandshakeFailure) as cm: self._server_fail_hello(client, server) self.assertEqual(str(cm.exception), "No supported signature algorithm") def test_server_unsupported_version(self): client = self.create_client() client._supported_versions = [tls.TLS_VERSION_1_2] server = self.create_server() with self.assertRaises(tls.AlertProtocolVersion) as cm: self._server_fail_hello(client, server) self.assertEqual(str(cm.exception), "No supported protocol version") def test_server_bad_finished_verify_data(self): client = self.create_client() server = self.create_server() # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) reset_buffers(client_buf) # Handle client hello. # # Send server hello, encrypted extensions, certificate, certificate verify, # finished. server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate, certificate verify, # finished. # # Send finished. client.handle_message(client_input, client_buf) self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) server_input = merge_buffers(client_buf) reset_buffers(client_buf) # Mess with finished verify data. server_input = server_input[:-4] + bytes(4) # Handle finished. with self.assertRaises(tls.AlertDecryptError): server.handle_message(server_input, server_buf) def _handshake(self, client, server): # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. # # Send server hello, encrypted extensions, certificate, certificate verify, # finished, (session ticket). server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) self.assertGreaterEqual(len(client_input), 587) self.assertLessEqual(len(client_input), 2316) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate, certificate verify, # finished, (session ticket). # # Send finished. client.handle_message(client_input, client_buf) self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) server_input = merge_buffers(client_buf) self.assertEqual(len(server_input), 52) reset_buffers(client_buf) # Handle finished. server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 0) # check keys match self.assertEqual(client._dec_key, server._enc_key) self.assertEqual(client._enc_key, server._dec_key) # check cipher suite self.assertEqual( client.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) self.assertEqual( server.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) def test_handshake(self): client = self.create_client() server = self.create_server() self._handshake(client, server) # check ALPN matches self.assertEqual(client.alpn_negotiated, None) self.assertEqual(server.alpn_negotiated, None) def test_handshake_with_certificate_request_no_certificate(self): # The server requests a certificate, but the client has none. client = self.create_client() server = self.create_server() server._request_client_certificate = True # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. # # Send server hello, encrypted extensions, certificate request, certificate, # certificate verify, finished. server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_CERTIFICATE) client_input = merge_buffers(server_buf) self.assertGreaterEqual(len(client_input), 587) self.assertLessEqual(len(client_input), 2316) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate request, certificate, # certificate verify, finished. # # Send certificate, finished. client.handle_message(client_input, client_buf) self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) server_input = merge_buffers(client_buf) self.assertEqual(len(server_input), 60) reset_buffers(client_buf) # Handle certificate, finished. server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 0) # check keys match self.assertEqual(client._dec_key, server._enc_key) self.assertEqual(client._enc_key, server._dec_key) # check cipher suite self.assertEqual( client.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) self.assertEqual( server.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) def test_handshake_with_certificate_request_with_certificate(self): # The server requests a certificate, and the client has one. client = self.create_client() client.certificate, client.certificate_private_key = generate_rsa_certificate( common_name="client.example.com" ) server = self.create_server() server._request_client_certificate = True # Send client hello. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. # # Send server hello, encrypted extensions, certificate request, certificate, # certificate verify, finished. server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_CERTIFICATE) client_input = merge_buffers(server_buf) self.assertGreaterEqual(len(client_input), 587) self.assertLessEqual(len(client_input), 2316) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate request, certificate, # certificate verify, finished. # # Send certificate, certificate verify, finished. client.handle_message(client_input, client_buf) self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) server_input = merge_buffers(client_buf) self.assertGreaterEqual(len(server_input), 1042) self.assertLessEqual(len(server_input), 1043) reset_buffers(client_buf) # Handle certificate, certificate verify, finished. server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 0) # check keys match self.assertEqual(client._dec_key, server._enc_key) self.assertEqual(client._enc_key, server._dec_key) # check cipher suite self.assertEqual( client.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) self.assertEqual( server.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 ) def _test_handshake_with_certificate(self, certificate, private_key): server = self.create_server() server.certificate = certificate server.certificate_private_key = private_key client = self.create_client( cadata=server.certificate.public_bytes(serialization.Encoding.PEM), cafile=None, ) self._handshake(client, server) # check ALPN matches self.assertEqual(client.alpn_negotiated, None) self.assertEqual(server.alpn_negotiated, None) def test_handshake_with_ec_certificate_secp256r1(self): self._test_handshake_with_certificate( *generate_ec_certificate(common_name="example.com", curve=ec.SECP256R1) ) def test_handshake_with_ec_certificate_secp384r1(self): self._test_handshake_with_certificate( *generate_ec_certificate(common_name="example.com", curve=ec.SECP384R1) ) def test_handshake_with_ed25519_certificate(self): self._test_handshake_with_certificate( *generate_ed25519_certificate(common_name="example.com") ) def test_handshake_with_ed448_certificate(self): self._test_handshake_with_certificate( *generate_ed448_certificate(common_name="example.com") ) def test_handshake_with_alpn(self): client = self.create_client(alpn_protocols=["hq-interop"]) server = self.create_server(alpn_protocols=["hq-interop", "h3"]) self._handshake(client, server) # check ALPN matches self.assertEqual(client.alpn_negotiated, "hq-interop") self.assertEqual(server.alpn_negotiated, "hq-interop") def test_handshake_with_alpn_fail(self): client = self.create_client(alpn_protocols=["hq-interop"]) server = self.create_server(alpn_protocols=["h3"]) with self.assertRaises(tls.AlertHandshakeFailure) as cm: self._handshake(client, server) self.assertEqual(str(cm.exception), "No common ALPN protocols") def test_handshake_with_rsa_pkcs1_sha1_signature(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA1] server = self.create_server() self._handshake(client, server) def test_handshake_with_rsa_pkcs1_sha256_signature(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA256] server = self.create_server() self._handshake(client, server) def test_handshake_with_rsa_pkcs1_sha384_signature(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA384] server = self.create_server() self._handshake(client, server) def test_handshake_with_rsa_pss_rsae_sha256_signature(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256] server = self.create_server() self._handshake(client, server) def test_handshake_with_rsa_pss_rsae_sha384_signature(self): client = self.create_client() client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384] server = self.create_server() self._handshake(client, server) def test_handshake_with_certificate_error(self): client = self.create_client(cafile=None) server = self.create_server() with self.assertRaises(tls.AlertBadCertificate) as cm: self._handshake(client, server) self.assertEqual(str(cm.exception), "unable to get local issuer certificate") def test_handshake_with_certificate_no_verify(self): client = self.create_client(cafile=None, verify_mode=ssl.CERT_NONE) server = self.create_server() self._handshake(client, server) def test_handshake_with_grease_group(self): client = self.create_client() client._supported_groups = [tls.Group.GREASE, tls.Group.SECP256R1] server = self.create_server() self._handshake(client, server) def test_handshake_with_secp256r1_group(self): client = self.create_client() client._supported_groups = [tls.Group.SECP256R1] server = self.create_server() self._handshake(client, server) def test_handshake_with_secp384r1_group(self): client = self.create_client() client._supported_groups = [tls.Group.SECP384R1] server = self.create_server() self._handshake(client, server) def test_handshake_with_x25519(self): client = self.create_client() client._supported_groups = [tls.Group.X25519] server = self.create_server() try: self._handshake(client, server) except UnsupportedAlgorithm as exc: self.skipTest(str(exc)) def test_handshake_with_x448(self): client = self.create_client() client._supported_groups = [tls.Group.X448] server = self.create_server() try: self._handshake(client, server) except UnsupportedAlgorithm as exc: self.skipTest(str(exc)) def test_session_ticket(self): client_tickets = [] server_tickets = [] def client_new_ticket(ticket): client_tickets.append(ticket) def server_get_ticket(label): for t in server_tickets: if t.ticket == label: return t return None def server_new_ticket(ticket): server_tickets.append(ticket) def first_handshake(): client = self.create_client() client.new_session_ticket_cb = client_new_ticket server = self.create_server() server.new_session_ticket_cb = server_new_ticket self._handshake(client, server) # check session resumption was not used self.assertFalse(client.session_resumed) self.assertFalse(server.session_resumed) # check tickets match self.assertEqual(len(client_tickets), 1) self.assertEqual(len(server_tickets), 1) self.assertEqual(client_tickets[0].ticket, server_tickets[0].ticket) self.assertEqual( client_tickets[0].resumption_secret, server_tickets[0].resumption_secret ) def second_handshake(): client = self.create_client() client.session_ticket = client_tickets[0] server = self.create_server() server.get_session_ticket_cb = server_get_ticket # Send client hello with pre_shared_key. client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # Handle client hello. # # Send server hello, encrypted extensions, finished. server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 275) reset_buffers(server_buf) # Handle server hello, encrypted extensions, certificate, # certificate verify, finished. # # Send finished. client.handle_message(client_input, client_buf) self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) server_input = merge_buffers(client_buf) self.assertEqual(len(server_input), 52) reset_buffers(client_buf) # Handle finished. # # Send new_session_ticket. server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 0) reset_buffers(server_buf) # check keys match self.assertEqual(client._dec_key, server._enc_key) self.assertEqual(client._enc_key, server._dec_key) # check session resumption was used self.assertTrue(client.session_resumed) self.assertTrue(server.session_resumed) def second_handshake_bad_binder(): client = self.create_client() client.session_ticket = client_tickets[0] server = self.create_server() server.get_session_ticket_cb = server_get_ticket # send client hello with pre_shared_key client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # tamper with binder server_input = server_input[:-4] + bytes(4) # handle client hello # send server hello, encrypted extensions, finished server_buf = create_buffers() with self.assertRaises(tls.AlertHandshakeFailure) as cm: server.handle_message(server_input, server_buf) self.assertEqual(str(cm.exception), "PSK validation failed") def second_handshake_bad_pre_shared_key(): client = self.create_client() client.session_ticket = client_tickets[0] server = self.create_server() server.get_session_ticket_cb = server_get_ticket # send client hello with pre_shared_key client_buf = create_buffers() client.handle_message(b"", client_buf) self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) server_input = merge_buffers(client_buf) self.assertClientHello(server_input) reset_buffers(client_buf) # handle client hello # send server hello, encrypted extensions, finished server_buf = create_buffers() server.handle_message(server_input, server_buf) self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) # tamper with pre_share_key index buf = server_buf[tls.Epoch.INITIAL] buf.seek(buf.tell() - 1) buf.push_uint8(1) client_input = merge_buffers(server_buf) self.assertEqual(len(client_input), 275) reset_buffers(server_buf) # handle server hello and bomb with self.assertRaises(tls.AlertIllegalParameter): client.handle_message(client_input, client_buf) first_handshake() second_handshake() second_handshake_bad_binder() second_handshake_bad_pre_shared_key() class TlsTest(TestCase): def test_pull_block_incomplete_read(self): """ If a block is not read until its end, an alert should be raised. """ buf = Buffer(data=bytes([2, 0, 0])) with self.assertRaises(tls.AlertDecodeError) as cm: with pull_block(buf, 1): buf.pull_bytes(1) self.assertEqual(str(cm.exception), "extra bytes at the end of a block") def test_pull_client_hello(self): buf = Buffer(data=load("tls_client_hello.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" ), ) self.assertEqual( hello.legacy_session_id, binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), ) self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ], ) self.assertEqual(hello.legacy_compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, None) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" "b0" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, None) self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], ) self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) self.assertEqual( hello.supported_versions, [ tls.TLS_VERSION_1_3, tls.TLS_VERSION_1_3_DRAFT_28, tls.TLS_VERSION_1_3_DRAFT_27, tls.TLS_VERSION_1_3_DRAFT_26, ], ) self.assertEqual( hello.other_extensions, [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ], ) def test_pull_client_hello_with_alpn(self): buf = Buffer(data=load("tls_client_hello_with_alpn.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ed575c6fbd599c4dfaabd003dca6e860ccdb0e1782c1af02e57bf27cb6479b76" ), ) self.assertEqual(hello.legacy_session_id, b"") self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.CHACHA20_POLY1305_SHA256, tls.CipherSuite.EMPTY_RENEGOTIATION_INFO_SCSV, ], ) self.assertEqual(hello.legacy_compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, ["h3-19"]) self.assertEqual(hello.early_data, False) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "048842315c437bb0ce2929c816fee4e942ec5cb6db6a6b9bf622680188ebb0d4" "b652e69033f71686aa01cbc79155866e264c9f33f45aa16b0dfa10a222e3a669" "22" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, "cloudflare-quic.com") self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.ECDSA_SECP384R1_SHA384, tls.SignatureAlgorithm.ECDSA_SECP521R1_SHA512, tls.SignatureAlgorithm.ED25519, tls.SignatureAlgorithm.ED448, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA256, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA384, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA512, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA512, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA384, tls.SignatureAlgorithm.RSA_PKCS1_SHA512, ], ) self.assertEqual( hello.supported_groups, [ tls.Group.SECP256R1, tls.Group.X25519, tls.Group.SECP384R1, tls.Group.SECP521R1, ], ) self.assertEqual(hello.supported_versions, [tls.TLS_VERSION_1_3]) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(len(buf.data), len(load("tls_client_hello_with_alpn.bin"))) def test_pull_client_hello_with_psk(self): buf = Buffer(data=load("tls_client_hello_with_psk.bin")) hello = pull_client_hello(buf) self.assertEqual(hello.early_data, True) self.assertEqual( hello.pre_shared_key, tls.OfferedPsks( identities=[ ( binascii.unhexlify( "fab3dc7d79f35ea53e9adf21150e601591a750b80cde0cd167fef6e0cdbc032a" "c4161fc5c5b66679de49524bd5624c50d71ba3e650780a4bfe402d6a06a00525" "0b5dc52085233b69d0dd13924cc5c713a396784ecafc59f5ea73c1585d79621b" "8a94e4f2291b17427d5185abf4a994fca74ee7a7f993a950c71003fc7cf8" ), 2067156378, ) ], binders=[ binascii.unhexlify( "1788ad43fdff37cfc628f24b6ce7c8c76180705380da17da32811b5bae4e78" "d7aaaf65a9b713872f2bb28818ca1a6b01" ) ], ), ) self.assertTrue(buf.eof()) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_psk.bin")) def test_pull_client_hello_with_psk_and_other_extension(self): buf = Buffer(capacity=1000) # Prepare PSK. psk_buf = Buffer(capacity=100) tls.push_offered_psks( psk_buf, tls.OfferedPsks( identities=[], binders=[], ), ) # Write a ClientHello with an extension *after* PSK. hello = ClientHello( random=binascii.unhexlify( "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" ), legacy_session_id=binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), cipher_suites=[tls.CipherSuite.AES_256_GCM_SHA384], legacy_compression_methods=[tls.CompressionMethod.NULL], key_share=[ ( tls.Group.SECP256R1, binascii.unhexlify( "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" "b0" ), ) ], psk_key_exchange_modes=[tls.PskKeyExchangeMode.PSK_DHE_KE], signature_algorithms=[tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256], supported_groups=[tls.Group.SECP256R1], supported_versions=[tls.TLS_VERSION_1_3], other_extensions=[ ( tls.ExtensionType.PRE_SHARED_KEY, psk_buf.data, ), ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, CLIENT_QUIC_TRANSPORT_PARAMETERS, ), ], ) push_client_hello(buf, hello) # Try reading it back. buf.seek(0) with self.assertRaises(tls.AlertIllegalParameter) as cm: pull_client_hello(buf) self.assertEqual(str(cm.exception), "PreSharedKey is not the last extension") def test_pull_client_hello_with_sni(self): buf = Buffer(data=load("tls_client_hello_with_sni.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "987d8934140b0a42cc5545071f3f9f7f61963d7b6404eb674c8dbe513604346b" ), ) self.assertEqual( hello.legacy_session_id, binascii.unhexlify( "26b19bdd30dbf751015a3a16e13bd59002dfe420b799d2a5cd5e11b8fa7bcb66" ), ) self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ], ) self.assertEqual(hello.legacy_compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, None) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "04b62d70f907c814cd65d0f73b8b991f06b70c77153f548410a191d2b19764a2" "ecc06065a480efa9e1f10c8da6e737d5bfc04be3f773e20a0c997f51b5621280" "40" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, "cloudflare-quic.com") self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], ) self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) self.assertEqual( hello.supported_versions, [ tls.TLS_VERSION_1_3, tls.TLS_VERSION_1_3_DRAFT_28, tls.TLS_VERSION_1_3_DRAFT_27, tls.TLS_VERSION_1_3_DRAFT_26, ], ) self.assertEqual( hello.other_extensions, [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ], ) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin")) def test_pull_client_hello_with_unexpected_version(self): buf = Buffer(data=corrupt_hello_version(load("tls_client_hello.bin"))) with self.assertRaises(tls.AlertDecodeError) as cm: pull_client_hello(buf) self.assertEqual(str(cm.exception), "ClientHello version is not 1.2") def test_push_client_hello(self): hello = ClientHello( random=binascii.unhexlify( "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" ), legacy_session_id=binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), cipher_suites=[ tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ], legacy_compression_methods=[tls.CompressionMethod.NULL], key_share=[ ( tls.Group.SECP256R1, binascii.unhexlify( "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" "b0" ), ) ], psk_key_exchange_modes=[tls.PskKeyExchangeMode.PSK_DHE_KE], signature_algorithms=[ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], supported_groups=[tls.Group.SECP256R1], supported_versions=[ tls.TLS_VERSION_1_3, tls.TLS_VERSION_1_3_DRAFT_28, tls.TLS_VERSION_1_3_DRAFT_27, tls.TLS_VERSION_1_3_DRAFT_26, ], other_extensions=[ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ], ) buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello.bin")) def test_pull_server_hello(self): buf = Buffer(data=load("tls_server_hello.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" ), ) self.assertEqual( hello.legacy_session_id, binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), ) self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) self.assertEqual( hello.key_share, ( tls.Group.SECP256R1, binascii.unhexlify( "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" "b2" ), ), ) self.assertEqual(hello.pre_shared_key, None) self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3) def test_pull_server_hello_with_psk(self): buf = Buffer(data=load("tls_server_hello_with_psk.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ccbaaf04fc1bd5143b2cc6b97520cf37d91470dbfc8127131a7bf0f941e3a137" ), ) self.assertEqual( hello.legacy_session_id, binascii.unhexlify( "9483e7e895d0f4cec17086b0849601c0632662cd764e828f2f892f4c4b7771b0" ), ) self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) self.assertEqual( hello.key_share, ( tls.Group.SECP256R1, binascii.unhexlify( "0485d7cecbebfc548fc657bf51b8e8da842a4056b164a27f7702ca318c16e488" "18b6409593b15c6649d6f459387a53128b164178adc840179aad01d36ce95d62" "76" ), ), ) self.assertEqual(hello.pre_shared_key, 0) self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3) # serialize buf = Buffer(1000) push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin")) def test_pull_server_hello_with_unexpected_version(self): buf = Buffer(data=corrupt_hello_version(load("tls_server_hello.bin"))) with self.assertRaises(tls.AlertDecodeError) as cm: pull_server_hello(buf) self.assertEqual(str(cm.exception), "ServerHello version is not 1.2") def test_pull_server_hello_with_unknown_extension(self): buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello, ServerHello( random=binascii.unhexlify( "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" ), legacy_session_id=binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384, compression_method=tls.CompressionMethod.NULL, key_share=( tls.Group.SECP256R1, binascii.unhexlify( "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" "b2" ), ), supported_version=tls.TLS_VERSION_1_3, other_extensions=[(12345, b"foo")], ), ) # serialize buf = Buffer(1000) push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_unknown_extension.bin")) def test_push_server_hello(self): hello = ServerHello( random=binascii.unhexlify( "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" ), legacy_session_id=binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384, compression_method=tls.CompressionMethod.NULL, key_share=( tls.Group.SECP256R1, binascii.unhexlify( "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" "b2" ), ), supported_version=tls.TLS_VERSION_1_3, ) buf = Buffer(1000) push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello.bin")) def test_pull_new_session_ticket(self): buf = Buffer(data=load("tls_new_session_ticket.bin")) new_session_ticket = pull_new_session_ticket(buf) self.assertIsNotNone(new_session_ticket) self.assertTrue(buf.eof()) self.assertEqual( new_session_ticket, NewSessionTicket( ticket_lifetime=86400, ticket_age_add=3303452425, ticket_nonce=b"", ticket=binascii.unhexlify( "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81" ), max_early_data_size=0xFFFFFFFF, ), ) # serialize buf = Buffer(100) push_new_session_ticket(buf, new_session_ticket) self.assertEqual(buf.data, load("tls_new_session_ticket.bin")) def test_pull_new_session_ticket_with_unknown_extension(self): buf = Buffer(data=load("tls_new_session_ticket_with_unknown_extension.bin")) new_session_ticket = pull_new_session_ticket(buf) self.assertIsNotNone(new_session_ticket) self.assertTrue(buf.eof()) self.assertEqual( new_session_ticket, NewSessionTicket( ticket_lifetime=86400, ticket_age_add=3303452425, ticket_nonce=b"", ticket=binascii.unhexlify( "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81" ), max_early_data_size=0xFFFFFFFF, other_extensions=[(12345, b"foo")], ), ) # serialize buf = Buffer(100) push_new_session_ticket(buf, new_session_ticket) self.assertEqual( buf.data, load("tls_new_session_ticket_with_unknown_extension.bin") ) def test_encrypted_extensions(self): data = load("tls_encrypted_extensions.bin") buf = Buffer(data=data) extensions = pull_encrypted_extensions(buf) self.assertIsNotNone(extensions) self.assertTrue(buf.eof()) self.assertEqual( extensions, EncryptedExtensions( other_extensions=[ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, SERVER_QUIC_TRANSPORT_PARAMETERS, ) ] ), ) # serialize buf = Buffer(capacity=100) push_encrypted_extensions(buf, extensions) self.assertEqual(buf.data, data) def test_encrypted_extensions_with_alpn(self): data = load("tls_encrypted_extensions_with_alpn.bin") buf = Buffer(data=data) extensions = pull_encrypted_extensions(buf) self.assertIsNotNone(extensions) self.assertTrue(buf.eof()) self.assertEqual( extensions, EncryptedExtensions( alpn_protocol="hq-20", other_extensions=[ (tls.ExtensionType.SERVER_NAME, b""), ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, SERVER_QUIC_TRANSPORT_PARAMETERS_2, ), ], ), ) # serialize buf = Buffer(115) push_encrypted_extensions(buf, extensions) self.assertTrue(buf.eof()) def test_pull_encrypted_extensions_with_alpn_and_early_data(self): buf = Buffer(data=load("tls_encrypted_extensions_with_alpn_and_early_data.bin")) extensions = pull_encrypted_extensions(buf) self.assertIsNotNone(extensions) self.assertTrue(buf.eof()) self.assertEqual( extensions, EncryptedExtensions( alpn_protocol="hq-20", early_data=True, other_extensions=[ (tls.ExtensionType.SERVER_NAME, b""), ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, SERVER_QUIC_TRANSPORT_PARAMETERS_3, ), ], ), ) # serialize buf = Buffer(116) push_encrypted_extensions(buf, extensions) self.assertTrue(buf.eof()) def test_pull_certificate(self): buf = Buffer(data=load("tls_certificate.bin")) certificate = pull_certificate(buf) self.assertTrue(buf.eof()) self.assertEqual(certificate.request_context, b"") self.assertEqual(certificate.certificates, [(CERTIFICATE_DATA, b"")]) def test_push_certificate(self): certificate = Certificate( request_context=b"", certificates=[(CERTIFICATE_DATA, b"")] ) buf = Buffer(1600) push_certificate(buf, certificate) self.assertEqual(buf.data, load("tls_certificate.bin")) def test_pull_certificate_request(self): buf = Buffer(data=load("tls_certificate_request.bin")) certificate_request = pull_certificate_request(buf) self.assertTrue(buf.eof()) self.assertEqual(certificate_request.request_context, b"") self.assertEqual( certificate_request.signature_algorithms, [ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], ) self.assertEqual(certificate_request.other_extensions, [(12345, b"foo")]) def test_push_certificate_request(self): certificate_request = CertificateRequest( request_context=b"", signature_algorithms=[ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], other_extensions=[(12345, b"foo")], ) buf = Buffer(400) push_certificate_request(buf, certificate_request) self.assertEqual(buf.data, load("tls_certificate_request.bin")) def test_pull_certificate_verify(self): buf = Buffer(data=load("tls_certificate_verify.bin")) verify = pull_certificate_verify(buf) self.assertTrue(buf.eof()) self.assertEqual(verify.algorithm, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256) self.assertEqual(verify.signature, CERTIFICATE_VERIFY_SIGNATURE) def test_push_certificate_verify(self): verify = CertificateVerify( algorithm=tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, signature=CERTIFICATE_VERIFY_SIGNATURE, ) buf = Buffer(400) push_certificate_verify(buf, verify) self.assertEqual(buf.data, load("tls_certificate_verify.bin")) def test_pull_finished(self): buf = Buffer(data=load("tls_finished.bin")) finished = pull_finished(buf) self.assertTrue(buf.eof()) self.assertEqual( finished.verify_data, binascii.unhexlify( "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68" ), ) def test_push_finished(self): finished = Finished( verify_data=binascii.unhexlify( "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68" ) ) buf = Buffer(128) push_finished(buf, finished) self.assertEqual(buf.data, load("tls_finished.bin")) def test_pull_server_name(self): buf = Buffer(data=b"\x00\x12\x00\x00\x0fwww.example.com") self.assertEqual(pull_server_name(buf), "www.example.com") def test_pull_server_name_with_bad_name_type(self): buf = Buffer(data=b"\x00\x12\xff\x00\x0fwww.example.com") with self.assertRaises(tls.AlertIllegalParameter) as cm: pull_server_name(buf) self.assertEqual(str(cm.exception), "ServerName has an unknown name type 255") def test_push_server_name(self): buf = Buffer(128) push_server_name(buf, "www.example.com") self.assertEqual(buf.data, b"\x00\x12\x00\x00\x0fwww.example.com") class VerifyCertificateTest(TestCase): def test_verify_certificate_chain(self): with open(SERVER_CERTFILE, "rb") as fp: certificate = load_pem_x509_certificates(fp.read())[0] with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc # fail with self.assertRaises(tls.AlertBadCertificate) as cm: verify_certificate(certificate=certificate, server_name="localhost") self.assertEqual( str(cm.exception), "unable to get local issuer certificate" ) # ok verify_certificate( cafile=SERVER_CACERTFILE, certificate=certificate, server_name="localhost", ) def test_verify_certificate_chain_self_signed(self): certificate, _ = generate_ec_certificate( alternative_names=["localhost"], common_name="localhost" ) with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc # fail with self.assertRaises(tls.AlertBadCertificate) as cm: verify_certificate(certificate=certificate, server_name="localhost") self.assertIn( str(cm.exception), ( "self signed certificate", "self-signed certificate", ), ) # ok verify_certificate( cadata=certificate.public_bytes(serialization.Encoding.PEM), certificate=certificate, server_name="localhost", ) def test_verify_dates(self): certificate, _ = generate_ec_certificate( alternative_names=["example.com"], common_name="example.com" ) cadata = certificate.public_bytes(serialization.Encoding.PEM) #  too early with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = ( certificate.not_valid_before_utc - datetime.timedelta(seconds=1) ) with self.assertRaises(tls.AlertCertificateExpired) as cm: verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) self.assertEqual(str(cm.exception), "Certificate is not valid yet") # valid with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_after_utc verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) # too late with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = ( certificate.not_valid_after_utc + datetime.timedelta(seconds=1) ) with self.assertRaises(tls.AlertCertificateExpired) as cm: verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) self.assertEqual(str(cm.exception), "Certificate is no longer valid") def test_verify_subject_no_subjaltname(self): certificate, _ = generate_ec_certificate(common_name="example.com") cadata = certificate.public_bytes(serialization.Encoding.PEM) with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc # certificates with no SubjectAltName are rejected with self.assertRaises(tls.AlertBadCertificate) as cm: verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) self.assertEqual( str(cm.exception), "Certificate does not contain any `subjectAltName`s." ) def test_verify_subject_with_subjaltname(self): certificate, _ = generate_ec_certificate( alternative_names=["*.example.com", "example.com"], common_name="example.com", ) cadata = certificate.public_bytes(serialization.Encoding.PEM) with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc # valid verify_certificate( cadata=cadata, certificate=certificate, server_name="example.com" ) verify_certificate( cadata=cadata, certificate=certificate, server_name="test.example.com" ) # invalid with self.assertRaises(tls.AlertBadCertificate) as cm: verify_certificate( cadata=cadata, certificate=certificate, server_name="acme.com" ) self.assertEqual( str(cm.exception), "hostname 'acme.com' doesn't match either of " "DNSPattern(pattern=b'*.example.com'), " "DNSPattern(pattern=b'example.com')", ) def test_verify_subject_with_subjaltname_ipaddress(self): certificate, _ = generate_ec_certificate( alternative_names=["1.2.3.4"], common_name="1.2.3.4", ) cadata = certificate.public_bytes(serialization.Encoding.PEM) with patch("aioquic.tls.utcnow") as mock_utcnow: mock_utcnow.return_value = certificate.not_valid_before_utc # valid verify_certificate( cadata=cadata, certificate=certificate, server_name="1.2.3.4" ) # invalid with self.assertRaises(tls.AlertBadCertificate) as cm: verify_certificate( cadata=cadata, certificate=certificate, server_name="8.8.8.8" ) self.assertEqual( str(cm.exception), "hostname '8.8.8.8' doesn't match " "IPAddressPattern(pattern=IPv4Address('1.2.3.4'))", ) def test_pull_greased_alpn_list(self): """Test pulling a list alpns with an ASCII item, an undecodable binary value such as greasing might give us, a valid UTF-8 encoding, and another ASCII item. We should only return the ASCII values. We currently only accept ASCII ALPNs, even though technically ALPNs are arbitrary bytes values, as our API is a list of strings. """ # the buffer is equivalent to "H2", b'\xff\xff', "é" in UTF-8, "H3" buf = Buffer(data=binascii.unhexlify("000c02483202ffff02c3a9024833")) self.assertEqual( tls.pull_list(buf, 2, partial(tls.pull_alpn_protocol, buf)), ["H2", "H3"] ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/test_webtransport.py0000644000175100001770000002474300000000000021271 0ustar00runnerdocker00000000000000from unittest import TestCase from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection from aioquic.h3.events import ( DatagramReceived, HeadersReceived, WebTransportStreamDataReceived, ) from aioquic.h3.exceptions import InvalidStreamTypeError from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import DatagramFrameReceived from .test_h3 import ( FakeQuicConnection, h3_client_and_server, h3_fake_client_and_server, h3_transfer, ) QUIC_CONFIGURATION_OPTIONS = { "alpn_protocols": H3_ALPN, "max_datagram_frame_size": 65536, } class WebTransportTest(TestCase): def _make_session(self, h3_client, h3_server): quic_client = h3_client._quic quic_server = h3_server._quic # send request stream_id = quic_client.get_next_available_stream_id() h3_client.send_headers( stream_id=stream_id, headers=[ (b":method", b"CONNECT"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b":protocol", b"webtransport"), ], ) # receive request events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ HeadersReceived( headers=[ (b":method", b"CONNECT"), (b":scheme", b"https"), (b":authority", b"localhost"), (b":path", b"/"), (b":protocol", b"webtransport"), ], stream_id=stream_id, stream_ended=False, push_id=None, ) ], ) # send response h3_server.send_headers( stream_id=stream_id, headers=[ (b":status", b"200"), ], ) # receive response events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ HeadersReceived( headers=[ (b":status", b"200"), ], stream_id=stream_id, stream_ended=False, ), ], ) return stream_id def test_bidirectional_stream(self): with h3_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send data on bidirectional stream stream_id = h3_client.create_webtransport_stream(session_id) quic_client.send_stream_data(stream_id, b"foo", end_stream=True) # receive data events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ WebTransportStreamDataReceived( data=b"foo", session_id=session_id, stream_ended=True, stream_id=stream_id, ) ], ) def test_bidirectional_stream_fragmented_frame(self): with h3_fake_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send data on bidirectional stream stream_id = h3_client.create_webtransport_stream(session_id) quic_client.send_stream_data(stream_id, b"foo", end_stream=True) # receive data events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ WebTransportStreamDataReceived( data=b"f", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"o", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"o", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"", session_id=session_id, stream_ended=True, stream_id=stream_id, ), ], ) def test_bidirectional_stream_server_initiated(self): with h3_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send data on bidirectional stream stream_id = h3_server.create_webtransport_stream(session_id) quic_server.send_stream_data(stream_id, b"foo", end_stream=True) # receive data events = h3_transfer(quic_server, h3_client) self.assertEqual( events, [ WebTransportStreamDataReceived( data=b"foo", session_id=session_id, stream_ended=True, stream_id=stream_id, ) ], ) def test_unidirectional_stream(self): with h3_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send data on unidirectional stream stream_id = h3_client.create_webtransport_stream( session_id, is_unidirectional=True ) quic_client.send_stream_data(stream_id, b"foo", end_stream=True) # receive data events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ WebTransportStreamDataReceived( data=b"foo", session_id=session_id, stream_ended=True, stream_id=stream_id, ) ], ) def test_unidirectional_stream_fragmented_frame(self): with h3_fake_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send data on unidirectional stream stream_id = h3_client.create_webtransport_stream( session_id, is_unidirectional=True ) quic_client.send_stream_data(stream_id, b"foo", end_stream=True) # receive data events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [ WebTransportStreamDataReceived( data=b"f", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"o", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"o", session_id=session_id, stream_ended=False, stream_id=stream_id, ), WebTransportStreamDataReceived( data=b"", session_id=session_id, stream_ended=True, stream_id=stream_id, ), ], ) def test_datagram(self): with h3_client_and_server(QUIC_CONFIGURATION_OPTIONS) as ( quic_client, quic_server, ): h3_client = H3Connection(quic_client, enable_webtransport=True) h3_server = H3Connection(quic_server, enable_webtransport=True) # create session session_id = self._make_session(h3_client, h3_server) # send datagram on a server-initiated stream with self.assertRaises(InvalidStreamTypeError): h3_client.send_datagram(data=b"foo", stream_id=1) # send datagram h3_client.send_datagram(data=b"foo", stream_id=session_id) # receive datagram events = h3_transfer(quic_client, h3_server) self.assertEqual( events, [DatagramReceived(data=b"foo", stream_id=session_id)], ) def test_handle_datagram_truncated(self): quic_server = FakeQuicConnection( configuration=QuicConfiguration(is_client=False) ) h3_server = H3Connection(quic_server) # receive a datagram with a truncated session ID h3_server.handle_event(DatagramFrameReceived(data=b"\xff")) self.assertEqual( quic_server.closed, (ErrorCode.H3_DATAGRAM_ERROR, "Could not parse quarter stream ID"), ) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_certificate.bin0000644000175100001770000000300200000000000020745 0ustar00runnerdocker00000000000000 00Y -ZiR\0  *H  0M1 0 UXY1&0$U Python Software Foundation CA10U our-ca-server0 180829142316Z 280707142316Z0_1 0 UXY10U Castle Anthrax1#0!U Python Software Foundation10U localhost00  *H 0(/7Adcbb):=^iXdK˳Bg3͹<9g|L~CJm>׹(ys7wvHx׽jF7 S"͏6#wJ孴|*ǫؗa[M/Ӷj;D(`d2Čl-f]݌ȌӺHZD55 "8n>U;ͮj[" ;@c)BYGZABγQZ[F&18 p0 Ys2:o6E15hZva:Greܶu'Qmw?42 &e%ޖM00U 0 localhost0U0U%0++0 U00U3\okoL˵|U0}U#v0tݿ47u!o(5HQO0M1 0 UXY1&0$U Python Software Foundation CA10U our-ca-server -ZiR[0+w0u0<+00http://testca.pythontest.net/testca/pycacert.cer05+0)http://testca.pythontest.net/testca/ocsp/0CU<0:08642http://testca.pythontest.net/testca/revocation.crl0  *H  'Y({_U"۲F Cnyeg! *U@G *Mdt.Bwu&5hAkwroj]UiZTW˰e9fZ-y5_Wxį^ R3نu +*welY_JĹzk2h;eĮ7O!ޟ:f? AB= >췕+5W:OXu>,].+AFO<'e$(#&1B~N{hTʟF+ 54u+# T4RE;o;F'aKX2zYD^y3/JBB3cՇM4?Y##cc././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_certificate_request.bin0000644000175100001770000000003400000000000022517 0ustar00runnerdocker00000000000000  09foo././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_certificate_verify.bin0000644000175100001770000000061000000000000022333 0ustar00runnerdocker00000000000000Σw 7pw%q.NØ^)Au|~5/΁ U5%- 2@{MzK#) tÐtgdg=~ ]Yڝ^އN4X]'|귵E"4Zjn׷>/9{g*+c:H\Q(~Pf`{m] >,lŔ<`4cB-3S6܏E+t-NdNhOi2\g#G F%ƷyLBq_^H?vsU@ax e֛fwΖԕF F2[ZX& BEfBL;P_8oaWʮQ[././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_client_hello.bin0000644000175100001770000000040200000000000021125 0ustar00runnerdocker00000000000000;K]Rz_NIMqev ц2)JCHZSƿU2253GEA{DFu5Bc]`@Z!hNgfZTs2ϯ"=~%WyO+   -71BX@d  ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_client_hello_with_alpn.bin0000644000175100001770000000046200000000000023200 0ustar00runnerdocker00000000000000.W\oYMܦ`{|Gv53 @du0cloudflare-quic.com  #h3-19    +-3GEAB1\C{))B\jk"hԶR3ǑUn&L3Zk "i"././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_client_hello_with_psk.bin0000644000175100001770000000074100000000000023043 0ustar00runnerdocker00000000000000Zd;{Ii2өQmaA4 CV ^oT M,;Ada}3GEA>2b.@?o|z @Kx5xT<$#pٟߦkZv+   -test.privateoctopus.com53u0@d  *)~}y^>!`P gͼ*ŶfyIRKbLPPx K@-j% ] #;iLxNYsX]yb)B}QNP|{6M10C7(KlapS2xתe/+k././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_client_hello_with_sni.bin0000644000175100001770000000043600000000000023040 0ustar00runnerdocker00000000000000}4 BUE?a={dgLQ64k &0QZ:;Ր ҥ^{f3GEA-pe; w?Tұd`e 7տKs Qb@+   -cloudflare-quic.com71BX@d  ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_encrypted_extensions.bin0000644000175100001770000000013200000000000022740 0ustar00runnerdocker00000000000000VTPEBX@d  ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_encrypted_extensions_with_alpn.bin0000644000175100001770000000016300000000000025011 0ustar00runnerdocker00000000000000omhq-20YWŬAn,"y1  gB \'!>(G././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_encrypted_extensions_with_alpn_and_early_data.bin0000644000175100001770000000016400000000000030021 0ustar00runnerdocker00000000000000pnhq-20*VT PBQ2g^ ȯ*c%TG g ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_finished.bin0000644000175100001770000000004400000000000020257 0ustar00runnerdocker00000000000000 W24I!.DZs'ohh././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_new_session_ticket.bin0000644000175100001770000000007100000000000022365 0ustar00runnerdocker000000000000005Q zxBk`|+5GpG Ykb́*././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_new_session_ticket_with_unknown_extension.bin0000644000175100001770000000010000000000000027264 0ustar00runnerdocker00000000000000<Q zxBk`|+5GpG Ykb́*09foo././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_server_hello.bin0000644000175100001770000000023300000000000021157 0ustar00runnerdocker00000000000000Rqіs6Q?o&^ޖN ц2)JCHZSƿU225O+3EA'("BK>)J3yNn{\P;7Ɠ|5Io%&8s:ϲ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_server_hello_with_psk.bin0000644000175100001770000000024100000000000023066 0ustar00runnerdocker00000000000000̺;,ƹu 7p'{A7 pc&bvN//LKwqU+3EATWQڄ*@Vdw1@\fIY8zSAx@l]bv)././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/tls_server_hello_with_unknown_extension.bin0000644000175100001770000000024200000000000026065 0ustar00runnerdocker00000000000000Rqіs6Q?o&^ޖN ц2)JCHZSƿU225V+3EA'("BK>)J3yNn{\P;7Ɠ|5Io%&8s:ϲ09foo././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1720306884.0 aioquic-1.2.0/tests/utils.py0000644000175100001770000000644600000000000016640 0ustar00runnerdocker00000000000000import asyncio import datetime import functools import ipaddress import logging import os from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec, ed448, ed25519, rsa def asynctest(coro): @functools.wraps(coro) def wrap(*args, **kwargs): asyncio.run(coro(*args, **kwargs)) return wrap def dns_name_or_ip_address(name): try: ip = ipaddress.ip_address(name) except ValueError: return x509.DNSName(name) else: return x509.IPAddress(ip) def generate_certificate(*, alternative_names, common_name, hash_algorithm, key): subject = issuer = x509.Name( [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] ) builder = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) .not_valid_after( datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=10) ) ) if alternative_names: builder = builder.add_extension( x509.SubjectAlternativeName( [dns_name_or_ip_address(name) for name in alternative_names] ), critical=False, ) cert = builder.sign(key, hash_algorithm) return cert, key def generate_ec_certificate(common_name, alternative_names=[], curve=ec.SECP256R1): key = ec.generate_private_key(curve=curve()) return generate_certificate( alternative_names=alternative_names, common_name=common_name, hash_algorithm=hashes.SHA256(), key=key, ) def generate_ed25519_certificate(common_name, alternative_names=[]): key = ed25519.Ed25519PrivateKey.generate() return generate_certificate( alternative_names=alternative_names, common_name=common_name, hash_algorithm=None, key=key, ) def generate_ed448_certificate(common_name, alternative_names=[]): key = ed448.Ed448PrivateKey.generate() return generate_certificate( alternative_names=alternative_names, common_name=common_name, hash_algorithm=None, key=key, ) def generate_rsa_certificate(common_name, alternative_names=[]): key = rsa.generate_private_key(public_exponent=65537, key_size=2048) return generate_certificate( alternative_names=alternative_names, common_name=common_name, hash_algorithm=hashes.SHA256(), key=key, ) def load(name: str) -> bytes: path = os.path.join(os.path.dirname(__file__), name) with open(path, "rb") as fp: return fp.read() SERVER_CACERTFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem") SERVER_CERTFILE = os.path.join(os.path.dirname(__file__), "ssl_cert.pem") SERVER_CERTFILE_WITH_CHAIN = os.path.join( os.path.dirname(__file__), "ssl_cert_with_chain.pem" ) SERVER_KEYFILE = os.path.join(os.path.dirname(__file__), "ssl_key.pem") SERVER_COMBINEDFILE = os.path.join(os.path.dirname(__file__), "ssl_combined.pem") SKIP_TESTS = frozenset(os.environ.get("AIOQUIC_SKIP_TESTS", "").split(",")) if os.environ.get("AIOQUIC_DEBUG"): logging.basicConfig(level=logging.DEBUG)