pax_global_header00006660000000000000000000000064140005643350014512gustar00rootroot0000000000000052 comment=dc6ce15776f2ea3d5ac54119985f679b232984d6 .coveragerc000066400000000000000000000000611400056433500131740ustar00rootroot00000000000000[run] relative_files = True source= aioopenssl .github/000077500000000000000000000000001400056433500124165ustar00rootroot00000000000000.github/workflows/000077500000000000000000000000001400056433500144535ustar00rootroot00000000000000.github/workflows/main.yaml000066400000000000000000000037731400056433500162750ustar00rootroot00000000000000name: CI on: push: branches: - devel - master - "release-*" pull_request: branches: - devel - master - "release-*" workflow_dispatch: jobs: mypy: runs-on: ubuntu-latest name: 'typecheck: mypy' steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: python-version: '3.7' - name: Install run: | set -euo pipefail pip install mypy pip install . - name: Typecheck run: | python -m mypy --config mypy.ini -p aioopenssl linting: runs-on: ubuntu-latest name: 'lint: flake8' steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: python-version: '3.7' - name: Install run: | set -euo pipefail pip install flake8 - name: Linting run: | python -m flake8 aioopenssl tests test: needs: - mypy - linting runs-on: ubuntu-latest strategy: matrix: python-version: - '3.5' - '3.6' - '3.7' - '3.8' - '3.9' name: 'unit: py${{ matrix.python-version }}' steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: python-version: '${{ matrix.python-version }}' - name: Install run: | set -euo pipefail pip install nose coveralls pyOpenSSL pip install . - name: Run tests run: | set -euo pipefail python -m nose --with-cover --cover-package aioopenssl tests - name: Coveralls uses: AndreMiras/coveralls-python-action@develop with: parallel: true flag-name: python-${{ matrix.python-version }} finish: needs: test runs-on: ubuntu-latest name: Finalize steps: - name: Finalize Coveralls interaction uses: AndreMiras/coveralls-python-action@develop with: parallel-finished: true .gitignore000066400000000000000000000000271400056433500130450ustar00rootroot00000000000000__pycache__ *.egg-info COPYING000066400000000000000000000251421400056433500121150ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. MANIFEST.in000066400000000000000000000000201400056433500126040ustar00rootroot00000000000000include COPYING Makefile000066400000000000000000000003731400056433500125210ustar00rootroot00000000000000SPHINXBUILD ?= sphinx-build-3 docs-html: cd docs; $(MAKE) SPHINXBUILD=$(SPHINXBUILD) html docs-view-html: docs-html xdg-open docs/sphinx-data/build/html/index.html docs-clean: cd docs; $(MAKE) SPHINXBUILD=$(SPHINXBUILD) clean .PHONY: docs-html README.rst000066400000000000000000000034551400056433500125540ustar00rootroot00000000000000OpenSSL Transport for asyncio ############################# .. image:: https://github.com/horazont/aioopenssl/workflows/CI/badge.svg :target: https://github.com/horazont/aioopenssl/actions?query=workflow%3ACI+branch%3Adevel .. image:: https://coveralls.io/repos/github/horazont/aioopenssl/badge.svg?branch=devel :target: https://coveralls.io/github/horazont/aioopenssl?branch=devel ``aioopenssl`` provides a `asyncio `_ Transport which uses `PyOpenSSL `_ instead of the built-in ssl module. The transport has two main advantages compared to the original: * The TLS handshake can be deferred by passing ``use_starttls=True`` and later calling the ``starttls()`` coroutine method. This is useful for protocols with a `STARTTLS `_ feature. * A coroutine can be called during the TLS handshake; this can be used to defer the certificate check to a later point, allowing e.g. to get user feedback before the ``starttls()`` method returns. This allows to ask users for certificate trust without the application layer protocol interfering or starting to communicate with the unverified peer. .. note:: Use this module at your own risk. It has lower test coverage than I’d like it to have; it has been exported from aioxmpp on request, where it undergoes implicit testing. If you find bugs, please report them. If possible, add regression tests while you’re at it. If you find security-critical bugs, please follow the procedure announced in the `aioxmpp readme `_.` Documentation ------------- Official documentation can be built with sphinx and is available online `on our servers `_. aioopenssl/000077500000000000000000000000001400056433500132325ustar00rootroot00000000000000aioopenssl/__init__.py000066400000000000000000001023241400056433500153450ustar00rootroot00000000000000""" # NOQA :mod:`aioopenssl` --- A transport for asyncio using :mod:`OpenSSL` ################################################################## This package provides a socket-based :class:`asyncio.Transport` which uses :mod:`OpenSSL` to create a TLS connection. Optionally, the TLS handshake can be deferred and performed later using :meth:`STARTTLSTransport.starttls`. .. note:: Use this module at your own risk. It has lower test coverage than I’d like it to have; it has been exported from aioxmpp on request, where it undergoes implicit testing. If you find bugs, please report them. If possible, add regression tests while you’re at it. If you find security-critical bugs, please follow the procedure announced in the `aioxmpp readme `_. The following function can be used to create a connection using the :class:`STARTTLSTransport`, which itself is documented below: .. autofunction:: create_starttls_connection The transport implementation is documented below: .. autoclass:: STARTTLSTransport(loop, rawsock, protocol, ssl_context_factory, [waiter=None], [use_starttls=False], [post_handshake_callback=None], [peer_hostname=None], [server_hostname=None]) :members: """ import asyncio import logging import socket import typing from enum import Enum from .version import __version__, version_info, version # noqa:F401 from .utils import SendWrap import OpenSSL.SSL logger = logging.getLogger(__name__) class _State(Enum): RAW_OPEN = 0x0000 # noqa:E221 RAW_EOF_RECEIVED = 0x0001 # noqa:E221 TLS_HANDSHAKING = 0x0300 # noqa:E221 TLS_OPEN = 0x0100 # noqa:E221 TLS_EOF_RECEIVED = 0x0101 # noqa:E221 TLS_SHUTTING_DOWN = 0x0102 # noqa:E221 TLS_SHUT_DOWN = 0x0103 # noqa:E221 CLOSED = 0x0003 # noqa:E221 @property def eof_received(self) -> bool: return bool(self.value & 0x0001) @property def tls_started(self) -> bool: return bool(self.value & 0x0100) @property def tls_handshaking(self) -> bool: return bool(self.value & 0x0200) @property def is_writable(self) -> bool: return not bool(self.value & 0x0002) @property def is_open(self) -> bool: return (self.value & 0x3) == 0 SSLContextFactory = typing.Callable[[asyncio.Transport], OpenSSL.SSL.Context] PostHandshakeCallback = typing.Callable[ ["STARTTLSTransport"], typing.Coroutine[typing.Any, typing.Any, None], ] class STARTTLSTransport(asyncio.Transport): """ Create a new :class:`asyncio.Transport` which supports TLS and the deferred starting of TLS using the :meth:`starttls` method. `loop` must be a :class:`asyncio.BaseEventLoop` with support for :meth:`BaseEventLoop.add_reader` as well as removal and the writer complements. `rawsock` must be a :class:`socket.socket` which will be used as the socket for the transport. `protocol` must be a :class:`asyncio.Protocol` which will be fed the data the transport receives. `ssl_context_factory` must be a callable accepting a single positional argument which returns a :class:`OpenSSL.SSL.Context`. The transport will be passed as the argument to the factory. The returned context will be used to create the :class:`OpenSSL.SSL.Connection` when TLS is enabled on the transport. If the callable is :data:`None`, a `ssl_context` must be supplied to :meth:`starttls` and `use_starttls` must be true. `use_starttls` must be a boolean value. If it is true, TLS is not enabled immediately. Instead, the user must call :meth:`starttls` to enable TLS on the transport. Until that point, the transport is unencrypted. If it is false, the TLS handshake is started immediately. This is roughly equivalent to calling :meth:`starttls` immediately. `peer_hostname` must be either a :class:`str` or :data:`None`. It may be used by certificate validators and must be the host name this transport actually connected to. That might be (e.g. in the case of XMPP) different from the actual domain name the transport communicates with (and for which the service must have a valid certificate). This host name may be used by certificate validators implementing e.g. DANE. `server_hostname` must be either a :class:`str` or :data:`None`. It may be used by certificate validators anrd must be the host name for which the peer must have a valid certificate (if host name based certificate validation is performed). `server_hostname` is also passed via the TLS Server Name Indication (SNI) extension if it is given. If host names are to be converted to :class:`bytes` by the transport, they are encoded using the ``utf-8`` codec. If `waiter` is not :data:`None`, it must be a :class:`asyncio.Future`. After the stream has been established, the futures result is set to a value of :data:`None`. If any errors occur, the exception is set on the future. If `use_starttls` is true, the future is fulfilled immediately after construction, as there is no blocking process which needs to take place. If `use_starttls` is false and thus TLS negotiation starts right away, the future is fulfilled when TLS negotiation is complete. `post_handshake_callback` may be a coroutine or :data:`None`. If it is not :data:`None`, it is called asynchronously after the TLS handshake and blocks the completion of the TLS handshake until it returns. It can be used to perform blocking post-handshake certificate verification, e.g. using DANE. The coroutine must not return a value. If it encounters an error, an appropriate exception should be raised, which will propagate out of :meth:`starttls` and/or passed to the `waiter` future. """ MAX_SIZE = 256 * 1024 def __init__( self, loop: asyncio.BaseEventLoop, rawsock: socket.socket, protocol: asyncio.Protocol, ssl_context_factory: typing.Optional[SSLContextFactory] = None, waiter: typing.Optional[asyncio.Future] = None, use_starttls: bool = False, post_handshake_callback: typing.Optional[ PostHandshakeCallback ] = None, peer_hostname: typing.Optional[str] = None, server_hostname: typing.Optional[str] = None): if not use_starttls and not ssl_context_factory: raise ValueError("Cannot have STARTTLS disabled (i.e. immediate " "TLS connection) and without SSL context.") super().__init__() self._rawsock = rawsock self._raw_fd = rawsock.fileno() self._trace_logger = logger.getChild( "trace.fd={}".format(self._raw_fd) ) self._sock = rawsock # type: typing.Union[socket.socket, OpenSSL.SSL.Connection] # noqa self._send_wrap = SendWrap(self._sock) self._protocol = protocol self._loop = loop self._extra = { "socket": rawsock, } # type: typing.Dict[str, typing.Any] self._waiter = waiter self._conn_lost = 0 self._buffer = bytearray() self._ssl_context_factory = ssl_context_factory self._extra.update( sslcontext=None, ssl_object=None, peername=self._rawsock.getpeername(), peer_hostname=peer_hostname, server_hostname=server_hostname ) # this is a list set of tasks which will also be cancelled if the # _waiter is cancelled self._chained_pending = set() # type: typing.Set[asyncio.Future] self._paused = False self._closing = False self._tls_conn = None # type: typing.Optional[OpenSSL.SSL.Connection] self._tls_read_wants_write = False self._tls_write_wants_read = False self._tls_post_handshake_callback = post_handshake_callback self._state = None # type: typing.Optional[_State] if not use_starttls: assert ssl_context_factory is not None self._ssl_context = ssl_context_factory(self) self._extra.update( sslcontext=self._ssl_context, ) self._initiate_tls() else: self._initiate_raw() def _waiter_done(self, fut: asyncio.Future) -> None: self._trace_logger.debug("_waiter future done (%r)", fut) for chained in self._chained_pending: self._trace_logger.debug("cancelling chained %r", chained) chained.cancel() self._chained_pending.clear() def _invalid_transition( self, via: typing.Optional[str] = None, to: typing.Optional[_State] = None) -> None: via_text = (" via {}".format(via)) if via is not None else "" to_text = (" to {}".format(to)) if to is not None else "" msg = "Invalid state transition (from {}{}{})".format( self._state, via_text, to_text ) logger.error(msg) raise RuntimeError(msg) def _invalid_state( self, what: str, exc: typing.Type[Exception] = RuntimeError, ) -> Exception: msg = "{what} (invalid in state {state}, closing={closing})".format( what=what, state=self._state, closing=self._closing) logger.error(msg) # raising is optional :) return exc(msg) def _fatal_error( self, exc: BaseException, msg: str) -> None: if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): self._loop.call_exception_handler({ "message": msg, "exception": exc, "transport": self, "protocol": self._protocol }) self._force_close(exc) def _force_close( self, exc: typing.Optional[BaseException], ) -> None: self._trace_logger.debug("_force_close called") self._remove_rw() if self._state == _State.CLOSED: raise self._invalid_state("_force_close called") self._state = _State.CLOSED if self._buffer: self._buffer.clear() if self._waiter is not None and not self._waiter.done(): self._waiter.set_exception( exc or ConnectionError("_force_close() called"), ) self._loop.remove_reader(self._raw_fd) self._loop.remove_writer(self._raw_fd) self._loop.call_soon(self._call_connection_lost_and_clean_up, exc) def _remove_rw(self) -> None: self._trace_logger.debug("clearing readers/writers") self._loop.remove_reader(self._raw_fd) self._loop.remove_writer(self._raw_fd) def _call_connection_lost_and_clean_up( self, exc: Exception, ) -> None: """ Clean up all resources and call the protocols connection lost method. """ self._state = _State.CLOSED try: self._protocol.connection_lost(exc) finally: self._rawsock.close() if self._tls_conn is not None: self._tls_conn.set_app_data(None) self._tls_conn = None self._rawsock = None # type:ignore self._protocol = None # type:ignore def _initiate_raw(self) -> None: if self._state is not None: self._invalid_transition(via="_initiate_raw", to=_State.RAW_OPEN) self._state = _State.RAW_OPEN self._loop.add_reader(self._raw_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) self._waiter = None def _initiate_tls(self) -> None: self._trace_logger.debug("_initiate_tls called") if self._state is not None and self._state != _State.RAW_OPEN: self._invalid_transition(via="_initiate_tls", to=_State.TLS_HANDSHAKING) self._tls_was_starttls = (self._state == _State.RAW_OPEN) self._state = _State.TLS_HANDSHAKING self._tls_conn = OpenSSL.SSL.Connection( self._ssl_context, self._sock) self._tls_conn.set_connect_state() self._tls_conn.set_app_data(self) try: self._tls_conn.set_tlsext_host_name( self._extra["server_hostname"].encode("IDNA")) except KeyError: pass self._sock = self._tls_conn self._send_wrap = SendWrap(self._sock) self._extra.update( ssl_object=self._tls_conn ) self._tls_do_handshake() def _tls_do_handshake(self) -> None: assert self._tls_conn is not None self._trace_logger.debug("_tls_do_handshake called") if self._state != _State.TLS_HANDSHAKING: raise self._invalid_state("_tls_do_handshake called") try: self._tls_conn.do_handshake() except OpenSSL.SSL.WantReadError: self._trace_logger.debug( "registering reader for _tls_do_handshake") self._loop.add_reader(self._raw_fd, self._tls_do_handshake) return except OpenSSL.SSL.WantWriteError: self._trace_logger.debug( "registering writer for _tls_do_handshake") self._loop.add_writer(self._raw_fd, self._tls_do_handshake) return except Exception as exc: self._remove_rw() self._fatal_error(exc, "Fatal error on tls handshake") if self._waiter is not None: self._waiter.set_exception(exc) return except BaseException as exc: self._remove_rw() if self._waiter is not None: self._waiter.set_exception(exc) raise self._remove_rw() # handshake complete self._trace_logger.debug("handshake complete") self._extra.update( peercert=self._tls_conn.get_peer_certificate() ) if self._tls_post_handshake_callback: self._trace_logger.debug("post handshake scheduled via callback") task = asyncio.ensure_future( self._tls_post_handshake_callback(self) ) task.add_done_callback(self._tls_post_handshake_done) self._chained_pending.add(task) self._tls_post_handshake_callback = None else: self._tls_post_handshake(None) def _tls_post_handshake_done( self, task: asyncio.Future, ) -> None: self._chained_pending.discard(task) try: task.result() except asyncio.CancelledError: # canceled due to closure or something similar pass except BaseException as err: self._tls_post_handshake(err) else: self._tls_post_handshake(None) def _tls_post_handshake( self, exc: typing.Optional[BaseException], ) -> None: self._trace_logger.debug("_tls_post_handshake called") if exc is not None: if self._waiter is not None and not self._waiter.done(): self._waiter.set_exception(exc) self._fatal_error(exc, "Fatal error on post-handshake callback") return self._tls_read_wants_write = False self._tls_write_wants_read = False self._state = _State.TLS_OPEN self._loop.add_reader(self._raw_fd, self._read_ready) if not self._tls_was_starttls: self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) def _tls_do_shutdown(self) -> None: self._trace_logger.debug("_tls_do_shutdown called") if self._state != _State.TLS_SHUTTING_DOWN: raise self._invalid_state("_tls_do_shutdown called") assert isinstance(self._sock, OpenSSL.SSL.Connection) try: self._sock.shutdown() except OpenSSL.SSL.WantReadError: self._trace_logger.debug("registering reader for _tls_shutdown") self._loop.add_reader(self._raw_fd, self._tls_shutdown) return except OpenSSL.SSL.WantWriteError: self._trace_logger.debug("registering writer for _tls_shutdown") self._loop.add_writer(self._raw_fd, self._tls_shutdown) return except Exception as exc: # force_close will take care of removing rw handlers self._fatal_error(exc, "Fatal error on tls shutdown") return except BaseException: self._remove_rw() raise self._remove_rw() self._state = _State.TLS_SHUT_DOWN # continue to raw shut down self._raw_shutdown() def _tls_shutdown(self) -> None: self._state = _State.TLS_SHUTTING_DOWN self._tls_do_shutdown() def _raw_shutdown(self) -> None: self._remove_rw() try: self._rawsock.shutdown(socket.SHUT_RDWR) except OSError: # we cannot do anything anyway if this fails pass self._force_close(None) def _read_ready(self) -> None: assert self._state is not None if self._state.tls_started and self._tls_write_wants_read: self._tls_write_wants_read = False self._write_ready() if self._buffer: self._trace_logger.debug("_read_ready: add writer for more" " data") self._loop.add_writer(self._raw_fd, self._write_ready) if self._state.eof_received: # no further reading return try: data = self._sock.recv(self.MAX_SIZE) except (BlockingIOError, InterruptedError, OpenSSL.SSL.WantReadError): pass except OpenSSL.SSL.WantWriteError: assert self._state.tls_started self._tls_read_wants_write = True self._trace_logger.debug("_read_ready: swap reader for writer") self._loop.remove_reader(self._raw_fd) self._loop.add_writer(self._raw_fd, self._write_ready) except OpenSSL.SSL.SysCallError as exc: if self._state in (_State.TLS_SHUT_DOWN, _State.TLS_SHUTTING_DOWN, _State.CLOSED): self._trace_logger.debug( "_read_ready: ignoring syscall exception during shutdown: " "%s", exc, ) else: self._fatal_error(exc, "Fatal read error on STARTTLS transport") except Exception as err: self._fatal_error(err, "Fatal read error on STARTTLS transport") return else: if data: self._protocol.data_received(data) else: keep_open = False try: keep_open = bool(self._protocol.eof_received()) finally: self._eof_received(keep_open) def _write_ready(self) -> None: assert self._state is not None if self._tls_read_wants_write: self._tls_read_wants_write = False self._read_ready() if not self._paused and not self._state.eof_received: self._trace_logger.debug("_write_ready: add reader for more" " data") self._loop.add_reader(self._raw_fd, self._read_ready) # do not send data during handshake! if self._buffer and self._state != _State.TLS_HANDSHAKING: try: nsent = self._send_wrap.send(self._buffer) except (BlockingIOError, InterruptedError, OpenSSL.SSL.WantWriteError): nsent = 0 except OpenSSL.SSL.WantReadError: nsent = 0 assert self._state.tls_started self._tls_write_wants_read = True self._trace_logger.debug( "_write_ready: swap writer for reader") self._loop.remove_writer(self._raw_fd) self._loop.add_reader(self._raw_fd, self._read_ready) except OpenSSL.SSL.SysCallError as exc: if self._state in (_State.TLS_SHUT_DOWN, _State.TLS_SHUTTING_DOWN, _State.CLOSED): self._trace_logger.debug( "_write_ready: ignoring syscall exception during " "shutdown: %s", exc, ) else: self._fatal_error(exc, "Fatal write error on STARTTLS " "transport") except Exception as err: self._fatal_error(err, "Fatal write error on STARTTLS " "transport") return if nsent: del self._buffer[:nsent] if not self._buffer: if not self._tls_read_wants_write: self._trace_logger.debug("_write_ready: nothing more to write," " removing writer") self._loop.remove_writer(self._raw_fd) if self._closing: if self._state.tls_started: self._tls_shutdown() else: self._raw_shutdown() def _eof_received(self, keep_open: bool) -> None: assert self._state is not None self._trace_logger.debug("_eof_received: removing reader") self._loop.remove_reader(self._raw_fd) if self._state.tls_started: assert self._tls_conn is not None if self._tls_conn.get_shutdown() & OpenSSL.SSL.RECEIVED_SHUTDOWN: # proper TLS shutdown going on if keep_open: self._state = _State.TLS_EOF_RECEIVED else: self._tls_shutdown() else: if keep_open: self._trace_logger.warning( "result of eof_received() ignored as shut down is" " improper", ) self._fatal_error( ConnectionError("Underlying transport closed"), "unexpected eof_received" ) else: if keep_open: self._state = _State.RAW_EOF_RECEIVED else: self._raw_shutdown() # public API def abort(self) -> None: """ Immediately close the stream, without sending remaining buffers or performing a proper shutdown. """ if self._state == _State.CLOSED: self._invalid_state("abort() called") return self._force_close(None) def can_write_eof(self) -> bool: """ Return :data:`False`. .. note:: Writing of EOF (i.e. closing the sending direction of the stream) is theoretically possible. However, it was deemed by the author that the case is rare enough to neglect it for the sake of implementation simplicity. """ return False def close(self) -> None: """ Close the stream. This performs a proper stream shutdown, except if the stream is currently performing a TLS handshake. In that case, calling :meth:`close` is equivalent to calling :meth:`abort`. Otherwise, the transport waits until all buffers are transmitted. """ if self._state == _State.CLOSED: self._invalid_state("close() called") return if self._state == _State.TLS_HANDSHAKING: # hard-close self._force_close(None) elif self._state == _State.TLS_SHUTTING_DOWN: # shut down in progress, nothing to do pass elif self._buffer: # there is data to be send left, first wait for it to transmit ... self._closing = True elif self._state is not None and self._state.tls_started: # normal TLS state, nothing left to transmit, shut down self._tls_shutdown() else: # normal non-TLS state, nothing left to transmit, close self._raw_shutdown() def get_extra_info( self, name: str, default: typing.Optional[typing.Any] = None, ) -> typing.Any: """ The following extra information is available: * ``socket``: the underlying :mod:`socket` object * ``sslcontext``: the :class:`OpenSSL.SSL.Context` object to use (this may be :data:`None` until :meth:`starttls` has been called) * ``ssl_object``: :class:`OpenSSL.SSL.Connection` object (:data:`None` if TLS is not enabled (yet)) * ``peername``: return value of :meth:`socket.Socket.getpeername` * ``peer_hostname``: The `peer_hostname` value passed to the constructor. * ``server_hostname``: The `server_hostname` value passed to the constructor. """ return self._extra.get(name, default) async def starttls( self, ssl_context: typing.Optional[OpenSSL.SSL.Context] = None, post_handshake_callback: typing.Optional[ PostHandshakeCallback ] = None, ) -> None: """ Start a TLS stream on top of the socket. This is an invalid operation if the stream is not in RAW_OPEN state. If `ssl_context` is set, it overrides the `ssl_context` passed to the constructor. If `post_handshake_callback` is set, it overrides the `post_handshake_callback` passed to the constructor. .. versionchanged:: 0.4 This method is now a barrier with respect to reads and writes: before the handshake is completed (including the post handshake callback, if any), no data is received or sent. """ if self._state != _State.RAW_OPEN or self._closing: raise self._invalid_state("starttls() called") if ssl_context is not None: self._ssl_context = ssl_context self._extra.update( sslcontext=ssl_context ) else: assert self._ssl_context_factory is not None self._ssl_context = self._ssl_context_factory(self) if post_handshake_callback is not None: self._tls_post_handshake_callback = post_handshake_callback self._waiter = asyncio.Future() self._waiter.add_done_callback(self._waiter_done) self._initiate_tls() try: await self._waiter finally: self._waiter = None def write(self, data: typing.Union[bytes, bytearray, memoryview]) -> None: """ Write data to the transport. This is an invalid operation if the stream is not writable, that is, if it is closed. During TLS negotiation, the data is buffered. """ if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError('data argument must be byte-ish (%r)', type(data)) if (self._state is None or not self._state.is_writable or self._closing): raise self._invalid_state("write() called") if not data: return if not self._buffer: self._loop.add_writer(self._raw_fd, self._write_ready) self._buffer.extend(data) def write_eof(self) -> None: """ Writing the EOF has not been implemented, for the sake of simplicity. """ raise NotImplementedError("Cannot write_eof() on STARTTLS transport") def can_starttls(self) -> bool: """ Return :data:`True`. """ return True def is_closing(self) -> bool: return (self._state == _State.TLS_SHUTTING_DOWN or self._state == _State.CLOSED) async def create_starttls_connection( loop: asyncio.BaseEventLoop, protocol_factory: typing.Callable[[], asyncio.Protocol], host: typing.Optional[str] = None, port: typing.Optional[int] = None, *, sock: typing.Optional[socket.socket] = None, ssl_context_factory: typing.Optional[SSLContextFactory] = None, use_starttls: bool = False, local_addr: typing.Any = None, **kwargs # type: typing.Any ) -> typing.Tuple[asyncio.Transport, asyncio.Protocol]: """ Create a connection which can later be upgraded to use TLS. .. versionchanged:: 0.4 The `local_addr` argument was added. :param loop: The event loop to use. :type loop: :class:`asyncio.BaseEventLoop` :param protocol_factory: Factory for the protocol for the connection :param host: The host name or address to connect to :type host: :class:`str` or :data:`None` :param port: The port to connect to :type port: :class:`int` or :data:`None` :param sock: A socket to wrap (conflicts with `host` and `port`) :type sock: :class:`socket.socket` :param ssl_context_factory: Function which returns a :class:`OpenSSL.SSL.Context` to use for TLS operations :param use_starttls: Flag to control whether TLS is negotiated right away or deferredly. :type use_starttls: :class:`bool` :param local_addr: Address to bind to This is roughly a copy of the asyncio implementation of :meth:`asyncio.BaseEventLoop.create_connection`. It returns a pair ``(transport, protocol)``, where `transport` is a newly created :class:`STARTTLSTransport` instance. Further keyword arguments are forwarded to the constructor of :class:`STARTTLSTransport`. `loop` must be a :class:`asyncio.BaseEventLoop`, with support for :meth:`asyncio.BaseEventLoop.add_reader` and the corresponding writer and removal functions for sockets. This is typically a selector type event loop. `protocol_factory` must be a callable which (without any arguments) returns a :class:`asyncio.Protocol` which will be connected to the STARTTLS transport. `host` and `port` must be a hostname and a port number, or both :data:`None`. Both must be :data:`None`, if and only if `sock` is not :data:`None`. In that case, `sock` is used instead of a newly created socket. `sock` is put into non-blocking mode and must be a stream socket. If `use_starttls` is :data:`True`, no TLS handshake will be performed initially. Instead, the connection is established without any transport-layer security. It is expected that the :meth:`STARTTLSTransport.starttls` method is used when the application protocol requires TLS. If `use_starttls` is :data:`False`, the TLS handshake is initiated right away. `local_addr` may be an address to bind this side of the socket to. If omitted or :data:`None`, the local address is assigned by the operating system. This coroutine returns when the stream is established. If `use_starttls` is :data:`False`, this means that the full TLS handshake has to be finished for this coroutine to return. Otherwise, no TLS handshake takes place. It must be invoked using the :meth:`STARTTLSTransport.starttls` coroutine. """ if host is not None and port is not None: host_addrs = await loop.getaddrinfo( host, port, type=socket.SOCK_STREAM, ) exceptions = [] for family, type, proto, cname, address in host_addrs: sock = None try: sock = socket.socket(family=family, type=type, proto=proto) sock.setblocking(False) if local_addr is not None: sock.bind(local_addr) await loop.sock_connect(sock, address) except OSError as exc: if sock is not None: sock.close() exceptions.append(exc) else: break else: if len(exceptions) == 1: raise exceptions[0] model = str(exceptions[0]) if all(str(exc) == model for exc in exceptions): raise exceptions[0] try: from aioxmpp.errors import MultiOSError # type:ignore except ImportError: MultiOSError = OSError raise MultiOSError( "could not connect to [{}]:{}".format(host, port), exceptions, ) elif sock is None: raise ValueError("sock must not be None if host and/or port are None") else: sock.setblocking(False) protocol = protocol_factory() waiter = asyncio.Future(loop=loop) # type: asyncio.Future[None] transport = STARTTLSTransport(loop, sock, protocol, ssl_context_factory=ssl_context_factory, waiter=waiter, use_starttls=use_starttls, **kwargs) await waiter return transport, protocol aioopenssl/utils.py000066400000000000000000000015771400056433500147560ustar00rootroot00000000000000import typing import OpenSSL.SSL class SendWrap: def __init__(self, sock: OpenSSL.SSL.Connection): self.__sock = sock self.__cached_write = None # type: typing.Optional[typing.Tuple[bytes, typing.Any]] # noqa def send(self, buf: typing.Union[bytes, memoryview]) -> int: if self.__cached_write is not None: as_bytes, prev_buf = self.__cached_write if prev_buf is not buf: raise ValueError( "this looks like a mistake: the previous send received a " "different buffer object" ) self.__cached_write = None else: as_bytes = bytes(buf) try: return self.__sock.send(as_bytes) except (OpenSSL.SSL.WantWriteError, OpenSSL.SSL.WantReadError): self.__cached_write = as_bytes, buf raise aioopenssl/version.py000066400000000000000000000002761400056433500152760ustar00rootroot00000000000000version_info = (0, 6, 0, None) __version__ = ".".join(map(str, version_info[:3])) + ( "-"+version_info[3] if version_info[3] is not None else "" # type:ignore ) version = __version__ docs/000077500000000000000000000000001400056433500120065ustar00rootroot00000000000000docs/Makefile000066400000000000000000000127371400056433500134600ustar00rootroot00000000000000# Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = sphinx-data/build # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: -rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/asyncio_xmpp.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/asyncio_xmpp.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/asyncio_xmpp" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/asyncio_xmpp" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." docs/conf.py000066400000000000000000000204351400056433500133110ustar00rootroot00000000000000#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # aioxmpp documentation build configuration file, created by # sphinx-quickstart on Mon Dec 1 08:14:58 2014. # # This file is execfile()d with the current directory set to its containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys, os import alabaster # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', 'sphinx.ext.autosummary'] # Add any paths that contain templates here, relative to this directory. templates_path = ['sphinx-data/templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = 'aioopenssl' copyright = '2016, Jonas Wielicki' import aioopenssl # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = ".".join(map(str, aioopenssl.version_info[:2])) # The full version, including alpha/beta/rc tags. release = aioopenssl.__version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['sphinx-data/build'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme_path = [alabaster.get_path()] html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { "github_button": "true", "github_repo": "aioopenssl", "github_user": "horazont", "font_size": "12pt", } html_sidebars = { '**': [ 'about.html', 'localtoc.html', 'navigation.html', 'relations.html', 'searchbox.html', 'donate.html', ] } # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['sphinx-data/static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'aioopenssldoc' # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ('index', 'aioopenssl.tex', 'aioopenssl Documentation', 'Jonas Wielicki', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'aioopenssl', 'aioopenssl Documentation', ['Jonas Wielicki'], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ('index', 'aioopenssl', 'aioopenssl Documentation', 'Jonas Wielicki', 'aioopenssl', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { 'https://docs.python.org/3/': None, 'https://pyopenssl.readthedocs.org/en/latest/': None, } docs/index.rst000066400000000000000000000001671400056433500136530ustar00rootroot00000000000000.. automodule:: aioopenssl Indices and tables ################## * :ref:`genindex` * :ref:`modindex` * :ref:`search` docs/sphinx-data/000077500000000000000000000000001400056433500142265ustar00rootroot00000000000000docs/sphinx-data/.gitignore000066400000000000000000000000061400056433500162120ustar00rootroot00000000000000build mypy.ini000066400000000000000000000010701400056433500125530ustar00rootroot00000000000000[mypy] python_version = 3.7 #warn_return_any = True warn_unused_configs = True disallow_untyped_calls = True disallow_untyped_defs = True disallow_incomplete_defs = True #check_untyped_defs = True disallow_untyped_decorators = True #disallow_any_unimported = True #disallow_any_expr = True #disallow_any_decorated = True disallow_any_explicit = False #disallow_any_generics = True disallow_subclassing_any = True no_implicit_optional = True warn_redundant_casts = True warn_unused_ignores = True warn_unreachable = True [mypy-OpenSSL.*] ignore_missing_imports = True setup.py000066400000000000000000000024271400056433500125750ustar00rootroot00000000000000#!/usr/bin/env python3 import os.path import runpy from setuptools import setup here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, "README.rst"), encoding="utf-8") as f: long_description = f.read() version_mod = runpy.run_path("aioopenssl/version.py") setup( name="aioopenssl", version=version_mod["__version__"], description="TLS-capable transport using OpenSSL for asyncio", long_description=long_description, url="https://github.com/horazont/aioopenssl", author="Jonas Wielicki", author_email="jonas@wielicki.name", license="Apache 2.0", classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Operating System :: POSIX", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Communications :: Chat", ], keywords="openssl asyncio library transport starttls", install_requires=[ "PyOpenSSL", ], packages=["aioopenssl"], ) tests/000077500000000000000000000000001400056433500122205ustar00rootroot00000000000000tests/__init__.py000066400000000000000000000000001400056433500143170ustar00rootroot00000000000000tests/ssl.pem000066400000000000000000000060631400056433500135310ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIID6jCCAtKgAwIBAgIJAPA6oph8Ud/CMA0GCSqGSIb3DQEBCwUAMG0xCzAJBgNV BAYTAkdCMRMwEQYDVQQKDApQcm9zb2R5IElNMSswKQYDVQQLDCJodHRwOi8vcHJv c29keS5pbS9kb2MvY2VydGlmaWNhdGVzMRwwGgYDVQQDDBNFeGFtcGxlIGNlcnRp ZmljYXRlMB4XDTE2MTAxMzE0NDAyMFoXDTE3MTAxMzE0NDAyMFowbTELMAkGA1UE BhMCR0IxEzARBgNVBAoMClByb3NvZHkgSU0xKzApBgNVBAsMImh0dHA6Ly9wcm9z b2R5LmltL2RvYy9jZXJ0aWZpY2F0ZXMxHDAaBgNVBAMME0V4YW1wbGUgY2VydGlm aWNhdGUwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCvxVfO9zVx3che YBGkem+wlkttxh3G2yMoyVGP1laIeNUG2iRZC+jC3nRZar9FoAp8vFC7uhSWUXtf ErhQCJTGhPv2a1unSgDE8sYMKneHx5ETNkbamH73nf3+8VYHCImrJwzHQkfEIQh+ UEZOHIMeaQw0xqLdBp9hc3AqrMZZ5w+hfMtdE0n/s4jANHmX+d0LKQKy/sDu0b87 3p0Rt92fte6NSwxT/k7TOW5cc7enZBIsV/R1dsyBG/2uzYOfiMBkcIPJMU4EhlQE bCIwSwLw0dGnC0edgi/dZu2Wxh+2bmQlHGVBHVHiC52MrGizwmvPrzdJik8VfZ0S c63f5imJAgMBAAGjgYwwgYkwDAYDVR0TBAUwAwEB/zB5BgNVHREEcjBwgglsb2Nh bGhvc3SgJAYIKwYBBQUHCAegGBYWX3htcHAtY2xpZW50LmxvY2FsaG9zdKAkBggr BgEFBQcIB6AYFhZfeG1wcC1zZXJ2ZXIubG9jYWxob3N0oBcGCCsGAQUFBwgFoAsM CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAViCNra0Otkonr7jOJ1r21RzI r8mna8ZXUB9aCHk9A/2bBLqAjoJ93ccVETgCOJdF2HTt8s0T+Hy1xdMc1yvqn7MQ 3MUL1DpudseT2hgCb/n6Tq6Wy0qNv9pCkVXRZ/71IgVRl+bjpt9l9kMxZGkm4ihC BiqsLTnXRNiaFggnhk8ayDVVRn+XrO7506Zaj4kONhUQRz83UWAnCcQLQE4V7glv H6xHc3xQIczKLB7SG+xvQqqdd6qK9EcN0WPVczocloiQPmN+VHZLnLmFp5XiJpJS KJFkY+5TkKg70QdXbkVjVikieb7W4Qg+n0VMxI6nMEpFAP8Zsr1HV+hOxkegvA== -----END CERTIFICATE----- -----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCvxVfO9zVx3che YBGkem+wlkttxh3G2yMoyVGP1laIeNUG2iRZC+jC3nRZar9FoAp8vFC7uhSWUXtf ErhQCJTGhPv2a1unSgDE8sYMKneHx5ETNkbamH73nf3+8VYHCImrJwzHQkfEIQh+ UEZOHIMeaQw0xqLdBp9hc3AqrMZZ5w+hfMtdE0n/s4jANHmX+d0LKQKy/sDu0b87 3p0Rt92fte6NSwxT/k7TOW5cc7enZBIsV/R1dsyBG/2uzYOfiMBkcIPJMU4EhlQE bCIwSwLw0dGnC0edgi/dZu2Wxh+2bmQlHGVBHVHiC52MrGizwmvPrzdJik8VfZ0S c63f5imJAgMBAAECggEAF2aJcydcOSWSqGXX03LtbbAEqL+odTH1c1eiASlC6fZU Pg5KqoQ4X8En8kOQ5V8SJlsX0HZMiVqdtyGR4i3SSL+pn7vZPuNOSOodTb9VnIEI ImumcTG+LY8eIpPUpSkQ0vVm8Rw027qeG9rqEToghqrrkhcj1ZMtItcwhq3r1koB LphD08jhU27WEaal8o+P+xHSWeyMBbA7mYiHHSS29JXfHwRUueWCOm/V3E9HYyvK 6pBdezs85fDiT7xPECA7RGEfPjWs/IQB/xrKfqSF9hgcw3dBvYsf90vZmlEtNgFr fudwrsyYBTyreG/WQGuFoySXxVjbAO2+BuztJRnBVQKBgQDfg5p3KoLJf1COh0Dd m7jQkufdvpYh0HJKAlf0L+hcPR8p1x6OtILNZRQDRSeR3AXkwe1X72cfLjzauWDz 9aPG0xAX/Scg4/RCaCStvQbt3jlJoUGbzANyLbsgr4uzUTor8tJCSVeUkqoOZ4wX 1d4im5n4BRQARo1a/KHgruDUowKBgQDJUVXF83XqcUz0/EhjHc4iMV9aa+NbPv4Y miRgtoIhX2VWZJAy6gfthiUsxDqj1HzWDqEvLoQZJYRRnzs+svtDMh/aSRBkDERM 2LpOpXfDbYz5pKsmhzUrDpg0HrBq1bzJe8LjBt/5wehyvWWZXFuksszHB6Y2HA3V zCVNysy/4wKBgDHbrmoknnYKI1MX0p1cbjaAfp5VNDIoyEXADhSXVzK0I652oQde NstQX8128KO1u87Sf3odGhi3fLWhooHo6naggDeJrd/FWagyiPQEdXY8GvVUtkjl kmM21kYtQnFmjh5dlQ9aQuIOcUazTGnIuDtqEEdmApcpJcEFF6sB938XAoGBAKTI EWgc1liWcsJYObv/anlsZjsx8f/++KcUjFApMyyz09O6LpmuG90cnxMcb08oHyJr CR4AehnHLp5Msoaoo6elwJLAmUz7CXDJOP4kzHnHEsxIL9sycp+Sq5yFvitEVemp hqSjPPNG98frJN05zr1pqNoEddT2c5CbL7GUHLG3AoGBAMfXr2oIjqjURqeo5Fo5 9a3PpyqwMQe1P56H9L7oMGkVPmQfCyBFuQ4rqX/debCJnIOYF5j2Fq1MqS9siC62 n23IG45CtKPeMz6HDeNSbIniJY8GrCS/L3ihllEnqz76kBAWn5+OXou4Psjro2W/ 7G0FfD08tedhC9HX7GrB5Af3 -----END PRIVATE KEY----- tests/test_e2e.py000066400000000000000000000560141400056433500143120ustar00rootroot00000000000000import asyncio import functools import logging import os import pathlib import ssl import socket import threading import unittest import unittest.mock import OpenSSL.SSL import aioopenssl PORT = int(os.environ.get("AIOOPENSSL_TEST_PORT", "12345")) KEYFILE = pathlib.Path(__file__).parent / "ssl.pem" def blocking(meth): @functools.wraps(meth) def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() return loop.run_until_complete( asyncio.wait_for(meth(*args, **kwargs), 1) ) return wrapper class TestSSLConnection(unittest.TestCase): TRY_PORTS = list(range(10000, 10010)) @blocking async def setUp(self): self.loop = asyncio.get_event_loop() self.server = None self.server_ctx = ssl.create_default_context( ssl.Purpose.CLIENT_AUTH ) self.server_ctx.load_cert_chain(str(KEYFILE)) await self._replace_server() self.inbound_queue = asyncio.Queue() async def _shutdown_server(self): self.server.close() while not self.inbound_queue.empty(): reader, writer = await self.inbound_queue.get() writer.close() await self.server.wait_closed() self.server = None async def _replace_server(self): if self.server is not None: await self._shutdown_server() self.server = await asyncio.start_server( self._server_accept, host="127.0.0.1", port=PORT, ssl=self.server_ctx, ) @blocking async def tearDown(self): await self._shutdown_server() def _server_accept(self, reader, writer): self.inbound_queue.put_nowait( (reader, writer) ) def _stream_reader_proto(self): reader = asyncio.StreamReader(loop=self.loop) proto = asyncio.StreamReaderProtocol(reader) proto.aioopenssl_test_reader = reader return proto async def _connect(self, *args, **kwargs): transport, reader_proto = \ await aioopenssl.create_starttls_connection( asyncio.get_event_loop(), self._stream_reader_proto, *args, **kwargs ) reader = reader_proto.aioopenssl_test_reader del reader_proto.aioopenssl_test_reader writer = asyncio.StreamWriter(transport, reader_proto, reader, self.loop) return transport, reader, writer @blocking async def test_send_and_receive_data(self): _, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() c_writer.write(b"foobar") s_writer.write(b"fnord") await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(5), s_reader.readexactly(6), ) self.assertEqual( s_read, b"foobar" ) self.assertEqual( c_read, b"fnord" ) @blocking async def test_send_large_data(self): _, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() data = bytearray(2**17) c_writer.write(data) s_writer.write(b"foobar") await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(6), s_reader.readexactly(len(data)), ) self.assertEqual( s_read, data, ) self.assertEqual( c_read, b"foobar", ) @blocking async def test_recv_large_data(self): _, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() data = bytearray(2**17) s_writer.write(data) c_writer.write(b"foobar") await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(len(data)), s_reader.readexactly(6), ) self.assertEqual( c_read, data, ) self.assertEqual( s_read, b"foobar", ) @blocking async def test_send_recv_large_data(self): _, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() data1 = bytearray(2**17) data2 = bytearray(2**17) s_writer.write(data1) c_writer.write(data2) await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(len(data1)), s_reader.readexactly(len(data2)), ) self.assertEqual( c_read, data1, ) self.assertEqual( s_read, data2, ) @blocking async def test_abort(self): c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() c_transport.abort() with self.assertRaises(ConnectionError): await asyncio.gather(c_writer.drain()) @blocking async def test_local_addr(self): last_exc = None used_port = None for port in self.TRY_PORTS: try: c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, local_addr=("127.0.0.1", port) ) except OSError as exc: last_exc = exc continue used_port = port break else: raise last_exc s_reader, s_writer = await self.inbound_queue.get() sock = s_writer.transport.get_extra_info("socket") peer_addr = sock.getpeername() self.assertEqual(peer_addr, ("127.0.0.1", used_port)) @blocking async def test_starttls(self): c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, ) await c_transport.starttls() s_reader, s_writer = await self.inbound_queue.get() c_writer.write(b"foobar") s_writer.write(b"fnord") await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(5), s_reader.readexactly(6), ) self.assertEqual( s_read, b"foobar" ) self.assertEqual( c_read, b"fnord" ) @blocking async def test_renegotiation(self): self.server_ctx = ssl.create_default_context( ssl.Purpose.CLIENT_AUTH ) if hasattr(ssl, "OP_NO_TLSv1_3"): # Need to forbid TLS v1.3, since TLSv1.3+ does not support # renegotiation self.server_ctx.options |= ssl.OP_NO_TLSv1_3 self.server_ctx.load_cert_chain(str(KEYFILE)) await self._replace_server() def factory(_): ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) if hasattr(OpenSSL.SSL, "OP_NO_TLSv1_3"): # Need to forbid TLS v1.3, since TLSv1.3+ does not support # renegotiation ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_3) return ctx c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=factory, server_hostname="localhost", use_starttls=False, ) s_reader, s_writer = await self.inbound_queue.get() ssl_sock = c_transport.get_extra_info("ssl_object") c_writer.write(b"foobar") s_writer.write(b"fnord") await asyncio.gather(s_writer.drain(), c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(5), s_reader.readexactly(6), ) self.assertEqual( s_read, b"foobar" ) self.assertEqual( c_read, b"fnord" ) ssl_sock.renegotiate() @blocking async def test_post_handshake_exception_is_propagated(self): class FooException(Exception): pass async def post_handshake_callback(transport): raise FooException() c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) with self.assertRaises(FooException): await c_transport.starttls() @blocking async def test_no_data_is_sent_if_handshake_crashes(self): class FooException(Exception): pass async def post_handshake_callback(transport): await asyncio.sleep(0.5) raise FooException() c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) starttls_task = asyncio.ensure_future(c_transport.starttls()) # ensure that handshake is in progress... await asyncio.sleep(0.2) c_transport.write(b"foobar") with self.assertRaises(FooException): await starttls_task s_reader, s_writer = await self.inbound_queue.get() with self.assertRaises(Exception) as ctx: await asyncio.wait_for( s_reader.readexactly(6), timeout=0.1, ) exc = ctx.exception # using type(None) as default, since that will always be false in the # isinstance check below incomplete_read_exc_type = getattr( asyncio.streams, "IncompleteReadError", getattr(asyncio, "IncompleteReadError", type(None)) ) if isinstance(exc, incomplete_read_exc_type): self.assertFalse(exc.partial) elif not isinstance(exc, ConnectionResetError): raise exc @blocking async def test_no_data_is_received_if_handshake_crashes(self): class FooException(Exception): pass async def post_handshake_callback(transport): await asyncio.sleep(0.5) raise FooException() c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) starttls_task = asyncio.ensure_future(c_transport.starttls()) s_reader, s_writer = await self.inbound_queue.get() self.assertFalse(starttls_task.done()) s_writer.write(b"fnord") with self.assertRaises(FooException): await c_reader.readexactly(5) with self.assertRaises(FooException): await starttls_task @blocking async def test_data_is_sent_after_handshake(self): async def post_handshake_callback(transport): await asyncio.sleep(0.5) c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) starttls_task = asyncio.ensure_future(c_transport.starttls()) # ensure that handshake is in progress... await asyncio.sleep(0.2) c_transport.write(b"foobar") await starttls_task s_reader, s_writer = await self.inbound_queue.get() s_recv = await asyncio.wait_for( s_reader.readexactly(6), timeout=0.1, ) self.assertEqual(s_recv, b"foobar") @blocking async def test_no_data_is_received_after_handshake(self): async def post_handshake_callback(transport): await asyncio.sleep(0.5) c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) starttls_task = asyncio.ensure_future(c_transport.starttls()) s_reader, s_writer = await self.inbound_queue.get() self.assertFalse(starttls_task.done()) s_writer.write(b"fnord") with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( c_reader.readexactly(5), timeout=0.1, ) await starttls_task c_recv = await c_reader.readexactly(5) self.assertEqual(c_recv, b"fnord") @blocking async def test_close_during_handshake(self): cancelled = None async def post_handshake_callback(transport): nonlocal cancelled try: await asyncio.sleep(0.5) cancelled = False except asyncio.CancelledError: cancelled = True c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=True, post_handshake_callback=post_handshake_callback, ) starttls_task = asyncio.ensure_future(c_transport.starttls()) # ensure that handshake is in progress... await asyncio.sleep(0.2) c_transport.close() with self.assertRaises(ConnectionError): await starttls_task self.assertTrue(cancelled) class ServerThread(threading.Thread): def __init__(self, ctx, port, loop, queue): super().__init__() self._logger = logging.getLogger("ServerThread") self._ctx = ctx self._socket = socket.socket( socket.AF_INET, socket.SOCK_STREAM, 0, ) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._socket.bind(("127.0.0.1", port)) self._socket.settimeout(0.5) self._socket.listen(0) self._loop = loop self._queue = queue self.stopped = False def _push(self, arg): self._loop.call_soon_threadsafe( self._queue.put_nowait, arg, ) def run(self): self._logger.info("ready") while not self.stopped: try: client, addr = self._socket.accept() except socket.timeout: self._logger.debug("no connection yet, cycling") continue self._logger.debug("connection accepted from %s", addr) try: wrapped = OpenSSL.SSL.Connection(self._ctx, client) wrapped.set_accept_state() wrapped.do_handshake() except Exception as exc: try: wrapped.close() except: # NOQA pass try: client.shutdown(socket.SHUT_RDWR) client.close() except: # NOQA pass self._push((False, exc)) else: self._push((True, wrapped)) self._logger.info("shutting down") self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() class TestSSLConnectionThreadServer(unittest.TestCase): TRY_PORTS = list(range(10000, 10010)) @blocking async def setUp(self): self.loop = asyncio.get_event_loop() self.thread = None ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) ctx.use_certificate_chain_file(str(KEYFILE)) ctx.use_privatekey_file(str(KEYFILE)) self._replace_thread(ctx) def _replace_thread(self, ctx): if self.thread is not None: self.thread.stopped = True self.thread.join() self.thread = None self.inbound_queue = asyncio.Queue() self.thread = ServerThread( ctx, PORT+1, self.loop, self.inbound_queue, ) self.thread.start() @blocking async def tearDown(self): self.thread.stopped = True self.thread.join() async def _get_inbound(self): ok, data = await self.inbound_queue.get() if not ok: raise data return data async def recv_thread(self, sock, *argv): return await self.loop.run_in_executor( None, sock.recv, *argv ) async def send_thread(self, sock, *argv): return await self.loop.run_in_executor( None, sock.send, *argv ) def _stream_reader_proto(self): reader = asyncio.StreamReader(loop=self.loop) proto = asyncio.StreamReaderProtocol(reader) proto.aioopenssl_test_reader = reader return proto async def _connect(self, *args, **kwargs): transport, reader_proto = \ await aioopenssl.create_starttls_connection( asyncio.get_event_loop(), self._stream_reader_proto, *args, **kwargs ) reader = reader_proto.aioopenssl_test_reader del reader_proto.aioopenssl_test_reader writer = asyncio.StreamWriter(transport, reader_proto, reader, self.loop) return transport, reader, writer @blocking async def test_connect_send_recv_close(self): c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT+1, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) sock = await self._get_inbound() c_writer.write(b"foobar") await self.send_thread(sock, b"fnord") await asyncio.gather(c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(5), self.recv_thread(sock, 6) ) self.assertEqual( s_read, b"foobar" ) self.assertEqual( c_read, b"fnord" ) c_transport.close() await asyncio.sleep(0.1) sock.close() @blocking async def test_renegotiate(self): ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) # Need to forbid TLS v1.3, since TLSv1.3+ does not support # renegotiation ctx.set_options( getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", getattr(ssl, "OP_NO_TLSv1_3", 0)) ) ctx.use_certificate_chain_file(str(KEYFILE)) ctx.use_privatekey_file(str(KEYFILE)) self._replace_thread(ctx) c_transport, c_reader, c_writer = await self._connect( host="127.0.0.1", port=PORT+1, ssl_context_factory=lambda transport: OpenSSL.SSL.Context( OpenSSL.SSL.SSLv23_METHOD ), server_hostname="localhost", use_starttls=False, ) sock = await self._get_inbound() c_writer.write(b"foobar") await self.send_thread(sock, b"fnord") await asyncio.gather(c_writer.drain()) c_read, s_read = await asyncio.gather( c_reader.readexactly(5), self.recv_thread(sock, 6) ) self.assertEqual( s_read, b"foobar" ) self.assertEqual( c_read, b"fnord" ) try: sock.renegotiate() except OpenSSL.SSL.Error as exc: (argv,), = exc.args if (argv[1] == "SSL_renegotiate" and argv[2] == "wrong ssl version"): raise RuntimeError( "You are a PyOpenSSL version which uses TLSv1.3, but has" " no way to turn it off. Update PyOpenSSL." ) raise c_writer.write(b"baz") await asyncio.gather( c_writer.drain(), self.loop.run_in_executor(None, sock.do_handshake) ) s_read, = await asyncio.gather( self.recv_thread(sock, 6) ) self.assertEqual(s_read, b"baz") c_transport.close() await asyncio.sleep(0.1) sock.close() tests/test_utils.py000066400000000000000000000261771400056433500150060ustar00rootroot00000000000000import contextlib import unittest import unittest.mock import OpenSSL.SSL import aioopenssl.utils as utils class TestSendWrap(unittest.TestCase): def setUp(self): self.default_send = lambda x: len(x) self.sock = unittest.mock.Mock(["send"]) self.sock.send.side_effect = self.default_send self.ww = utils.SendWrap(self.sock) def test_send_calls_send(self): data = unittest.mock.sentinel.data self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) result = self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) self.assertEqual( result, unittest.mock.sentinel.send_result, ) def test_send_propagates_exceptions_from_send(self): data = bytearray() self.sock.send.side_effect = OpenSSL.SSL.Error with self.assertRaises(OpenSSL.SSL.Error): self.ww.send(data) def test_send_propagates_want_read_from_send(self): data = bytearray() self.sock.send.side_effect = OpenSSL.SSL.WantReadError with self.assertRaises(OpenSSL.SSL.WantReadError): self.ww.send(data) def test_send_propagates_want_send_from_send(self): data = bytearray() self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) def test_send_after_want_send_passes_cached_bytes(self): data = unittest.mock.sentinel.data with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result result = self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) self.assertEqual(result, unittest.mock.sentinel.send_result) def test_send_after_want_send_works_normally(self): data = unittest.mock.sentinel.data data2 = unittest.mock.sentinel.data2 with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result1 result1 = self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() bytes_.return_value = unittest.mock.sentinel.new_bytes self.sock.send.reset_mock() self.sock.send.return_value = unittest.mock.sentinel.send_result2 result2 = self.ww.send(data2) bytes_.assert_called_once_with(data2) self.sock.send.assert_called_once_with( unittest.mock.sentinel.new_bytes ) bytes_.reset_mock() self.sock.send.reset_mock() self.assertEqual(result1, unittest.mock.sentinel.send_result1) self.assertEqual(result2, unittest.mock.sentinel.send_result2) def test_send_after_want_send_rejects_subsequent_call_if_different_buffer(self): # NOQA data = unittest.mock.sentinel.data data2 = unittest.mock.sentinel.data2 with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result1 with self.assertRaisesRegex( ValueError, "this looks like a mistake: the previous send received a " "different buffer object"): self.ww.send(data2) bytes_.assert_not_called() self.sock.send.assert_not_called() def test_send_after_want_read_passes_cached_bytes(self): data = unittest.mock.sentinel.data with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantReadError with self.assertRaises(OpenSSL.SSL.WantReadError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result result = self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) self.assertEqual(result, unittest.mock.sentinel.send_result) def test_send_after_want_read_works_normally(self): data = unittest.mock.sentinel.data data2 = unittest.mock.sentinel.data2 with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantReadError with self.assertRaises(OpenSSL.SSL.WantReadError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result1 result1 = self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() bytes_.return_value = unittest.mock.sentinel.new_bytes self.sock.send.reset_mock() self.sock.send.return_value = unittest.mock.sentinel.send_result2 result2 = self.ww.send(data2) bytes_.assert_called_once_with(data2) self.sock.send.assert_called_once_with( unittest.mock.sentinel.new_bytes ) bytes_.reset_mock() self.sock.send.reset_mock() self.assertEqual(result1, unittest.mock.sentinel.send_result1) self.assertEqual(result2, unittest.mock.sentinel.send_result2) def test_send_after_want_read_rejects_subsequent_call_if_different_buffer(self): # NOQA data = unittest.mock.sentinel.data data2 = unittest.mock.sentinel.data2 with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantReadError with self.assertRaises(OpenSSL.SSL.WantReadError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result1 with self.assertRaisesRegex( ValueError, "this looks like a mistake: the previous send received a " "different buffer object"): self.ww.send(data2) bytes_.assert_not_called() self.sock.send.assert_not_called() def test_send_with_several_want_read_send_errors(self): data = unittest.mock.sentinel.data data2 = unittest.mock.sentinel.data2 with contextlib.ExitStack() as stack: bytes_ = stack.enter_context( unittest.mock.patch("aioopenssl.utils.bytes", create=True) ) self.sock.send.side_effect = OpenSSL.SSL.WantReadError with self.assertRaises(OpenSSL.SSL.WantReadError): self.ww.send(data) bytes_.assert_called_once_with(data) self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = OpenSSL.SSL.WantWriteError with self.assertRaises(OpenSSL.SSL.WantWriteError): self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() self.sock.send.reset_mock() self.sock.send.side_effect = None self.sock.send.return_value = unittest.mock.sentinel.send_result1 result1 = self.ww.send(data) bytes_.assert_not_called() self.sock.send.assert_called_once_with(bytes_()) bytes_.reset_mock() bytes_.return_value = unittest.mock.sentinel.new_bytes self.sock.send.reset_mock() self.sock.send.return_value = unittest.mock.sentinel.send_result2 result2 = self.ww.send(data2) bytes_.assert_called_once_with(data2) self.sock.send.assert_called_once_with( unittest.mock.sentinel.new_bytes ) bytes_.reset_mock() self.sock.send.reset_mock() self.assertEqual(result1, unittest.mock.sentinel.send_result1) self.assertEqual(result2, unittest.mock.sentinel.send_result2)