`_
.. raw:: html
websockets for enterprise
Available as part of the Tidelift Subscription
The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.
(If you contribute to websockets
and would like to become an official support provider, let me know.)
Why should I use ``websockets``?
--------------------------------
The development of ``websockets`` is shaped by four principles:
1. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and
``await ws.send(msg)``; ``websockets`` takes care of managing connections
so you can focus on your application.
2. **Robustness**: ``websockets`` is built for production; for example it was
the only library to `handle backpressure correctly`_ before the issue
became widely known in the Python community.
3. **Quality**: ``websockets`` is heavily tested. Continuous integration fails
under 100% branch coverage. Also it passes the industry-standard `Autobahn
Testsuite`_.
4. **Performance**: memory use is configurable. An extension written in C
accelerates expensive operations. It's pre-compiled for Linux, macOS and
Windows and packaged in the wheel format for each system and Python version.
Documentation is a first class concern in the project. Head over to `Read the
Docs`_ and see for yourself.
.. _Read the Docs: https://websockets.readthedocs.io/
.. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers
.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst
Why shouldn't I use ``websockets``?
-----------------------------------
* If you prefer callbacks over coroutines: ``websockets`` was created to
provide the best coroutine-based API to manage WebSocket connections in
Python. Pick another library for a callback-based API.
* If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims
at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol
and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP
is minimal — just enough for a HTTP health check.
* If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which
only works on Python 3. ``websockets`` requires Python ≥ 3.6.1.
What else?
----------
Bug reports, patches and suggestions are welcome!
To report a security vulnerability, please use the `Tidelift security
contact`_. Tidelift will coordinate the fix and disclosure.
.. _Tidelift security contact: https://tidelift.com/security
For anything else, please open an issue_ or send a `pull request`_.
.. _issue: https://github.com/aaugustin/websockets/issues/new
.. _pull request: https://github.com/aaugustin/websockets/compare/
Participants must uphold the `Contributor Covenant code of conduct`_.
.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md
``websockets`` is released under the `BSD license`_.
.. _BSD license: https://github.com/aaugustin/websockets/blob/master/LICENSE
././@PaxHeader 0000000 0000000 0000000 00000000034 00000000000 011452 x ustar 00 0000000 0000000 28 mtime=1572615610.4777734
websockets-8.1/setup.cfg 0000644 0000765 0000024 00000001011 00000000000 015271 0 ustar 00myk staff 0000000 0000000 [bdist_wheel]
python-tag = py36.py37
[metadata]
license_file = LICENSE
[flake8]
ignore = E731,F403,F405,W503
max-line-length = 88
[isort]
combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_standard_library = asyncio
line_length = 88
lines_after_imports = 2
multi_line_output = 3
[coverage:run]
branch = True
omit = */__main__.py
source =
websockets
tests
[coverage:paths]
source =
src/websockets
.tox/*/lib/python*/site-packages/websockets
[egg_info]
tag_build =
tag_date = 0
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1570344432.0
websockets-8.1/setup.py 0000644 0000765 0000024 00000003517 00000000000 015177 0 ustar 00myk staff 0000000 0000000 import pathlib
import re
import sys
import setuptools
root_dir = pathlib.Path(__file__).parent
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
long_description = (root_dir / 'README.rst').read_text(encoding='utf-8')
# PyPI disables the "raw" directive.
long_description = re.sub(
r"^\.\. raw:: html.*?^(?=\w)",
"",
long_description,
flags=re.DOTALL | re.MULTILINE,
)
exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8'))
if sys.version_info[:3] < (3, 6, 1):
raise Exception("websockets requires Python >= 3.6.1.")
packages = ['websockets', 'websockets/extensions']
ext_modules = [
setuptools.Extension(
'websockets.speedups',
sources=['src/websockets/speedups.c'],
optional=not (root_dir / '.cibuildwheel').exists(),
)
]
setuptools.setup(
name='websockets',
version=version,
description=description,
long_description=long_description,
url='https://github.com/aaugustin/websockets',
author='Aymeric Augustin',
author_email='aymeric.augustin@m4x.org',
license='BSD',
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Web Environment',
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
package_dir = {'': 'src'},
package_data = {'websockets': ['py.typed']},
packages=packages,
ext_modules=ext_modules,
include_package_data=True,
zip_safe=False,
python_requires='>=3.6.1',
test_loader='unittest:TestLoader',
)
././@PaxHeader 0000000 0000000 0000000 00000000033 00000000000 011451 x ustar 00 0000000 0000000 27 mtime=1572615610.465479
websockets-8.1/src/ 0000755 0000765 0000024 00000000000 00000000000 014246 5 ustar 00myk staff 0000000 0000000 ././@PaxHeader 0000000 0000000 0000000 00000000033 00000000000 011451 x ustar 00 0000000 0000000 27 mtime=1572615610.473283
websockets-8.1/src/websockets/ 0000755 0000765 0000024 00000000000 00000000000 016417 5 ustar 00myk staff 0000000 0000000 ././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1571084701.0
websockets-8.1/src/websockets/__init__.py 0000644 0000765 0000024 00000002442 00000000000 020532 0 ustar 00myk staff 0000000 0000000 # This relies on each of the submodules having an __all__ variable.
from .auth import * # noqa
from .client import * # noqa
from .exceptions import * # noqa
from .protocol import * # noqa
from .server import * # noqa
from .typing import * # noqa
from .uri import * # noqa
from .version import version as __version__ # noqa
__all__ = [
"AbortHandshake",
"basic_auth_protocol_factory",
"BasicAuthWebSocketServerProtocol",
"connect",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
"Data",
"DuplicateParameter",
"ExtensionHeader",
"ExtensionParameter",
"InvalidHandshake",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidMessage",
"InvalidOrigin",
"InvalidParameterName",
"InvalidParameterValue",
"InvalidState",
"InvalidStatusCode",
"InvalidUpgrade",
"InvalidURI",
"NegotiationError",
"Origin",
"parse_uri",
"PayloadTooBig",
"ProtocolError",
"RedirectHandshake",
"SecurityError",
"serve",
"Subprotocol",
"unix_connect",
"unix_serve",
"WebSocketClientProtocol",
"WebSocketCommonProtocol",
"WebSocketException",
"WebSocketProtocolError",
"WebSocketServer",
"WebSocketServerProtocol",
"WebSocketURI",
]
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1570344432.0
websockets-8.1/src/websockets/__main__.py 0000644 0000765 0000024 00000014424 00000000000 020516 0 ustar 00myk staff 0000000 0000000 import argparse
import asyncio
import os
import signal
import sys
import threading
from typing import Any, Set
from .client import connect
from .exceptions import ConnectionClosed, format_close
if sys.platform == "win32":
def win_enable_vt100() -> None:
"""
Enable VT-100 for console output on Windows.
See also https://bugs.python.org/issue29059.
"""
import ctypes
STD_OUTPUT_HANDLE = ctypes.c_uint(-11)
INVALID_HANDLE_VALUE = ctypes.c_uint(-1)
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004
handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE)
if handle == INVALID_HANDLE_VALUE:
raise RuntimeError("unable to obtain stdout handle")
cur_mode = ctypes.c_uint()
if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0:
raise RuntimeError("unable to query current console mode")
# ctypes ints lack support for the required bit-OR operation.
# Temporarily convert to Py int, do the OR and convert back.
py_int_mode = int.from_bytes(cur_mode, sys.byteorder)
new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)
if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0:
raise RuntimeError("unable to set console mode")
def exit_from_event_loop_thread(
loop: asyncio.AbstractEventLoop, stop: "asyncio.Future[None]"
) -> None:
loop.stop()
if not stop.done():
# When exiting the thread that runs the event loop, raise
# KeyboardInterrupt in the main thread to exit the program.
try:
ctrl_c = signal.CTRL_C_EVENT # Windows
except AttributeError:
ctrl_c = signal.SIGINT # POSIX
os.kill(os.getpid(), ctrl_c)
def print_during_input(string: str) -> None:
sys.stdout.write(
# Save cursor position
"\N{ESC}7"
# Add a new line
"\N{LINE FEED}"
# Move cursor up
"\N{ESC}[A"
# Insert blank line, scroll last line down
"\N{ESC}[L"
# Print string in the inserted blank line
f"{string}\N{LINE FEED}"
# Restore cursor position
"\N{ESC}8"
# Move cursor down
"\N{ESC}[B"
)
sys.stdout.flush()
def print_over_input(string: str) -> None:
sys.stdout.write(
# Move cursor to beginning of line
"\N{CARRIAGE RETURN}"
# Delete current line
"\N{ESC}[K"
# Print string
f"{string}\N{LINE FEED}"
)
sys.stdout.flush()
async def run_client(
uri: str,
loop: asyncio.AbstractEventLoop,
inputs: "asyncio.Queue[str]",
stop: "asyncio.Future[None]",
) -> None:
try:
websocket = await connect(uri)
except Exception as exc:
print_over_input(f"Failed to connect to {uri}: {exc}.")
exit_from_event_loop_thread(loop, stop)
return
else:
print_during_input(f"Connected to {uri}.")
try:
while True:
incoming: asyncio.Future[Any] = asyncio.ensure_future(websocket.recv())
outgoing: asyncio.Future[Any] = asyncio.ensure_future(inputs.get())
done: Set[asyncio.Future[Any]]
pending: Set[asyncio.Future[Any]]
done, pending = await asyncio.wait(
[incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED
)
# Cancel pending tasks to avoid leaking them.
if incoming in pending:
incoming.cancel()
if outgoing in pending:
outgoing.cancel()
if incoming in done:
try:
message = incoming.result()
except ConnectionClosed:
break
else:
if isinstance(message, str):
print_during_input("< " + message)
else:
print_during_input("< (binary) " + message.hex())
if outgoing in done:
message = outgoing.result()
await websocket.send(message)
if stop in done:
break
finally:
await websocket.close()
close_status = format_close(websocket.close_code, websocket.close_reason)
print_over_input(f"Connection closed: {close_status}.")
exit_from_event_loop_thread(loop, stop)
def main() -> None:
# If we're on Windows, enable VT100 terminal support.
if sys.platform == "win32":
try:
win_enable_vt100()
except RuntimeError as exc:
sys.stderr.write(
f"Unable to set terminal to VT100 mode. This is only "
f"supported since Win10 anniversary update. Expect "
f"weird symbols on the terminal.\nError: {exc}\n"
)
sys.stderr.flush()
try:
import readline # noqa
except ImportError: # Windows has no `readline` normally
pass
# Parse command line arguments.
parser = argparse.ArgumentParser(
prog="python -m websockets",
description="Interactive WebSocket client.",
add_help=False,
)
parser.add_argument("uri", metavar="")
args = parser.parse_args()
# Create an event loop that will run in a background thread.
loop = asyncio.new_event_loop()
# Create a queue of user inputs. There's no need to limit its size.
inputs: asyncio.Queue[str] = asyncio.Queue(loop=loop)
# Create a stop condition when receiving SIGINT or SIGTERM.
stop: asyncio.Future[None] = loop.create_future()
# Schedule the task that will manage the connection.
asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop)
# Start the event loop in a background thread.
thread = threading.Thread(target=loop.run_forever)
thread.start()
# Read from stdin in the main thread in order to receive signals.
try:
while True:
# Since there's no size limit, put_nowait is identical to put.
message = input("> ")
loop.call_soon_threadsafe(inputs.put_nowait, message)
except (KeyboardInterrupt, EOFError): # ^C, ^D
loop.call_soon_threadsafe(stop.set_result, None)
# Wait for the event loop to terminate.
thread.join()
if __name__ == "__main__":
main()
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1571084468.0
websockets-8.1/src/websockets/auth.py 0000644 0000765 0000024 00000012446 00000000000 017741 0 ustar 00myk staff 0000000 0000000 """
:mod:`websockets.auth` provides HTTP Basic Authentication according to
:rfc:`7235` and :rfc:`7617`.
"""
import functools
import http
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union
from .exceptions import InvalidHeader
from .headers import build_www_authenticate_basic, parse_authorization_basic
from .http import Headers
from .server import HTTPResponse, WebSocketServerProtocol
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
Credentials = Tuple[str, str]
def is_credentials(value: Any) -> bool:
try:
username, password = value
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""
WebSocket server protocol that enforces HTTP Basic Auth.
"""
def __init__(
self,
*args: Any,
realm: str,
check_credentials: Callable[[str, str], Awaitable[bool]],
**kwargs: Any,
) -> None:
self.realm = realm
self.check_credentials = check_credentials
super().__init__(*args, **kwargs)
async def process_request(
self, path: str, request_headers: Headers
) -> Optional[HTTPResponse]:
"""
Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.
If authentication succeeds, the username of the authenticated user is
stored in the ``username`` attribute.
"""
try:
authorization = request_headers["Authorization"]
except KeyError:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Missing credentials\n",
)
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Unsupported credentials\n",
)
if not await self.check_credentials(username, password):
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Invalid credentials\n",
)
self.username = username
return await super().process_request(path, request_headers)
def basic_auth_protocol_factory(
realm: str,
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
create_protocol: Type[
BasicAuthWebSocketServerProtocol
] = BasicAuthWebSocketServerProtocol,
) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
"""
Protocol factory that enforces HTTP Basic Auth.
``basic_auth_protocol_factory`` is designed to integrate with
:func:`~websockets.server.serve` like this::
websockets.serve(
...,
create_protocol=websockets.basic_auth_protocol_factory(
realm="my dev server",
credentials=("hello", "iloveyou"),
)
)
``realm`` indicates the scope of protection. It should contain only ASCII
characters because the encoding of non-ASCII characters is undefined.
Refer to section 2.2 of :rfc:`7235` for details.
``credentials`` defines hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
``check_credentials`` defines a coroutine that checks whether credentials
are authorized. This coroutine receives ``username`` and ``password``
arguments and returns a :class:`bool`.
One of ``credentials`` or ``check_credentials`` must be provided but not
both.
By default, ``basic_auth_protocol_factory`` creates a factory for building
:class:`BasicAuthWebSocketServerProtocol` instances. You can override this
with the ``create_protocol`` parameter.
:param realm: scope of protection
:param credentials: hard coded credentials
:param check_credentials: coroutine that verifies credentials
:raises TypeError: if the credentials argument has the wrong type
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
async def check_credentials(username: str, password: str) -> bool:
return (username, password) == credentials
elif isinstance(credentials, Iterable):
credentials_list = list(credentials)
if all(is_credentials(item) for item in credentials_list):
credentials_dict = dict(credentials_list)
async def check_credentials(username: str, password: str) -> bool:
return credentials_dict.get(username) == password
else:
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
return functools.partial(
create_protocol, realm=realm, check_credentials=check_credentials
)
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1572595517.0
websockets-8.1/src/websockets/client.py 0000644 0000765 0000024 00000051337 00000000000 020260 0 ustar 00myk staff 0000000 0000000 """
:mod:`websockets.client` defines the WebSocket client APIs.
"""
import asyncio
import collections.abc
import functools
import logging
import warnings
from types import TracebackType
from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast
from .exceptions import (
InvalidHandshake,
InvalidHeader,
InvalidMessage,
InvalidStatusCode,
NegotiationError,
RedirectHandshake,
SecurityError,
)
from .extensions.base import ClientExtensionFactory, Extension
from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
from .handshake import build_request, check_response
from .headers import (
build_authorization_basic,
build_extension,
build_subprotocol,
parse_extension,
parse_subprotocol,
)
from .http import USER_AGENT, Headers, HeadersLike, read_response
from .protocol import WebSocketCommonProtocol
from .typing import ExtensionHeader, Origin, Subprotocol
from .uri import WebSocketURI, parse_uri
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
logger = logging.getLogger(__name__)
class WebSocketClientProtocol(WebSocketCommonProtocol):
"""
:class:`~asyncio.Protocol` subclass implementing a WebSocket client.
This class inherits most of its methods from
:class:`~websockets.protocol.WebSocketCommonProtocol`.
"""
is_client = True
side = "client"
def __init__(
self,
*,
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
**kwargs: Any,
) -> None:
self.origin = origin
self.available_extensions = extensions
self.available_subprotocols = subprotocols
self.extra_headers = extra_headers
super().__init__(**kwargs)
def write_http_request(self, path: str, headers: Headers) -> None:
"""
Write request line and headers to the HTTP request.
"""
self.path = path
self.request_headers = headers
logger.debug("%s > GET %s HTTP/1.1", self.side, path)
logger.debug("%s > %r", self.side, headers)
# Since the path and headers only contain ASCII characters,
# we can keep this simple.
request = f"GET {path} HTTP/1.1\r\n"
request += str(headers)
self.transport.write(request.encode())
async def read_http_response(self) -> Tuple[int, Headers]:
"""
Read status line and headers from the HTTP response.
If the response contains a body, it may be read from ``self.reader``
after this coroutine returns.
:raises ~websockets.exceptions.InvalidMessage: if the HTTP message is
malformed or isn't an HTTP/1.1 GET response
"""
try:
status_code, reason, headers = await read_response(self.reader)
except Exception as exc:
raise InvalidMessage("did not receive a valid HTTP response") from exc
logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason)
logger.debug("%s < %r", self.side, headers)
self.response_headers = headers
return status_code, self.response_headers
@staticmethod
def process_extensions(
headers: Headers,
available_extensions: Optional[Sequence[ClientExtensionFactory]],
) -> List[Extension]:
"""
Handle the Sec-WebSocket-Extensions HTTP response header.
Check that each extension is supported, as well as its parameters.
Return the list of accepted extensions.
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
connection.
:rfc:`6455` leaves the rules up to the specification of each
:extension.
To provide this level of flexibility, for each extension accepted by
the server, we check for a match with each extension available in the
client configuration. If no match is found, an exception is raised.
If several variants of the same extension are accepted by the server,
it may be configured severel times, which won't make sense in general.
Extensions must implement their own requirements. For this purpose,
the list of previously accepted extensions is provided.
Other requirements, for example related to mandatory extensions or the
order of extensions, may be implemented by overriding this method.
"""
accepted_extensions: List[Extension] = []
header_values = headers.get_all("Sec-WebSocket-Extensions")
if header_values:
if available_extensions is None:
raise InvalidHandshake("no extensions supported")
parsed_header_values: List[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in header_values], []
)
for name, response_params in parsed_header_values:
for extension_factory in available_extensions:
# Skip non-matching extensions based on their name.
if extension_factory.name != name:
continue
# Skip non-matching extensions based on their params.
try:
extension = extension_factory.process_response_params(
response_params, accepted_extensions
)
except NegotiationError:
continue
# Add matching extension to the final list.
accepted_extensions.append(extension)
# Break out of the loop once we have a match.
break
# If we didn't break from the loop, no extension in our list
# matched what the server sent. Fail the connection.
else:
raise NegotiationError(
f"Unsupported extension: "
f"name = {name}, params = {response_params}"
)
return accepted_extensions
@staticmethod
def process_subprotocol(
headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
"""
Handle the Sec-WebSocket-Protocol HTTP response header.
Check that it contains exactly one supported subprotocol.
Return the selected subprotocol.
"""
subprotocol: Optional[Subprotocol] = None
header_values = headers.get_all("Sec-WebSocket-Protocol")
if header_values:
if available_subprotocols is None:
raise InvalidHandshake("no subprotocols supported")
parsed_header_values: Sequence[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in header_values], []
)
if len(parsed_header_values) > 1:
subprotocols = ", ".join(parsed_header_values)
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
subprotocol = parsed_header_values[0]
if subprotocol not in available_subprotocols:
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
return subprotocol
async def handshake(
self,
wsuri: WebSocketURI,
origin: Optional[Origin] = None,
available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
available_subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
) -> None:
"""
Perform the client side of the opening handshake.
:param origin: sets the Origin HTTP header
:param available_extensions: list of supported extensions in the order
in which they should be used
:param available_subprotocols: list of supported subprotocols in order
of decreasing preference
:param extra_headers: sets additional HTTP request headers; it must be
a :class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name,
value)`` pairs
:raises ~websockets.exceptions.InvalidHandshake: if the handshake
fails
"""
request_headers = Headers()
if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
request_headers["Host"] = wsuri.host
else:
request_headers["Host"] = f"{wsuri.host}:{wsuri.port}"
if wsuri.user_info:
request_headers["Authorization"] = build_authorization_basic(
*wsuri.user_info
)
if origin is not None:
request_headers["Origin"] = origin
key = build_request(request_headers)
if available_extensions is not None:
extensions_header = build_extension(
[
(extension_factory.name, extension_factory.get_request_params())
for extension_factory in available_extensions
]
)
request_headers["Sec-WebSocket-Extensions"] = extensions_header
if available_subprotocols is not None:
protocol_header = build_subprotocol(available_subprotocols)
request_headers["Sec-WebSocket-Protocol"] = protocol_header
if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
request_headers[name] = value
request_headers.setdefault("User-Agent", USER_AGENT)
self.write_http_request(wsuri.resource_name, request_headers)
status_code, response_headers = await self.read_http_response()
if status_code in (301, 302, 303, 307, 308):
if "Location" not in response_headers:
raise InvalidHeader("Location")
raise RedirectHandshake(response_headers["Location"])
elif status_code != 101:
raise InvalidStatusCode(status_code)
check_response(response_headers, key)
self.extensions = self.process_extensions(
response_headers, available_extensions
)
self.subprotocol = self.process_subprotocol(
response_headers, available_subprotocols
)
self.connection_open()
class Connect:
"""
Connect to the WebSocket server at the given ``uri``.
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
can then be used to send and receive messages.
:func:`connect` can also be used as a asynchronous context manager. In
that case, the connection is closed when exiting the context.
:func:`connect` is a wrapper around the event loop's
:meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
are passed to :meth:`~asyncio.loop.create_connection`.
For example, you can set the ``ssl`` keyword argument to a
:class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
a ``wss://`` URI, if this argument isn't provided explicitly,
:func:`ssl.create_default_context` is called to create a context.
You can connect to a different host and port from those found in ``uri``
by setting ``host`` and ``port`` keyword arguments. This only changes the
destination of the TCP connection. The host name from ``uri`` is still
used in the TLS handshake for secure connections and in the ``Host`` HTTP
header.
The ``create_protocol`` parameter allows customizing the
:class:`~asyncio.Protocol` that manages the connection. It should be a
callable or class accepting the same arguments as
:class:`WebSocketClientProtocol` and returning an instance of
:class:`WebSocketClientProtocol` or a subclass. It defaults to
:class:`WebSocketClientProtocol`.
The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
:func:`connect` also accepts the following optional arguments:
* ``compression`` is a shortcut to configure compression extensions;
by default it enables the "permessage-deflate" extension; set it to
``None`` to disable compression
* ``origin`` sets the Origin HTTP header
* ``extensions`` is a list of supported extensions in order of
decreasing preference
* ``subprotocols`` is a list of supported subprotocols in order of
decreasing preference
* ``extra_headers`` sets additional HTTP request headers; it can be a
:class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
pairs
:raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
:raises ~websockets.handshake.InvalidHandshake: if the opening handshake
fails
"""
MAX_REDIRECTS_ALLOWED = 10
def __init__(
self,
uri: str,
*,
path: Optional[str] = None,
create_protocol: Optional[Type[WebSocketClientProtocol]] = None,
ping_interval: float = 20,
ping_timeout: float = 20,
close_timeout: Optional[float] = None,
max_size: int = 2 ** 20,
max_queue: int = 2 ** 5,
read_limit: int = 2 ** 16,
write_limit: int = 2 ** 16,
loop: Optional[asyncio.AbstractEventLoop] = None,
legacy_recv: bool = False,
klass: Optional[Type[WebSocketClientProtocol]] = None,
timeout: Optional[float] = None,
compression: Optional[str] = "deflate",
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
**kwargs: Any,
) -> None:
# Backwards compatibility: close_timeout used to be called timeout.
if timeout is None:
timeout = 10
else:
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
# If both are specified, timeout is ignored.
if close_timeout is None:
close_timeout = timeout
# Backwards compatibility: create_protocol used to be called klass.
if klass is None:
klass = WebSocketClientProtocol
else:
warnings.warn("rename klass to create_protocol", DeprecationWarning)
# If both are specified, klass is ignored.
if create_protocol is None:
create_protocol = klass
if loop is None:
loop = asyncio.get_event_loop()
wsuri = parse_uri(uri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
elif kwargs.get("ssl") is not None:
raise ValueError(
"connect() received a ssl argument for a ws:// URI, "
"use a wss:// URI to enable TLS"
)
if compression == "deflate":
if extensions is None:
extensions = []
if not any(
extension_factory.name == ClientPerMessageDeflateFactory.name
for extension_factory in extensions
):
extensions = list(extensions) + [
ClientPerMessageDeflateFactory(client_max_window_bits=True)
]
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
factory = functools.partial(
create_protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_size=max_size,
max_queue=max_queue,
read_limit=read_limit,
write_limit=write_limit,
loop=loop,
host=wsuri.host,
port=wsuri.port,
secure=wsuri.secure,
legacy_recv=legacy_recv,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
extra_headers=extra_headers,
)
if path is None:
host: Optional[str]
port: Optional[int]
if kwargs.get("sock") is None:
host, port = wsuri.host, wsuri.port
else:
# If sock is given, host and port shouldn't be specified.
host, port = None, None
# If host and port are given, override values from the URI.
host = kwargs.pop("host", host)
port = kwargs.pop("port", port)
create_connection = functools.partial(
loop.create_connection, factory, host, port, **kwargs
)
else:
create_connection = functools.partial(
loop.create_unix_connection, factory, path, **kwargs
)
# This is a coroutine function.
self._create_connection = create_connection
self._wsuri = wsuri
def handle_redirect(self, uri: str) -> None:
# Update the state of this instance to connect to a new URI.
old_wsuri = self._wsuri
new_wsuri = parse_uri(uri)
# Forbid TLS downgrade.
if old_wsuri.secure and not new_wsuri.secure:
raise SecurityError("redirect from WSS to WS")
same_origin = (
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
)
# Rewrite the host and port arguments for cross-origin redirects.
# This preserves connection overrides with the host and port
# arguments if the redirect points to the same host and port.
if not same_origin:
# Replace the host and port argument passed to the protocol factory.
factory = self._create_connection.args[0]
factory = functools.partial(
factory.func,
*factory.args,
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
)
# Replace the host and port argument passed to create_connection.
self._create_connection = functools.partial(
self._create_connection.func,
*(factory, new_wsuri.host, new_wsuri.port),
**self._create_connection.keywords,
)
# Set the new WebSocket URI. This suffices for same-origin redirects.
self._wsuri = new_wsuri
# async with connect(...)
async def __aenter__(self) -> WebSocketClientProtocol:
return await self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.ws_client.close()
# await connect(...)
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> WebSocketClientProtocol:
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
transport, protocol = await self._create_connection()
# https://github.com/python/typeshed/pull/2756
transport = cast(asyncio.Transport, transport)
protocol = cast(WebSocketClientProtocol, protocol)
try:
try:
await protocol.handshake(
self._wsuri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
)
except Exception:
protocol.fail_connection()
await protocol.wait_closed()
raise
else:
self.ws_client = protocol
return protocol
except RedirectHandshake as exc:
self.handle_redirect(exc.uri)
else:
raise SecurityError("too many redirects")
# yield from connect(...)
__iter__ = __await__
connect = Connect
def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect:
"""
Similar to :func:`connect`, but for connecting to a Unix socket.
This function calls the event loop's
:meth:`~asyncio.loop.create_unix_connection` method.
It is only available on Unix.
It's mainly useful for debugging servers listening on Unix sockets.
:param path: file system path to the Unix socket
:param uri: WebSocket URI
"""
return connect(uri=uri, path=path, **kwargs)
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1571084468.0
websockets-8.1/src/websockets/exceptions.py 0000644 0000765 0000024 00000021170 00000000000 021153 0 ustar 00myk staff 0000000 0000000 """
:mod:`websockets.exceptions` defines the following exception hierarchy:
* :exc:`WebSocketException`
* :exc:`ConnectionClosed`
* :exc:`ConnectionClosedError`
* :exc:`ConnectionClosedOK`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`InvalidMessage`
* :exc:`InvalidHeader`
* :exc:`InvalidHeaderFormat`
* :exc:`InvalidHeaderValue`
* :exc:`InvalidOrigin`
* :exc:`InvalidUpgrade`
* :exc:`InvalidStatusCode`
* :exc:`NegotiationError`
* :exc:`DuplicateParameter`
* :exc:`InvalidParameterName`
* :exc:`InvalidParameterValue`
* :exc:`AbortHandshake`
* :exc:`RedirectHandshake`
* :exc:`InvalidState`
* :exc:`InvalidURI`
* :exc:`PayloadTooBig`
* :exc:`ProtocolError`
"""
import http
from typing import Optional
from .http import Headers, HeadersLike
__all__ = [
"WebSocketException",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
"InvalidHandshake",
"SecurityError",
"InvalidMessage",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidOrigin",
"InvalidUpgrade",
"InvalidStatusCode",
"NegotiationError",
"DuplicateParameter",
"InvalidParameterName",
"InvalidParameterValue",
"AbortHandshake",
"RedirectHandshake",
"InvalidState",
"InvalidURI",
"PayloadTooBig",
"ProtocolError",
"WebSocketProtocolError",
]
class WebSocketException(Exception):
"""
Base class for all exceptions defined by :mod:`websockets`.
"""
CLOSE_CODES = {
1000: "OK",
1001: "going away",
1002: "protocol error",
1003: "unsupported type",
# 1004 is reserved
1005: "no status code [internal]",
1006: "connection closed abnormally [internal]",
1007: "invalid data",
1008: "policy violation",
1009: "message too big",
1010: "extension required",
1011: "unexpected error",
1015: "TLS failure [internal]",
}
def format_close(code: int, reason: str) -> str:
"""
Display a human-readable version of the close code and reason.
"""
if 3000 <= code < 4000:
explanation = "registered"
elif 4000 <= code < 5000:
explanation = "private use"
else:
explanation = CLOSE_CODES.get(code, "unknown")
result = f"code = {code} ({explanation}), "
if reason:
result += f"reason = {reason}"
else:
result += "no reason"
return result
class ConnectionClosed(WebSocketException):
"""
Raised when trying to interact with a closed connection.
Provides the connection close code and reason in its ``code`` and
``reason`` attributes respectively.
"""
def __init__(self, code: int, reason: str) -> None:
self.code = code
self.reason = reason
super().__init__(format_close(code, reason))
class ConnectionClosedError(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
This means the close code is different from 1000 (OK) and 1001 (going away).
"""
def __init__(self, code: int, reason: str) -> None:
assert code != 1000 and code != 1001
super().__init__(code, reason)
class ConnectionClosedOK(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated properly.
This means the close code is 1000 (OK) or 1001 (going away).
"""
def __init__(self, code: int, reason: str) -> None:
assert code == 1000 or code == 1001
super().__init__(code, reason)
class InvalidHandshake(WebSocketException):
"""
Raised during the handshake when the WebSocket connection fails.
"""
class SecurityError(InvalidHandshake):
"""
Raised when a handshake request or response breaks a security rule.
Security limits are hard coded.
"""
class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""
class InvalidHeader(InvalidHandshake):
"""
Raised when a HTTP header doesn't have a valid format or value.
"""
def __init__(self, name: str, value: Optional[str] = None) -> None:
self.name = name
self.value = value
if value is None:
message = f"missing {name} header"
elif value == "":
message = f"empty {name} header"
else:
message = f"invalid {name} header: {value}"
super().__init__(message)
class InvalidHeaderFormat(InvalidHeader):
"""
Raised when a HTTP header cannot be parsed.
The format of the header doesn't match the grammar for that header.
"""
def __init__(self, name: str, error: str, header: str, pos: int) -> None:
self.name = name
error = f"{error} at {pos} in {header}"
super().__init__(name, error)
class InvalidHeaderValue(InvalidHeader):
"""
Raised when a HTTP header has a wrong value.
The format of the header is correct but a value isn't acceptable.
"""
class InvalidOrigin(InvalidHeader):
"""
Raised when the Origin header in a request isn't allowed.
"""
def __init__(self, origin: Optional[str]) -> None:
super().__init__("Origin", origin)
class InvalidUpgrade(InvalidHeader):
"""
Raised when the Upgrade or Connection header isn't correct.
"""
class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
The integer status code is available in the ``status_code`` attribute.
"""
def __init__(self, status_code: int) -> None:
self.status_code = status_code
message = f"server rejected WebSocket connection: HTTP {status_code}"
super().__init__(message)
class NegotiationError(InvalidHandshake):
"""
Raised when negotiating an extension fails.
"""
class DuplicateParameter(NegotiationError):
"""
Raised when a parameter name is repeated in an extension header.
"""
def __init__(self, name: str) -> None:
self.name = name
message = f"duplicate parameter: {name}"
super().__init__(message)
class InvalidParameterName(NegotiationError):
"""
Raised when a parameter name in an extension header is invalid.
"""
def __init__(self, name: str) -> None:
self.name = name
message = f"invalid parameter name: {name}"
super().__init__(message)
class InvalidParameterValue(NegotiationError):
"""
Raised when a parameter value in an extension header is invalid.
"""
def __init__(self, name: str, value: Optional[str]) -> None:
self.name = name
self.value = value
if value is None:
message = f"missing value for parameter {name}"
elif value == "":
message = f"empty value for parameter {name}"
else:
message = f"invalid value for parameter {name}: {value}"
super().__init__(message)
class AbortHandshake(InvalidHandshake):
"""
Raised to abort the handshake on purpose and return a HTTP response.
This exception is an implementation detail.
The public API is :meth:`~server.WebSocketServerProtocol.process_request`.
"""
def __init__(
self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b""
) -> None:
self.status = status
self.headers = Headers(headers)
self.body = body
message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes"
super().__init__(message)
class RedirectHandshake(InvalidHandshake):
"""
Raised when a handshake gets redirected.
This exception is an implementation detail.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
def __str__(self) -> str:
return f"redirect to {self.uri}"
class InvalidState(WebSocketException, AssertionError):
"""
Raised when an operation is forbidden in the current state.
This exception is an implementation detail.
It should never be raised in normal circumstances.
"""
class InvalidURI(WebSocketException):
"""
Raised when connecting to an URI that isn't a valid WebSocket URI.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
message = "{} isn't a valid URI".format(uri)
super().__init__(message)
class PayloadTooBig(WebSocketException):
"""
Raised when receiving a frame with a payload exceeding the maximum size.
"""
class ProtocolError(WebSocketException):
"""
Raised when the other side breaks the protocol.
"""
WebSocketProtocolError = ProtocolError # for backwards compatibility
././@PaxHeader 0000000 0000000 0000000 00000000034 00000000000 011452 x ustar 00 0000000 0000000 28 mtime=1572615610.4764388
websockets-8.1/src/websockets/extensions/ 0000755 0000765 0000024 00000000000 00000000000 020616 5 ustar 00myk staff 0000000 0000000 ././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1546111875.0
websockets-8.1/src/websockets/extensions/__init__.py 0000644 0000765 0000024 00000000000 00000000000 022715 0 ustar 00myk staff 0000000 0000000 ././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1571084468.0
websockets-8.1/src/websockets/extensions/base.py 0000644 0000765 0000024 00000005340 00000000000 022104 0 ustar 00myk staff 0000000 0000000 """
:mod:`websockets.extensions.base` defines abstract classes for implementing
extensions.
See `section 9 of RFC 6455`_.
.. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9
"""
from typing import List, Optional, Sequence, Tuple
from ..framing import Frame
from ..typing import ExtensionName, ExtensionParameter
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
class Extension:
"""
Abstract class for extensions.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame:
"""
Decode an incoming frame.
:param frame: incoming frame
:param max_size: maximum payload size in bytes
"""
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.
:param frame: outgoing frame
"""
class ClientExtensionFactory:
"""
Abstract class for client-side extension factories.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def get_request_params(self) -> List[ExtensionParameter]:
"""
Build request parameters.
Return a list of ``(name, value)`` pairs.
"""
def process_response_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence[Extension],
) -> Extension:
"""
Process response parameters received from the server.
:param params: list of ``(name, value)`` pairs.
:param accepted_extensions: list of previously accepted extensions.
:raises ~websockets.exceptions.NegotiationError: if parameters aren't
acceptable
"""
class ServerExtensionFactory:
"""
Abstract class for server-side extension factories.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def process_request_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence[Extension],
) -> Tuple[List[ExtensionParameter], Extension]:
"""
Process request parameters received from the client.
To accept the offer, return a 2-uple containing:
- response parameters: a list of ``(name, value)`` pairs
- an extension: an instance of a subclass of :class:`Extension`
:param params: list of ``(name, value)`` pairs.
:param accepted_extensions: list of previously accepted extensions.
:raises ~websockets.exceptions.NegotiationError: to reject the offer,
if parameters aren't acceptable
"""
././@PaxHeader 0000000 0000000 0000000 00000000026 00000000000 011453 x ustar 00 0000000 0000000 22 mtime=1571084468.0
websockets-8.1/src/websockets/extensions/permessage_deflate.py 0000644 0000765 0000024 00000052342 00000000000 025015 0 ustar 00myk staff 0000000 0000000 """
:mod:`websockets.extensions.permessage_deflate` implements the Compression
Extensions for WebSocket as specified in :rfc:`7692`.
"""
import zlib
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from ..exceptions import (
DuplicateParameter,
InvalidParameterName,
InvalidParameterValue,
NegotiationError,
PayloadTooBig,
)
from ..framing import CTRL_OPCODES, OP_CONT, Frame
from ..typing import ExtensionName, ExtensionParameter
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
__all__ = [
"PerMessageDeflate",
"ClientPerMessageDeflateFactory",
"ServerPerMessageDeflateFactory",
]
_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
class PerMessageDeflate(Extension):
"""
Per-Message Deflate extension.
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
remote_no_context_takeover: bool,
local_no_context_takeover: bool,
remote_max_window_bits: int,
local_max_window_bits: int,
compress_settings: Optional[Dict[Any, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension.
"""
if compress_settings is None:
compress_settings = {}
assert remote_no_context_takeover in [False, True]
assert local_no_context_takeover in [False, True]
assert 8 <= remote_max_window_bits <= 15
assert 8 <= local_max_window_bits <= 15
assert "wbits" not in compress_settings
self.remote_no_context_takeover = remote_no_context_takeover
self.local_no_context_takeover = local_no_context_takeover
self.remote_max_window_bits = remote_max_window_bits
self.local_max_window_bits = local_max_window_bits
self.compress_settings = compress_settings
if not self.remote_no_context_takeover:
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
if not self.local_no_context_takeover:
self.encoder = zlib.compressobj(
wbits=-self.local_max_window_bits, **self.compress_settings
)
# To handle continuation frames properly, we must keep track of
# whether that initial frame was encoded.
self.decode_cont_data = False
# There's no need for self.encode_cont_data because we always encode
# outgoing frames, so it would always be True.
def __repr__(self) -> str:
return (
f"PerMessageDeflate("
f"remote_no_context_takeover={self.remote_no_context_takeover}, "
f"local_no_context_takeover={self.local_no_context_takeover}, "
f"remote_max_window_bits={self.remote_max_window_bits}, "
f"local_max_window_bits={self.local_max_window_bits})"
)
def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame:
"""
Decode an incoming frame.
"""
# Skip control frames.
if frame.opcode in CTRL_OPCODES:
return frame
# Handle continuation data frames:
# - skip if the initial data frame wasn't encoded
# - reset "decode continuation data" flag if it's a final frame
if frame.opcode == OP_CONT:
if not self.decode_cont_data:
return frame
if frame.fin:
self.decode_cont_data = False
# Handle text and binary data frames:
# - skip if the frame isn't encoded
# - set "decode continuation data" flag if it's a non-final frame
else:
if not frame.rsv1:
return frame
if not frame.fin: # frame.rsv1 is True at this point
self.decode_cont_data = True
# Re-initialize per-message decoder.
if self.remote_no_context_takeover:
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
# Uncompress compressed frames. Protect against zip bombs by
# preventing zlib from decompressing more than max_length bytes
# (except when the limit is disabled with max_size = None).
data = frame.data
if frame.fin:
data += _EMPTY_UNCOMPRESSED_BLOCK
max_length = 0 if max_size is None else max_size
data = self.decoder.decompress(data, max_length)
if self.decoder.unconsumed_tail:
raise PayloadTooBig(
f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)"
)
# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
del self.decoder
return frame._replace(data=data, rsv1=False)
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.
"""
# Skip control frames.
if frame.opcode in CTRL_OPCODES:
return frame
# Since we always encode and never fragment messages, there's no logic
# similar to decode() here at this time.
if frame.opcode != OP_CONT:
# Re-initialize per-message decoder.
if self.local_no_context_takeover:
self.encoder = zlib.compressobj(
wbits=-self.local_max_window_bits, **self.compress_settings
)
# Compress data frames.
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
data = data[:-4]
# Allow garbage collection of the encoder if it won't be reused.
if frame.fin and self.local_no_context_takeover:
del self.encoder
return frame._replace(data=data, rsv1=True)
def _build_parameters(
server_no_context_takeover: bool,
client_no_context_takeover: bool,
server_max_window_bits: Optional[int],
client_max_window_bits: Optional[Union[int, bool]],
) -> List[ExtensionParameter]:
"""
Build a list of ``(name, value)`` pairs for some compression parameters.
"""
params: List[ExtensionParameter] = []
if server_no_context_takeover:
params.append(("server_no_context_takeover", None))
if client_no_context_takeover:
params.append(("client_no_context_takeover", None))
if server_max_window_bits:
params.append(("server_max_window_bits", str(server_max_window_bits)))
if client_max_window_bits is True: # only in handshake requests
params.append(("client_max_window_bits", None))
elif client_max_window_bits:
params.append(("client_max_window_bits", str(client_max_window_bits)))
return params
def _extract_parameters(
params: Sequence[ExtensionParameter], *, is_server: bool
) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]:
"""
Extract compression parameters from a list of ``(name, value)`` pairs.
If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided
without a value. This is only allow in handshake requests.
"""
server_no_context_takeover: bool = False
client_no_context_takeover: bool = False
server_max_window_bits: Optional[int] = None
client_max_window_bits: Optional[Union[int, bool]] = None
for name, value in params:
if name == "server_no_context_takeover":
if server_no_context_takeover:
raise DuplicateParameter(name)
if value is None:
server_no_context_takeover = True
else:
raise InvalidParameterValue(name, value)
elif name == "client_no_context_takeover":
if client_no_context_takeover:
raise DuplicateParameter(name)
if value is None:
client_no_context_takeover = True
else:
raise InvalidParameterValue(name, value)
elif name == "server_max_window_bits":
if server_max_window_bits is not None:
raise DuplicateParameter(name)
if value in _MAX_WINDOW_BITS_VALUES:
server_max_window_bits = int(value)
else:
raise InvalidParameterValue(name, value)
elif name == "client_max_window_bits":
if client_max_window_bits is not None:
raise DuplicateParameter(name)
if is_server and value is None: # only in handshake requests
client_max_window_bits = True
elif value in _MAX_WINDOW_BITS_VALUES:
client_max_window_bits = int(value)
else:
raise InvalidParameterValue(name, value)
else:
raise InvalidParameterName(name)
return (
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
)
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
"""
Client-side extension factory for the Per-Message Deflate extension.
Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
``True`` to include them in the negotiation offer without a value or to an
integer value to include them with this value.
.. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
:param server_no_context_takeover: defaults to ``False``
:param client_no_context_takeover: defaults to ``False``
:param server_max_window_bits: optional, defaults to ``None``
:param client_max_window_bits: optional, defaults to ``None``
:param compress_settings: optional, keyword arguments for
:func:`zlib.compressobj`, excluding ``wbits``
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
server_max_window_bits: Optional[int] = None,
client_max_window_bits: Optional[Union[int, bool]] = None,
compress_settings: Optional[Dict[str, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension factory.
"""
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
raise ValueError("server_max_window_bits must be between 8 and 15")
if not (
client_max_window_bits is None
or client_max_window_bits is True
or 8 <= client_max_window_bits <= 15
):
raise ValueError("client_max_window_bits must be between 8 and 15")
if compress_settings is not None and "wbits" in compress_settings:
raise ValueError(
"compress_settings must not include wbits, "
"set client_max_window_bits instead"
)
self.server_no_context_takeover = server_no_context_takeover
self.client_no_context_takeover = client_no_context_takeover
self.server_max_window_bits = server_max_window_bits
self.client_max_window_bits = client_max_window_bits
self.compress_settings = compress_settings
def get_request_params(self) -> List[ExtensionParameter]:
"""
Build request parameters.
"""
return _build_parameters(
self.server_no_context_takeover,
self.client_no_context_takeover,
self.server_max_window_bits,
self.client_max_window_bits,
)
def process_response_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence["Extension"],
) -> PerMessageDeflate:
"""
Process response parameters.
Return an extension instance.
"""
if any(other.name == self.name for other in accepted_extensions):
raise NegotiationError(f"received duplicate {self.name}")
# Request parameters are available in instance variables.
# Load response parameters in local variables.
(
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
) = _extract_parameters(params, is_server=False)
# After comparing the request and the response, the final
# configuration must be available in the local variables.
# server_no_context_takeover
#
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False Error!
# True True True
if self.server_no_context_takeover:
if not server_no_context_takeover:
raise NegotiationError("expected server_no_context_takeover")
# client_no_context_takeover
#
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False True - must change value
# True True True
if self.client_no_context_takeover:
if not client_no_context_takeover:
client_no_context_takeover = True
# server_max_window_bits
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 M
# 8≤N≤15 None Error!
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N self.server_max_window_bits:
raise NegotiationError("unsupported server_max_window_bits")
# client_max_window_bits
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 Error!
# True None None
# True 8≤M≤15 M
# 8≤N≤15 None N - must change value
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N self.client_max_window_bits:
raise NegotiationError("unsupported client_max_window_bits")
return PerMessageDeflate(
server_no_context_takeover, # remote_no_context_takeover
client_no_context_takeover, # local_no_context_takeover
server_max_window_bits or 15, # remote_max_window_bits
client_max_window_bits or 15, # local_max_window_bits
self.compress_settings,
)
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
"""
Server-side extension factory for the Per-Message Deflate extension.
Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
``True`` to include them in the negotiation offer without a value or to an
integer value to include them with this value.
.. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
:param server_no_context_takeover: defaults to ``False``
:param client_no_context_takeover: defaults to ``False``
:param server_max_window_bits: optional, defaults to ``None``
:param client_max_window_bits: optional, defaults to ``None``
:param compress_settings: optional, keyword arguments for
:func:`zlib.compressobj`, excluding ``wbits``
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
server_max_window_bits: Optional[int] = None,
client_max_window_bits: Optional[int] = None,
compress_settings: Optional[Dict[str, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension factory.
"""
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
raise ValueError("server_max_window_bits must be between 8 and 15")
if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
raise ValueError("client_max_window_bits must be between 8 and 15")
if compress_settings is not None and "wbits" in compress_settings:
raise ValueError(
"compress_settings must not include wbits, "
"set server_max_window_bits instead"
)
self.server_no_context_takeover = server_no_context_takeover
self.client_no_context_takeover = client_no_context_takeover
self.server_max_window_bits = server_max_window_bits
self.client_max_window_bits = client_max_window_bits
self.compress_settings = compress_settings
def process_request_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence["Extension"],
) -> Tuple[List[ExtensionParameter], PerMessageDeflate]:
"""
Process request parameters.
Return response params and an extension instance.
"""
if any(other.name == self.name for other in accepted_extensions):
raise NegotiationError(f"skipped duplicate {self.name}")
# Load request parameters in local variables.
(
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
) = _extract_parameters(params, is_server=True)
# Configuration parameters are available in instance variables.
# After comparing the request and the configuration, the response must
# be available in the local variables.
# server_no_context_takeover
#
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False True - must change value to True
# True True True
if self.server_no_context_takeover:
if not server_no_context_takeover:
server_no_context_takeover = True
# client_no_context_takeover
#
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# False False False
# False True True (or False)
# True False True - must change value to True
# True True True (or False)
if self.client_no_context_takeover:
if not client_no_context_takeover:
client_no_context_takeover = True
# server_max_window_bits
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 M
# 8≤N≤15 None N - must change value
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N self.server_max_window_bits:
server_max_window_bits = self.server_max_window_bits
# client_max_window_bits
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# None None None
# None True None - must change value
# None 8≤M≤15 M (or None)
# 8≤N≤15 None Error!
# 8≤N≤15 True N - must change value
# 8≤N≤15 8≤M≤N M (or None)
# 8≤N≤15 N