pax_global_header00006660000000000000000000000064147546777740014546gustar00rootroot0000000000000052 comment=c5724ce0b082f23f1ab77acfcdaa2870f53e7ec5 asyncssh-2.20.0/000077500000000000000000000000001475467777400134625ustar00rootroot00000000000000asyncssh-2.20.0/.coveragerc000066400000000000000000000003411475467777400156010ustar00rootroot00000000000000[run] branch = True relative_files = True source = asyncssh tests [report] exclude_lines = if TYPE_CHECKING: pragma: no cover raise NotImplementedError partial_branches = pragma: no branch for .* asyncssh-2.20.0/.github/000077500000000000000000000000001475467777400150225ustar00rootroot00000000000000asyncssh-2.20.0/.github/workflows/000077500000000000000000000000001475467777400170575ustar00rootroot00000000000000asyncssh-2.20.0/.github/workflows/run_tests.yml000066400000000000000000000116341475467777400216350ustar00rootroot00000000000000name: Run tests on: [push, pull_request] jobs: run-tests: name: Run tests strategy: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] include: - os: macos-latest python-version: "3.10" openssl-version: "3" - os: macos-latest python-version: "3.11" openssl-version: "3" - os: macos-latest python-version: "3.12" openssl-version: "3" - os: macos-latest python-version: "3.13" openssl-version: "3" runs-on: ${{ matrix.os }} env: liboqs_version: '0.10.1' nettle_version: nettle_3.8.1_release_20220727 steps: - name: Checkout asyncssh uses: actions/checkout@v4 with: path: asyncssh - name: Checkout liboqs if: ${{ runner.os != 'macOS' }} uses: actions/checkout@v4 with: repository: open-quantum-safe/liboqs ref: ${{ env.liboqs_version }} path: liboqs - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip cache-dependency-path: | asyncssh/setup.py asyncssh/tox.ini - name: Set up ccache for liboqs (Linux) uses: hendrikmuhs/ccache-action@v1.2 if: ${{ runner.os == 'Linux' }} with: key: liboqs-cache-${{ matrix.os }} - name: Install Linux dependencies if: ${{ runner.os == 'Linux' }} run: | sudo apt update sudo apt install -y --no-install-recommends libnettle8 libsodium-dev libssl-dev libkrb5-dev ssh cmake ninja-build - name: Install macOS dependencies if: ${{ runner.os == 'macOS' }} run: brew install nettle liboqs libsodium openssl - name: Provide OpenSSL 3 if: ${{ runner.os == 'macOS' && matrix.openssl-version == '3' }} run: echo "/usr/local/opt/openssl@3/bin" >> $GITHUB_PATH - name: Install nettle (Windows) if: ${{ runner.os == 'Windows' }} shell: pwsh run: | curl -fLO https://github.com/ShiftMediaProject/nettle/releases/download/${{ env.nettle_version }}/libnettle_${{ env.nettle_version }}_msvc17.zip Expand-Archive libnettle_${{ env.nettle_version }}_msvc17.zip nettle cp nettle\bin\x64\*.dll "$env:Python_ROOT_DIR" - name: Install liboqs (Linux) if: ${{ runner.os == 'Linux' }} working-directory: liboqs run: | cmake -GNinja -Bbuild . -DCMAKE_INSTALL_PREFIX=/usr -DBUILD_SHARED_LIBS=ON -DOQS_BUILD_ONLY_LIB=ON -DOQS_DIST_BUILD=ON -DCMAKE_C_COMPILER_LAUNCHER=ccache cmake --build build sudo cmake --install build - name: Initialize MSVC environment uses: ilammy/msvc-dev-cmd@v1 - name: Install liboqs (Windows) if: ${{ runner.os == 'Windows' }} shell: pwsh working-directory: liboqs run: | cmake -GNinja -Bbuild . -DBUILD_SHARED_LIBS=ON -DOQS_BUILD_ONLY_LIB=ON -DOQS_DIST_BUILD=ON cmake --build build cp build\bin\oqs.dll "$env:Python_ROOT_DIR" - name: Install Python dependencies run: pip install tox - name: Run tests shell: python working-directory: asyncssh run: | import os, sys, platform, subprocess V = sys.version_info p = platform.system().lower() subprocess.run( ['tox', 'run', '-e', f'py{V.major}{V.minor}-{p}', '--', '-ra'], check=True) - name: Upload coverage data uses: actions/upload-artifact@v4 with: name: coverage-${{ matrix.os }}-${{ matrix.python-version }} path: asyncssh/.coverage.* include-hidden-files: true retention-days: 1 merge-coverage: runs-on: ubuntu-latest needs: run-tests if: ${{ always() }} steps: - name: Merge coverage uses: actions/upload-artifact/merge@v4 with: name: coverage pattern: coverage-* include-hidden-files: true report-coverage: name: Report coverage runs-on: ubuntu-latest needs: merge-coverage if: ${{ always() }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 - uses: actions/download-artifact@v4 with: name: coverage - name: Install dependencies run: | sudo apt install -y sqlite3 pip install tox - name: Report coverage run: | shopt -s nullglob for f in .coverage.*-windows; do sqlite3 "$f" "update file set path = replace(path, '\\', '/');" done tox -e report - uses: codecov/codecov-action@v4 with: files: coverage.xml token: ${{ secrets.CODECOV_TOKEN }} asyncssh-2.20.0/.gitignore000066400000000000000000000001421475467777400154470ustar00rootroot00000000000000.*.swp MANIFEST __pycache__/ *.py[cod] asyncssh.egg-info build/ dist/ docs/Makefile docs/_build/ asyncssh-2.20.0/.readthedocs.yaml000066400000000000000000000002501475467777400167060ustar00rootroot00000000000000version: 2 build: os: ubuntu-22.04 tools: python: "3.11" python: install: - requirements: docs/requirements.txt sphinx: configuration: docs/conf.py asyncssh-2.20.0/CONTRIBUTING.rst000066400000000000000000000072371475467777400161340ustar00rootroot00000000000000Contributing to AsyncSSH ======================== Input on AsyncSSH is extremely welcome. Below are some recommendations of the best ways to contribute. Asking questions ---------------- If you have a general question about how to use AsyncSSH, you are welcome to post it to the end-user mailing list at `asyncssh-users@googlegroups.com `_. If you have a question related to the development of AsyncSSH, you can post it to the development mailing list at `asyncssh-dev@googlegroups.com `_. You are also welcome to use the AsyncSSH `issue tracker `_ to ask questions. Reporting bugs -------------- Please use the `issue tracker `_ to report any bugs you find. Before creating a new issue, please check the currently open issues to see if your problem has already been reported. If you create a new issue, please include the version of AsyncSSH you are using, information about the OS you are running on and the installed version of Python and any other libraries that are involved. Please also include detailed information about how to reproduce the problem, including any traceback information you were able to collect or other relevant output. If you have sample code which exhibits the problem, feel free to include that as well. If possible, please test against the latest version of AsyncSSH. Also, if you are testing code in something other than the master branch, it would be helpful to know if you also see the problem in master. Requesting feature enhancements ------------------------------- The `issue tracker `_ should also be used to post feature enhancement requests. While I can't make any promises about what features will be added in the future, suggestions are always welcome! Contributing code ----------------- Before submitting a pull request, please create an issue on the `issue tracker `_ explaining what functionality you'd like to contribute and how it could be used. Discussing the approach you'd like to take up front will make it far more likely I'll be able to accept your changes, or explain what issues might prevent that before you spend a lot of effort. If you find a typo or other small bug in the code, you're welcome to submit a patch without filing an issue first, but for anything larger than a few lines I strongly recommend coordinating up front. Any code you submit will need to be provided with a compatible license. AsyncSSH code is currently released under the `Eclipse Public License v2.0 `_. Before submitting a pull request, make sure to indicate that you are ok with releasing your code under this license and how you'd like to be listed in the contributors list. Branches -------- There are two long-lived branches in AsyncSSH: * The master branch is intended to contain the latest stable version of the code. All official versions of AsyncSSH are released from this branch, and each release has a corresponding tag added matching its release number. * The develop branch is intended to contain new features and bug fixes ready to be tested before being added to an official release. APIs in the develop branch may be subject to change until they are migrated back to master, and there's no guarantee of backward compatibility in this branch. However, pulling from this branch will provide early access to new functionality and a chance to influence this functionality before it is released. Also, all pull requests should be submitted against this branch. asyncssh-2.20.0/COPYRIGHT000066400000000000000000000011431475467777400147540ustar00rootroot00000000000000Copyright (c) 2013-2018 by Ron Frederick and others. This program and the accompanying materials are made available under the terms of the Eclipse Public License v2.0 which accompanies this distribution and is available at: http://www.eclipse.org/legal/epl-2.0/ This program may also be made available under the following secondary licenses when the conditions for such availability set forth in the Eclipse Public License v2.0 are satisfied: GNU General Public License, Version 2.0, or any later versions of that license SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later asyncssh-2.20.0/LICENSE000066400000000000000000000335661475467777400145040ustar00rootroot00000000000000Eclipse Public License - v 2.0 THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 1. DEFINITIONS "Contribution" means: a) in the case of the initial Contributor, the initial content Distributed under this Agreement, and b) in the case of each subsequent Contributor: i) changes to the Program, and ii) additions to the Program; where such changes and/or additions to the Program originate from and are Distributed by that particular Contributor. A Contribution "originates" from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include changes or additions to the Program that are not Modified Works. "Contributor" means any person or entity that Distributes the Program. "Licensed Patents" mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program. "Program" means the Contributions Distributed in accordance with this Agreement. "Recipient" means anyone who receives the Program under this Agreement or any Secondary License (as applicable), including Contributors. "Derivative Works" shall mean any work, whether in Source Code or other form, that is based on (or derived from) the Program and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. "Modified Works" shall mean any work in Source Code or other form that results from an addition to, deletion from, or modification of the contents of the Program, including, for purposes of clarity any new file in Source Code form that contains any contents of the Program. Modified Works shall not include works that contain only declarations, interfaces, types, classes, structures, or files of the Program solely in each case in order to link to, bind by name, or subclass the Program or Modified Works thereof. "Distribute" means the acts of a) distributing or b) making available in any manner that enables the transfer of a copy. "Source Code" means the form of a Program preferred for making modifications, including but not limited to software source code, documentation source, and configuration files. "Secondary License" means either the GNU General Public License, Version 2.0, or any later versions of that license, including any exceptions or additional permissions as identified by the initial Contributor. 2. GRANT OF RIGHTS a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, Distribute and sublicense the Contribution of such Contributor, if any, and such Derivative Works. b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in Source Code or other form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder. c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to Distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program. d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement. e) Notwithstanding the terms of any Secondary License, no Contributor makes additional grants to any Recipient (other than those set forth in this Agreement) as a result of such Recipient's receipt of the Program under the terms of a Secondary License (if permitted under the terms of Section 3). 3. REQUIREMENTS 3.1 If a Contributor Distributes the Program in any form, then: a) the Program must also be made available as Source Code, in accordance with section 3.2, and the Contributor must accompany the Program with a statement that the Source Code for the Program is available under this Agreement, and informs Recipients how to obtain it in a reasonable manner on or through a medium customarily used for software exchange; and b) the Contributor may Distribute the Program under a license different than this Agreement, provided that such license: i) effectively disclaims on behalf of all other Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose; ii) effectively excludes on behalf of all other Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits; iii) does not attempt to limit or alter the recipients' rights in the Source Code under section 3.2; and iv) requires any subsequent distribution of the Program by any party to be under a license that satisfies the requirements of this section 3. 3.2 When the Program is Distributed as Source Code: a) it must be made available under this Agreement, or if the Program (i) is combined with other material in a separate file or files made available under a Secondary License, and (ii) the initial Contributor attached to the Source Code the notice described in Exhibit A of this Agreement, then the Program may be made available under the terms of such Secondary Licenses, and b) a copy of this Agreement must be included with each copy of the Program. 3.3 Contributors may not remove or alter any copyright, patent, trademark, attribution notices, disclaimers of warranty, or limitations of liability ("notices") contained within the Program from any copy of the Program which they Distribute, provided that Contributors may add their own appropriate notices. 4. COMMERCIAL DISTRIBUTION Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor ("Commercial Contributor") hereby agrees to defend and indemnify every other Contributor ("Indemnified Contributor") against any losses, damages and costs (collectively "Losses") arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense. For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages. 5. NO WARRANTY EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE PROGRAM IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement, including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. 6. DISCLAIMER OF LIABILITY EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT PERMITTED BY APPLICABLE LAW, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 7. GENERAL If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed. All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive. Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be Distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to Distribute the Program (including its Contributions) under the new version. Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved. Nothing in this Agreement is intended to be enforceable by any entity that is not a Contributor or Recipient. No third-party beneficiary rights are created under this Agreement. Exhibit A - Form of Secondary Licenses Notice "This Source Code may also be made available under the following Secondary Licenses when the conditions for such availability set forth in the Eclipse Public License, v. 2.0 are satisfied: {name license(s), version(s), and exceptions or additional permissions here}." Simply including a copy of this Agreement, including this Exhibit A is not sufficient to license the Source Code under Secondary Licenses. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership.asyncssh-2.20.0/MANIFEST.in000066400000000000000000000001501475467777400152140ustar00rootroot00000000000000include CONTRIBUTING.rst COPYRIGHT LICENSE README.rst pylintrc tox.ini include examples/*.py tests/*.py asyncssh-2.20.0/README.rst000066400000000000000000000175741475467777400151670ustar00rootroot00000000000000.. image:: https://readthedocs.org/projects/asyncssh/badge/?version=latest :target: https://asyncssh.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status .. image:: https://img.shields.io/pypi/v/asyncssh.svg :target: https://pypi.python.org/pypi/asyncssh/ :alt: AsyncSSH PyPI Project AsyncSSH: Asynchronous SSH for Python ===================================== AsyncSSH is a Python package which provides an asynchronous client and server implementation of the SSHv2 protocol on top of the Python 3.6+ asyncio framework. .. code:: python import asyncio, asyncssh, sys async def run_client(): async with asyncssh.connect('localhost') as conn: result = await conn.run('echo "Hello!"', check=True) print(result.stdout, end='') try: asyncio.get_event_loop().run_until_complete(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) Check out the `examples`__ to get started! __ http://asyncssh.readthedocs.io/en/stable/#client-examples Features -------- * Full support for SSHv2, SFTP, and SCP client and server functions * Shell, command, and subsystem channels * Environment variables, terminal type, and window size * Direct and forwarded TCP/IP channels * OpenSSH-compatible direct and forwarded UNIX domain socket channels * OpenSSH-compatible TUN/TAP channels and packet forwarding * Local and remote TCP/IP port forwarding * Local and remote UNIX domain socket forwarding * Dynamic TCP/IP port forwarding via SOCKS * X11 forwarding support on both the client and the server * SFTP protocol version 3 with OpenSSH extensions * Experimental support for SFTP versions 4-6, when requested * SCP protocol support, including third-party remote to remote copies * Multiple simultaneous sessions on a single SSH connection * Multiple SSH connections in a single event loop * Byte and string based I/O with settable encoding * A variety of `key exchange`__, `encryption`__, and `MAC`__ algorithms * Including post-quantum kex algorithms ML-KEM and SNTRUP * Support for `gzip compression`__ * Including OpenSSH variant to delay compression until after auth * User and host-based public key, password, and keyboard-interactive authentication methods * Many types and formats of `public keys and certificates`__ * Including OpenSSH-compatible support for U2F and FIDO2 security keys * Including PKCS#11 support for accessing PIV security tokens * Including support for X.509 certificates as defined in RFC 6187 * Support for accessing keys managed by `ssh-agent`__ on UNIX systems * Including agent forwarding support on both the client and the server * Support for accessing keys managed by PuTTY's Pageant agent on Windows * Support for accessing host keys via OpenSSH's ssh-keysign * OpenSSH-style `known_hosts file`__ support * OpenSSH-style `authorized_keys file`__ support * Partial support for `OpenSSH-style configuration files`__ * Compatibility with OpenSSH "Encrypt then MAC" option for better security * Time and byte-count based session key renegotiation * Designed to be easy to extend to support new forms of key exchange, authentication, encryption, and compression algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#key-exchange-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#encryption-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#mac-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#compression-algorithms __ http://asyncssh.readthedocs.io/en/stable/api.html#public-key-support __ http://asyncssh.readthedocs.io/en/stable/api.html#ssh-agent-support __ http://asyncssh.readthedocs.io/en/stable/api.html#known-hosts __ http://asyncssh.readthedocs.io/en/stable/api.html#authorized-keys __ http://asyncssh.readthedocs.io/en/stable/api.html#config-file-support License ------- This package is released under the following terms: Copyright (c) 2013-2024 by Ron Frederick and others. This program and the accompanying materials are made available under the terms of the Eclipse Public License v2.0 which accompanies this distribution and is available at: http://www.eclipse.org/legal/epl-2.0/ This program may also be made available under the following secondary licenses when the conditions for such availability set forth in the Eclipse Public License v2.0 are satisfied: GNU General Public License, Version 2.0, or any later versions of that license SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later For more information about this license, please see the `Eclipse Public License FAQ `_. Prerequisites ------------- To use AsyncSSH 2.0 or later, you need the following: * Python 3.6 or later * cryptography (PyCA) 3.1 or later Installation ------------ Install AsyncSSH by running: :: pip install asyncssh Optional Extras ^^^^^^^^^^^^^^^ There are some optional modules you can install to enable additional functionality: * Install bcrypt from https://pypi.python.org/pypi/bcrypt if you want support for OpenSSH private key encryption. * Install fido2 from https://pypi.org/project/fido2 if you want support for key exchange and authentication with U2F/FIDO2 security keys. * Install python-pkcs11 from https://pypi.org/project/python-pkcs11 if you want support for accessing PIV keys on PKCS#11 security tokens. * Install gssapi from https://pypi.python.org/pypi/gssapi if you want support for GSSAPI key exchange and authentication on UNIX. * Install liboqs from https://github.com/open-quantum-safe/liboqs if you want support for the OpenSSH post-quantum key exchange algorithms based on ML-KEM and SNTRUP. * Install libsodium from https://github.com/jedisct1/libsodium and libnacl from https://pypi.python.org/pypi/libnacl if you have a version of OpenSSL older than 1.1.1b installed and you want support for Curve25519 key exchange, Ed25519 keys and certificates, or the Chacha20-Poly1305 cipher. * Install libnettle from http://www.lysator.liu.se/~nisse/nettle/ if you want support for UMAC cryptographic hashes. * Install pyOpenSSL from https://pypi.python.org/pypi/pyOpenSSL if you want support for X.509 certificate authentication. * Install pywin32 from https://pypi.python.org/pypi/pywin32 if you want support for using the Pageant agent or support for GSSAPI key exchange and authentication on Windows. AsyncSSH defines the following optional PyPI extra packages to make it easy to install any or all of these dependencies: | bcrypt | fido2 | gssapi | libnacl | pkcs11 | pyOpenSSL | pywin32 For example, to install bcrypt, fido2, gssapi, libnacl, pkcs11, and pyOpenSSL on UNIX, you can run: :: pip install 'asyncssh[bcrypt,fido2,gssapi,libnacl,pkcs11,pyOpenSSL]' To install bcrypt, fido2, libnacl, pkcs11, pyOpenSSL, and pywin32 on Windows, you can run: :: pip install 'asyncssh[bcrypt,fido2,libnacl,pkcs11,pyOpenSSL,pywin32]' Note that you will still need to manually install the libsodium library listed above for libnacl to work correctly and/or libnettle for UMAC support. Unfortunately, since liboqs, libsodium, and libnettle are not Python packages, they cannot be directly installed using pip. Installing the development branch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If you would like to install the development branch of asyncssh directly from Github, you can use the following command to do this: :: pip install git+https://github.com/ronf/asyncssh@develop Mailing Lists ------------- Three mailing lists are available for AsyncSSH: * `asyncssh-announce@googlegroups.com`__: Project announcements * `asyncssh-dev@googlegroups.com`__: Development discussions * `asyncssh-users@googlegroups.com`__: End-user discussions __ http://groups.google.com/d/forum/asyncssh-announce __ http://groups.google.com/d/forum/asyncssh-dev __ http://groups.google.com/d/forum/asyncssh-users asyncssh-2.20.0/asyncssh/000077500000000000000000000000001475467777400153155ustar00rootroot00000000000000asyncssh-2.20.0/asyncssh/__init__.py000066400000000000000000000200631475467777400174270ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """An asynchronous SSH2 library for Python""" from .version import __author__, __author_email__, __url__, __version__ # pylint: disable=wildcard-import from .constants import * # pylint: enable=wildcard-import from .agent import SSHAgentClient, SSHAgentKeyPair, connect_agent from .auth_keys import SSHAuthorizedKeys from .auth_keys import import_authorized_keys, read_authorized_keys from .channel import SSHClientChannel, SSHServerChannel from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel from .client import SSHClient from .config import ConfigParseError from .forward import SSHForwarder from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions from .connection import SSHAcceptHandler from .connection import create_connection, create_server, connect, listen from .connection import connect_reverse, listen_reverse, get_server_host_key from .connection import get_server_auth_methods, run_client, run_server from .editor import SSHLineEditorChannel from .known_hosts import SSHKnownHosts from .known_hosts import import_known_hosts, read_known_hosts from .known_hosts import match_known_hosts from .listener import SSHListener from .logging import logger, set_log_level, set_sftp_log_level, set_debug_level from .misc import BytesOrStr from .misc import Error, DisconnectError, ChannelOpenError, ChannelListenError from .misc import ConnectionLost, CompressionError, HostKeyNotVerifiable from .misc import KeyExchangeFailed, IllegalUserName, MACError from .misc import PermissionDenied, ProtocolError, ProtocolNotSupported from .misc import ServiceNotAvailable, PasswordChangeRequired from .misc import BreakReceived, SignalReceived, TerminalSizeChanged from .pbe import KeyEncryptionError from .pkcs11 import load_pkcs11_keys from .process import SSHServerProcessFactory from .process import SSHClientProcess, SSHServerProcess from .process import SSHCompletedProcess, ProcessError from .process import TimeoutError # pylint: disable=redefined-builtin from .process import DEVNULL, PIPE, STDOUT from .public_key import SSHKey, SSHKeyPair, SSHCertificate from .public_key import KeyGenerationError, KeyImportError, KeyExportError from .public_key import generate_private_key, import_private_key from .public_key import import_public_key, import_certificate from .public_key import read_private_key, read_public_key, read_certificate from .public_key import read_private_key_list, read_public_key_list from .public_key import read_certificate_list from .public_key import load_keypairs, load_public_keys, load_certificates from .public_key import load_resident_keys from .rsa import set_default_skip_rsa_key_validation from .scp import scp from .session import DataType, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .server import SSHServer from .sftp import SFTPClient, SFTPClientFile, SFTPServer, SFTPError from .sftp import SFTPEOFError, SFTPNoSuchFile, SFTPPermissionDenied from .sftp import SFTPFailure, SFTPBadMessage, SFTPNoConnection from .sftp import SFTPInvalidHandle, SFTPNoSuchPath, SFTPFileAlreadyExists from .sftp import SFTPWriteProtect, SFTPNoMedia, SFTPNoSpaceOnFilesystem from .sftp import SFTPQuotaExceeded, SFTPUnknownPrincipal, SFTPLockConflict from .sftp import SFTPDirNotEmpty, SFTPNotADirectory, SFTPInvalidFilename from .sftp import SFTPLinkLoop, SFTPCannotDelete, SFTPInvalidParameter from .sftp import SFTPFileIsADirectory, SFTPByteRangeLockConflict from .sftp import SFTPByteRangeLockRefused, SFTPDeletePending from .sftp import SFTPFileCorrupt, SFTPOwnerInvalid, SFTPGroupInvalid from .sftp import SFTPNoMatchingByteRangeLock from .sftp import SFTPConnectionLost, SFTPOpUnsupported from .sftp import SFTPAttrs, SFTPVFSAttrs, SFTPName, SFTPLimits from .sftp import SEEK_SET, SEEK_CUR, SEEK_END from .stream import SSHSocketSessionFactory, SSHServerSessionFactory from .stream import SFTPServerFactory, SSHReader, SSHWriter from .subprocess import SSHSubprocessReadPipe, SSHSubprocessWritePipe from .subprocess import SSHSubprocessProtocol, SSHSubprocessTransport # Import these explicitly to trigger register calls in them from . import sk_eddsa, sk_ecdsa, eddsa, ecdsa, rsa, dsa, kex_dh, kex_rsa __all__ = [ '__author__', '__author_email__', '__url__', '__version__', 'BreakReceived', 'BytesOrStr', 'ChannelListenError', 'ChannelOpenError', 'CompressionError', 'ConfigParseError', 'ConnectionLost', 'DEVNULL', 'DataType', 'DisconnectError', 'Error', 'HostKeyNotVerifiable', 'IllegalUserName', 'KeyEncryptionError', 'KeyExchangeFailed', 'KeyExportError', 'KeyGenerationError', 'KeyImportError', 'MACError', 'PIPE', 'PasswordChangeRequired', 'PermissionDenied', 'ProcessError', 'ProtocolError', 'ProtocolNotSupported', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', 'SFTPAttrs', 'SFTPBadMessage', 'SFTPByteRangeLockConflict', 'SFTPByteRangeLockRefused', 'SFTPCannotDelete', 'SFTPClient', 'SFTPClientFile', 'SFTPConnectionLost', 'SFTPDeletePending', 'SFTPDirNotEmpty', 'SFTPEOFError', 'SFTPError', 'SFTPFailure', 'SFTPFileAlreadyExists', 'SFTPFileCorrupt', 'SFTPFileIsADirectory', 'SFTPGroupInvalid', 'SFTPInvalidFilename', 'SFTPInvalidHandle', 'SFTPInvalidParameter', 'SFTPLimits', 'SFTPLinkLoop', 'SFTPLockConflict', 'SFTPName', 'SFTPNoConnection', 'SFTPNoMatchingByteRangeLock', 'SFTPNoMedia', 'SFTPNoSpaceOnFilesystem', 'SFTPNoSuchFile', 'SFTPNoSuchPath', 'SFTPNotADirectory', 'SFTPOpUnsupported', 'SFTPOwnerInvalid', 'SFTPPermissionDenied', 'SFTPQuotaExceeded', 'SFTPServer', 'SFTPServerFactory', 'SFTPUnknownPrincipal', 'SFTPVFSAttrs', 'SFTPWriteProtect', 'SSHAcceptHandler', 'SSHAcceptor', 'SSHAgentClient', 'SSHAgentKeyPair', 'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient', 'SSHClientChannel', 'SSHClientConnection', 'SSHClientConnectionOptions', 'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess', 'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts', 'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer', 'SSHServerChannel', 'SSHServerConnection', 'SSHServerConnectionOptions', 'SSHServerProcess', 'SSHServerProcessFactory', 'SSHServerSession', 'SSHServerSessionFactory', 'SSHSocketSessionFactory', 'SSHSubprocessProtocol', 'SSHSubprocessReadPipe', 'SSHSubprocessTransport', 'SSHSubprocessWritePipe', 'SSHTCPChannel', 'SSHTCPSession', 'SSHTunTapChannel', 'SSHTunTapSession', 'SSHUNIXChannel', 'SSHUNIXSession', 'SSHWriter', 'STDOUT', 'ServiceNotAvailable', 'SignalReceived', 'TerminalSizeChanged', 'TimeoutError', 'connect', 'connect_agent', 'connect_reverse', 'create_connection', 'create_server', 'generate_private_key', 'get_server_auth_methods', 'get_server_host_key', 'import_authorized_keys', 'import_certificate', 'import_known_hosts', 'import_private_key', 'import_public_key', 'listen', 'listen_reverse', 'load_certificates', 'load_keypairs', 'load_pkcs11_keys', 'load_public_keys', 'load_resident_keys', 'logger', 'match_known_hosts', 'read_authorized_keys', 'read_certificate', 'read_certificate_list', 'read_known_hosts', 'read_private_key', 'read_private_key_list', 'read_public_key', 'read_public_key_list', 'run_client', 'run_server', 'scp', 'set_debug_level', 'set_default_skip_rsa_key_validation', 'set_log_level', 'set_sftp_log_level' ] asyncssh-2.20.0/asyncssh/agent.py000066400000000000000000000552411475467777400167740ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent client""" import asyncio import os import sys from types import TracebackType from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Type, Union from typing_extensions import Protocol, Self from .listener import SSHForwardListener from .misc import async_context_manager, maybe_wait_closed from .packet import Byte, String, UInt32, PacketDecodeError, SSHPacket from .public_key import KeyPairListArg, SSHCertificate, SSHKeyPair from .public_key import load_default_keypairs, load_keypairs if TYPE_CHECKING: from tempfile import TemporaryDirectory class AgentReader(Protocol): """Protocol for reading from an SSH agent""" async def readexactly(self, n: int) -> bytes: """Read exactly n bytes from the SSH agent""" class AgentWriter(Protocol): """Protocol for writing to an SSH agent""" def write(self, data: bytes) -> None: """Write bytes to the SSH agent""" def close(self) -> None: """Close connection to the SSH agent""" async def wait_closed(self) -> None: """Wait for the connection to the SSH agent to close""" if sys.platform == 'win32': # pragma: no cover from .agent_win32 import open_agent else: from .agent_unix import open_agent class _SupportsOpenAgentConnection(Protocol): """A class that supports open_agent_connection""" async def open_agent_connection(self) -> Tuple[AgentReader, AgentWriter]: """Open a forwarded ssh-agent connection back to the client""" _AgentPath = Union[str, _SupportsOpenAgentConnection] # Client request message numbers SSH_AGENTC_REQUEST_IDENTITIES = 11 SSH_AGENTC_SIGN_REQUEST = 13 SSH_AGENTC_ADD_IDENTITY = 17 SSH_AGENTC_REMOVE_IDENTITY = 18 SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19 SSH_AGENTC_ADD_SMARTCARD_KEY = 20 SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21 SSH_AGENTC_LOCK = 22 SSH_AGENTC_UNLOCK = 23 SSH_AGENTC_ADD_ID_CONSTRAINED = 25 SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26 SSH_AGENTC_EXTENSION = 27 # Agent response message numbers SSH_AGENT_FAILURE = 5 SSH_AGENT_SUCCESS = 6 SSH_AGENT_IDENTITIES_ANSWER = 12 SSH_AGENT_SIGN_RESPONSE = 14 SSH_AGENT_EXTENSION_FAILURE = 28 # SSH agent constraint numbers SSH_AGENT_CONSTRAIN_LIFETIME = 1 SSH_AGENT_CONSTRAIN_CONFIRM = 2 SSH_AGENT_CONSTRAIN_EXTENSION = 255 # SSH agent signature flags SSH_AGENT_RSA_SHA2_256 = 2 SSH_AGENT_RSA_SHA2_512 = 4 class SSHAgentKeyPair(SSHKeyPair): """Surrogate for a key managed by the SSH agent""" _key_type = 'agent' def __init__(self, agent: 'SSHAgentClient', algorithm: bytes, public_data: bytes, comment: bytes): is_cert = algorithm.endswith(b'-cert-v01@openssh.com') if is_cert: if algorithm.startswith(b'sk-'): sig_algorithm = algorithm[:-21] + b'@openssh.com' else: sig_algorithm = algorithm[:-21] else: sig_algorithm = algorithm # Neither Pageant nor the Win10 OpenSSH agent seems to support the # ssh-agent protocol flags used to request RSA SHA2 signatures yet if sig_algorithm == b'ssh-rsa' and sys.platform != 'win32': sig_algorithms: Sequence[bytes] = \ (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa') else: sig_algorithms = (sig_algorithm,) if is_cert: host_key_algorithms: Sequence[bytes] = (algorithm,) else: host_key_algorithms = sig_algorithms super().__init__(algorithm, sig_algorithm, sig_algorithms, host_key_algorithms, public_data, comment) self._agent = agent self._is_cert = is_cert self._flags = 0 @property def has_cert(self) -> bool: """ Return if this key pair has an associated cert""" return self._is_cert @property def has_x509_chain(self) -> bool: """ Return if this key pair has an associated X.509 cert chain""" return False def set_certificate(self, cert: SSHCertificate) -> None: """Set certificate to use with this key""" super().set_certificate(cert) self._is_cert = True def set_sig_algorithm(self, sig_algorithm: bytes) -> None: """Set the signature algorithm to use when signing data""" super().set_sig_algorithm(sig_algorithm) if sig_algorithm in (b'rsa-sha2-256', b'x509v3-rsa2048-sha256'): self._flags |= SSH_AGENT_RSA_SHA2_256 elif sig_algorithm == b'rsa-sha2-512': self._flags |= SSH_AGENT_RSA_SHA2_512 async def sign_async(self, data: bytes) -> bytes: """Asynchronously sign a block of data with this private key""" return await self._agent.sign(self.key_public_data, data, self._flags) async def remove(self) -> None: """Remove this key pair from the agent""" await self._agent.remove_keys([self]) class SSHAgentClient: """SSH agent client""" def __init__(self, agent_path: _AgentPath): self._agent_path = agent_path self._reader: Optional[AgentReader] = None self._writer: Optional[AgentWriter] = None self._lock = asyncio.Lock() async def __aenter__(self) -> Self: """Allow SSHAgentClient to be used as an async context manager""" return self async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> bool: """Wait for connection close when used as an async context manager""" await self._cleanup() return False async def _cleanup(self) -> None: """Clean up this SSH agent client""" self.close() await self.wait_closed() @staticmethod def encode_constraints(lifetime: Optional[int], confirm: bool) -> bytes: """Encode key constraints""" result = b'' if lifetime: result += Byte(SSH_AGENT_CONSTRAIN_LIFETIME) + UInt32(lifetime) if confirm: result += Byte(SSH_AGENT_CONSTRAIN_CONFIRM) return result async def connect(self) -> None: """Connect to the SSH agent""" if isinstance(self._agent_path, str): self._reader, self._writer = await open_agent(self._agent_path) else: self._reader, self._writer = \ await self._agent_path.open_agent_connection() async def _make_request(self, msgtype: int, *args: bytes) -> \ Tuple[int, SSHPacket]: """Send an SSH agent request""" async with self._lock: try: if not self._writer: await self.connect() reader = self._reader writer = self._writer assert reader is not None assert writer is not None payload = Byte(msgtype) + b''.join(args) writer.write(UInt32(len(payload)) + payload) resplen = int.from_bytes((await reader.readexactly(4)), 'big') resp = SSHPacket(await reader.readexactly(resplen)) resptype = resp.get_byte() return resptype, resp except (OSError, EOFError, PacketDecodeError) as exc: await self._cleanup() raise ValueError(str(exc)) from None async def get_keys(self, identities: Optional[Sequence[bytes]] = None) -> \ Sequence[SSHKeyPair]: """Request the available client keys This method is a coroutine which returns a list of client keys available in the ssh-agent. :param identities: (optional) A list of allowed byte string identities to return. If empty, all identities on the SSH agent will be returned. :returns: A list of :class:`SSHKeyPair` objects """ resptype, resp = \ await self._make_request(SSH_AGENTC_REQUEST_IDENTITIES) if resptype == SSH_AGENT_IDENTITIES_ANSWER: result: List[SSHKeyPair] = [] num_keys = resp.get_uint32() for _ in range(num_keys): key_blob = resp.get_string() comment = resp.get_string() if identities and key_blob not in identities: continue packet = SSHPacket(key_blob) algorithm = packet.get_string() result.append(SSHAgentKeyPair(self, algorithm, key_blob, comment)) resp.check_end() return result else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def sign(self, key_blob: bytes, data: bytes, flags: int = 0) -> bytes: """Sign a block of data with the requested key""" resptype, resp = await self._make_request(SSH_AGENTC_SIGN_REQUEST, String(key_blob), String(data), UInt32(flags)) if resptype == SSH_AGENT_SIGN_RESPONSE: sig = resp.get_string() resp.check_end() return sig elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to sign with requested key') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def add_keys(self, keylist: KeyPairListArg = (), passphrase: Optional[str] = None, lifetime: Optional[int] = None, confirm: bool = False) -> None: """Add keys to the agent This method adds a list of local private keys and optional matching certificates to the agent. :param keylist: (optional) The list of keys to add. If not specified, an attempt will be made to load keys from the files :file:`.ssh/id_ed25519_sk`, :file:`.ssh/id_ecdsa_sk`, :file:`.ssh/id_ed448`, :file:`.ssh/id_ed25519`, :file:`.ssh/id_ecdsa`, :file:`.ssh/id_rsa` and :file:`.ssh/id_dsa` in the user's home directory with optional matching certificates loaded from the files :file:`.ssh/id_ed25519_sk-cert.pub`, :file:`.ssh/id_ecdsa_sk-cert.pub`, :file:`.ssh/id_ed448-cert.pub`, :file:`.ssh/id_ed25519-cert.pub`, :file:`.ssh/id_ecdsa-cert.pub`, :file:`.ssh/id_rsa-cert.pub`, and :file:`.ssh/id_dsa-cert.pub`. Failures when adding keys are ignored in this case, as the agent may not recognize some of these key types. :param passphrase: (optional) The passphrase to use to decrypt the keys. :param lifetime: (optional) The time in seconds after which the keys should be automatically deleted, or `None` to store these keys indefinitely (the default). :param confirm: (optional) Whether or not to require confirmation for each private key operation which uses these keys, defaulting to `False`. :type keylist: *see* :ref:`SpecifyingPrivateKeys` :type passphrase: `str` :type lifetime: `int` or `None` :type confirm: `bool` :raises: :exc:`ValueError` if the keys cannot be added """ if keylist: keypairs = load_keypairs(keylist, passphrase) ignore_failures = False else: keypairs = load_default_keypairs(passphrase) ignore_failures = True base_constraints = self.encode_constraints(lifetime, confirm) provider = os.environ.get('SSH_SK_PROVIDER') or 'internal' sk_constraints = Byte(SSH_AGENT_CONSTRAIN_EXTENSION) + \ String('sk-provider@openssh.com') + \ String(provider) for keypair in keypairs: constraints = base_constraints if keypair.algorithm.startswith(b'sk-'): constraints += sk_constraints msgtype = SSH_AGENTC_ADD_ID_CONSTRAINED if constraints else \ SSH_AGENTC_ADD_IDENTITY comment = keypair.get_comment_bytes() resptype, resp = \ await self._make_request(msgtype, keypair.get_agent_private_key(), String(comment or b''), constraints) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: if not ignore_failures: raise ValueError('Unable to add key') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def add_smartcard_keys(self, provider: str, pin: Optional[str] = None, lifetime: Optional[int] = None, confirm: bool = False) -> None: """Store keys associated with a smart card in the agent :param provider: The name of the smart card provider :param pin: (optional) The PIN to use to unlock the smart card :param lifetime: (optional) The time in seconds after which the keys should be automatically deleted, or `None` to store these keys indefinitely (the default). :param confirm: (optional) Whether or not to require confirmation for each private key operation which uses these keys, defaulting to `False`. :type provider: `str` :type pin: `str` or `None` :type lifetime: `int` or `None` :type confirm: `bool` :raises: :exc:`ValueError` if the keys cannot be added """ constraints = self.encode_constraints(lifetime, confirm) msgtype = SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED \ if constraints else SSH_AGENTC_ADD_SMARTCARD_KEY resptype, resp = await self._make_request(msgtype, String(provider), String(pin or ''), constraints) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to add keys') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_keys(self, keylist: Sequence[SSHKeyPair]) -> None: """Remove a key stored in the agent :param keylist: The list of keys to remove. :type keylist: `list` of :class:`SSHKeyPair` :raises: :exc:`ValueError` if any keys are not found """ for keypair in keylist: resptype, resp = \ await self._make_request(SSH_AGENTC_REMOVE_IDENTITY, String(keypair.public_data)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Key not found') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_smartcard_keys(self, provider: str, pin: Optional[str] = None) -> None: """Remove keys associated with a smart card stored in the agent :param provider: The name of the smart card provider :param pin: (optional) The PIN to use to unlock the smart card :type provider: `str` :type pin: `str` or `None` :raises: :exc:`ValueError` if the keys are not found """ resptype, resp = \ await self._make_request(SSH_AGENTC_REMOVE_SMARTCARD_KEY, String(provider), String(pin or '')) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Keys not found') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_all(self) -> None: """Remove all keys stored in the agent :raises: :exc:`ValueError` if the keys can't be removed """ resptype, resp = \ await self._make_request(SSH_AGENTC_REMOVE_ALL_IDENTITIES) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to remove all keys') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def lock(self, passphrase: str) -> None: """Lock the agent using the specified passphrase .. note:: The lock and unlock actions don't appear to be supported on the Windows 10 OpenSSH agent. :param passphrase: The passphrase required to later unlock the agent :type passphrase: `str` :raises: :exc:`ValueError` if the agent can't be locked """ resptype, resp = await self._make_request(SSH_AGENTC_LOCK, String(passphrase)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to lock SSH agent') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def unlock(self, passphrase: str) -> None: """Unlock the agent using the specified passphrase .. note:: The lock and unlock actions don't appear to be supported on the Windows 10 OpenSSH agent. :param passphrase: The passphrase to use to unlock the agent :type passphrase: `str` :raises: :exc:`ValueError` if the agent can't be unlocked """ resptype, resp = await self._make_request(SSH_AGENTC_UNLOCK, String(passphrase)) if resptype == SSH_AGENT_SUCCESS: resp.check_end() elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to unlock SSH agent') else: raise ValueError(f'Unknown SSH agent response: {resptype}') async def query_extensions(self) -> Sequence[str]: """Return a list of extensions supported by the agent :returns: A list of strings of supported extension names """ resptype, resp = await self._make_request(SSH_AGENTC_EXTENSION, String('query')) if resptype == SSH_AGENT_SUCCESS: result = [] while resp: exttype = resp.get_string() try: exttype_str = exttype.decode('utf-8') except UnicodeDecodeError: raise ValueError('Invalid extension type name') from None result.append(exttype_str) return result elif resptype == SSH_AGENT_FAILURE: return [] else: raise ValueError(f'Unknown SSH agent response: {resptype}') def close(self) -> None: """Close the SSH agent connection This method closes the connection to the ssh-agent. Any attempts to use this :class:`SSHAgentClient` or the key pairs it previously returned will result in an error. """ if self._writer: self._writer.close() async def wait_closed(self) -> None: """Wait for this agent connection to close This method is a coroutine which can be called to block until the connection to the agent has finished closing. """ if self._writer: await maybe_wait_closed(self._writer) self._reader = None self._writer = None class SSHAgentListener: """Listener used to forward agent connections""" def __init__(self, tempdir: 'TemporaryDirectory[str]', path: str, unix_listener: SSHForwardListener): self._tempdir = tempdir self._path = path self._unix_listener = unix_listener def get_path(self) -> str: """Return the path being listened on""" return self._path def close(self) -> None: """Close the agent listener""" self._unix_listener.close() self._tempdir.cleanup() @async_context_manager async def connect_agent(agent_path: _AgentPath = '') -> 'SSHAgentClient': """Make a connection to the SSH agent This function attempts to connect to an ssh-agent process listening on a UNIX domain socket at `agent_path`. If not provided, it will attempt to get the path from the `SSH_AUTH_SOCK` environment variable. If the connection is successful, an :class:`SSHAgentClient` object is returned that has methods on it you can use to query the ssh-agent. If no path is specified and the environment variable is not set or the connection to the agent fails, an error is raised. :param agent_path: (optional) The path to use to contact the ssh-agent process, or the :class:`SSHServerConnection` to forward the agent request over. :type agent_path: `str` or :class:`SSHServerConnection` :returns: An :class:`SSHAgentClient` :raises: :exc:`OSError` or :exc:`ChannelOpenError` if the connection to the agent can't be opened """ if not agent_path: agent_path = os.environ.get('SSH_AUTH_SOCK', '') agent = SSHAgentClient(agent_path) await agent.connect() return agent asyncssh-2.20.0/asyncssh/agent_unix.py000066400000000000000000000022541475467777400200330ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent support code for UNIX""" import asyncio import errno from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: # pylint: disable=cyclic-import from .agent import AgentReader, AgentWriter async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']: """Open a connection to ssh-agent""" if not agent_path: raise OSError(errno.ENOENT, 'Agent not found') return await asyncio.open_unix_connection(agent_path) asyncssh-2.20.0/asyncssh/agent_win32.py000066400000000000000000000120251475467777400200070ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH agent support code for Windows""" # Some of the imports below won't be found when running pylint on UNIX # pylint: disable=import-error import asyncio import ctypes import ctypes.wintypes import errno from typing import TYPE_CHECKING, Tuple, Union, cast from .misc import open_file if TYPE_CHECKING: # pylint: disable=cyclic-import from .agent import AgentReader, AgentWriter try: import mmapfile import win32api import win32con import win32ui _pywin32_available = True except ImportError: _pywin32_available = False _AGENT_COPYDATA_ID = 0x804e50ba _AGENT_MAX_MSGLEN = 8192 _AGENT_NAME = 'Pageant' _DEFAULT_OPENSSH_PATH = r'\\.\pipe\openssh-ssh-agent' def _find_agent_window() -> 'win32ui.PyCWnd': """Find and return the Pageant window""" if _pywin32_available: try: return win32ui.FindWindow(_AGENT_NAME, _AGENT_NAME) except win32ui.error: raise OSError(errno.ENOENT, 'Agent not found') from None else: raise OSError(errno.ENOENT, 'PyWin32 not installed') from None class _CopyDataStruct(ctypes.Structure): """Windows COPYDATASTRUCT argument for WM_COPYDATA message""" _fields_ = (('dwData', ctypes.wintypes.LPARAM), ('cbData', ctypes.wintypes.DWORD), ('lpData', ctypes.c_char_p)) class _PageantTransport: """Transport to connect to Pageant agent on Windows""" def __init__(self) -> None: self._mapname = f'{_AGENT_NAME}{win32api.GetCurrentThreadId():08x}' try: self._mapfile = mmapfile.mmapfile('', self._mapname, _AGENT_MAX_MSGLEN, 0, 0) except mmapfile.error as exc: raise OSError(errno.EIO, str(exc)) from None self._cds = _CopyDataStruct(_AGENT_COPYDATA_ID, len(self._mapname) + 1, self._mapname.encode()) self._writing = False def write(self, data: bytes) -> None: """Write request data to Pageant agent""" if not self._writing: self._mapfile.seek(0) self._writing = True try: self._mapfile.write(data) except ValueError as exc: raise OSError(errno.EIO, str(exc)) from None async def readexactly(self, n: int) -> bytes: """Read response data from Pageant agent""" if self._writing: cwnd = _find_agent_window() if not cwnd.SendMessage(win32con.WM_COPYDATA, 0, cast(int, self._cds)): raise OSError(errno.EIO, 'Unable to send agent request') self._writing = False self._mapfile.seek(0) result = self._mapfile.read(n) if len(result) != n: raise asyncio.IncompleteReadError(result, n) return result def close(self) -> None: """Close the connection to Pageant""" if self._mapfile: self._mapfile.close() async def wait_closed(self) -> None: """Wait for the transport to close""" class _W10OpenSSHTransport: """Transport to connect to OpenSSH agent on Windows 10""" def __init__(self, agent_path: str): self._agentfile = open_file(agent_path, 'r+b') async def readexactly(self, n: int) -> bytes: """Read response data from OpenSSH agent""" result = self._agentfile.read(n) if len(result) != n: raise asyncio.IncompleteReadError(result, n) return result def write(self, data: bytes) -> None: """Write request data to OpenSSH agent""" self._agentfile.write(data) def close(self) -> None: """Close the connection to OpenSSH""" if self._agentfile: self._agentfile.close() async def wait_closed(self) -> None: """Wait for the transport to close""" async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']: """Open a connection to the Pageant or Windows 10 OpenSSH agent""" transport: Union[None, _PageantTransport, _W10OpenSSHTransport] = None if not agent_path: try: _find_agent_window() transport = _PageantTransport() except OSError: agent_path = _DEFAULT_OPENSSH_PATH if not transport: transport = _W10OpenSSHTransport(agent_path) return transport, transport asyncssh-2.20.0/asyncssh/asn1.py000066400000000000000000000575671475467777400165550ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Utilities for encoding and decoding ASN.1 DER data The der_encode function takes a Python value and encodes it in DER format, returning a byte string. In addition to supporting standard Python types, BitString can be used to encode a DER bit string, ObjectIdentifier can be used to encode OIDs, values can be wrapped in a TaggedDERObject to set an alternate DER tag on them, and non-standard types can be encoded by placing them in a RawDERObject. The der_decode function takes a byte string in DER format and decodes it into the corresponding Python values. """ from typing import Dict, FrozenSet, Sequence, Set, Tuple, Type, TypeVar, Union from typing import cast _DERClass = Type['DERType'] _DERClassVar = TypeVar('_DERClassVar', bound='_DERClass') # ASN.1 object classes UNIVERSAL = 0x00 APPLICATION = 0x01 CONTEXT_SPECIFIC = 0x02 PRIVATE = 0x03 # ASN.1 universal object tags END_OF_CONTENT = 0x00 BOOLEAN = 0x01 INTEGER = 0x02 BIT_STRING = 0x03 OCTET_STRING = 0x04 NULL = 0x05 OBJECT_IDENTIFIER = 0x06 UTF8_STRING = 0x0c SEQUENCE = 0x10 SET = 0x11 IA5_STRING = 0x16 _asn1_class = ('Universal', 'Application', 'Context-specific', 'Private') _der_class_by_tag: Dict[int, _DERClass] = {} _der_class_by_type: Dict[Union[object, _DERClass], _DERClass] = {} def _encode_identifier(asn1_class: int, constructed: bool, tag: int) -> bytes: """Encode a DER object's identifier""" if asn1_class not in (UNIVERSAL, APPLICATION, CONTEXT_SPECIFIC, PRIVATE): raise ASN1EncodeError('Invalid ASN.1 class') flags = (asn1_class << 6) | (0x20 if constructed else 0x00) if tag < 0x20: identifier = [flags | tag] else: identifier = [tag & 0x7f] while tag >= 0x80: tag >>= 7 identifier.append(0x80 | (tag & 0x7f)) identifier.append(flags | 0x1f) return bytes(identifier[::-1]) class ASN1Error(ValueError): """ASN.1 coding error""" class ASN1EncodeError(ASN1Error): """ASN.1 DER encoding error""" class ASN1DecodeError(ASN1Error): """ASN.1 DER decoding error""" class DERType: """Parent class for classes which use DERTag decorator""" identifier: bytes = b'' @staticmethod def encode(value: object) -> bytes: """Encode value as a DER byte string""" raise NotImplementedError @classmethod def decode(cls, constructed: bool, content: bytes) -> object: """Decode a DER byte string into an object""" raise NotImplementedError class DERTag: """A decorator used by classes which convert values to/from DER Classes which convert Python values to and from DER format should use the DERTag decorator to indicate what DER tag value they understand. When DER data is decoded, the tag is looked up in the list to see which class to call to perform the decoding. Classes which convert existing Python types to and from DER format can specify the list of types they understand in the optional "types" argument. Otherwise, conversion is expected to be to and from the new class being defined. """ def __init__(self, tag: int, types: Sequence[object] = (), constructed: bool = False): self._tag = tag self._types = types self._identifier = _encode_identifier(UNIVERSAL, constructed, tag) def __call__(self, cls: _DERClassVar) -> _DERClassVar: cls.identifier = self._identifier _der_class_by_tag[self._tag] = cls if self._types: for t in self._types: _der_class_by_type[t] = cls else: _der_class_by_type[cls] = cls return cls class RawDERObject: """A class which can encode a DER object of an arbitrary type This object is initialized with an ASN.1 class, tag, and a byte string representing the already encoded data. Such objects will never have the constructed flag set, since that is represented here as a TaggedDERObject. """ def __init__(self, tag: int, content: bytes, asn1_class: int): self.asn1_class = asn1_class self.tag = tag self.content = content def __repr__(self) -> str: return f'RawDERObject({_asn1_class[self.asn1_class]}, ' \ f'{self.tag}, {self.content!r})' def __eq__(self, other: object) -> bool: if not isinstance(other, RawDERObject): # pragma: no cover return NotImplemented return (self.asn1_class == other.asn1_class and self.tag == other.tag and self.content == other.content) def __hash__(self) -> int: return hash((self.asn1_class, self.tag, self.content)) def encode_identifier(self) -> bytes: """Encode the DER identifier for this object as a byte string""" return _encode_identifier(self.asn1_class, False, self.tag) @staticmethod def encode(value: object) -> bytes: """Encode the content for this object as a DER byte string""" return cast('RawDERObject', value).content class TaggedDERObject: """An explicitly tagged DER object This object provides a way to wrap an ASN.1 object with an explicit tag. The value (including the tag representing its actual type) is then encoded as part of its value. By default, the ASN.1 class for these objects is CONTEXT_SPECIFIC, and the DER encoding always marks these values as constructed. """ def __init__(self, tag: int, value: object, asn1_class: int = CONTEXT_SPECIFIC): self.asn1_class = asn1_class self.tag = tag self.value = value def __repr__(self) -> str: if self.asn1_class == CONTEXT_SPECIFIC: return f'TaggedDERObject({self.tag}, {self.value!r})' else: return f'TaggedDERObject({_asn1_class[self.asn1_class]}, ' \ f'{self.tag}, {self.value!r})' def __eq__(self, other: object) -> bool: if not isinstance(other, TaggedDERObject): # pragma: no cover return NotImplemented return (self.asn1_class == other.asn1_class and self.tag == other.tag and self.value == other.value) def __hash__(self) -> int: return hash((self.asn1_class, self.tag, self.value)) def encode_identifier(self) -> bytes: """Encode the DER identifier for this object as a byte string""" return _encode_identifier(self.asn1_class, True, self.tag) @staticmethod def encode(value: object) -> bytes: """Encode the content for this object as a DER byte string""" return der_encode(cast('TaggedDERObject', value).value) @DERTag(NULL, (type(None),)) class _Null(DERType): """A null value""" @staticmethod def encode(value: object) -> bytes: """Encode a DER null value""" # pylint: disable=unused-argument return b'' @classmethod def decode(cls, constructed: bool, content: bytes) -> None: """Decode a DER null value""" if constructed: raise ASN1DecodeError('NULL should not be constructed') if content: raise ASN1DecodeError('NULL should not have associated content') return None @DERTag(BOOLEAN, (bool,)) class _Boolean(DERType): """A boolean value""" @staticmethod def encode(value: object) -> bytes: """Encode a DER boolean value""" return b'\xff' if value else b'\0' @classmethod def decode(cls, constructed: bool, content: bytes) -> bool: """Decode a DER boolean value""" if constructed: raise ASN1DecodeError('BOOLEAN should not be constructed') if content not in {b'\x00', b'\xff'}: raise ASN1DecodeError('BOOLEAN content must be 0x00 or 0xff') return bool(content[0]) @DERTag(INTEGER, (int,)) class _Integer(DERType): """An integer value""" @staticmethod def encode(value: object) -> bytes: """Encode a DER integer value""" i = cast(int, value) l = i.bit_length() l = l // 8 + 1 if l % 8 == 0 else (l + 7) // 8 result = i.to_bytes(l, 'big', signed=True) return result[1:] if result.startswith(b'\xff\x80') else result @classmethod def decode(cls, constructed: bool, content: bytes) -> int: """Decode a DER integer value""" if constructed: raise ASN1DecodeError('INTEGER should not be constructed') return int.from_bytes(content, 'big', signed=True) @DERTag(OCTET_STRING, (bytes, bytearray)) class _OctetString(DERType): """An octet string value""" @staticmethod def encode(value: object) -> bytes: """Encode a DER octet string""" return cast(bytes, value) @classmethod def decode(cls, constructed: bool, content: bytes) -> bytes: """Decode a DER octet string""" if constructed: raise ASN1DecodeError('OCTET STRING should not be constructed') return content @DERTag(UTF8_STRING, (str,)) class _UTF8String(DERType): """A UTF-8 string value""" @staticmethod def encode(value: object) -> bytes: """Encode a DER UTF-8 string""" return cast(str, value).encode('utf-8') @classmethod def decode(cls, constructed: bool, content: bytes) -> str: """Decode a DER UTF-8 string""" if constructed: raise ASN1DecodeError('UTF8 STRING should not be constructed') return content.decode('utf-8') @DERTag(SEQUENCE, (list, tuple), constructed=True) class _Sequence(DERType): """A sequence of values""" @staticmethod def encode(value: object) -> bytes: """Encode a sequence of DER values""" seq_value = cast(Sequence[object], value) return b''.join(der_encode(item) for item in seq_value) @classmethod def decode(cls, constructed: bool, content: bytes) -> Sequence[object]: """Decode a sequence of DER values""" if not constructed: raise ASN1DecodeError('SEQUENCE should always be constructed') offset = 0 length = len(content) value = [] while offset < length: item, consumed = der_decode_partial(content[offset:]) value.append(item) offset += consumed return tuple(value) @DERTag(SET, (set, frozenset), constructed=True) class _Set(DERType): """A set of DER values""" @staticmethod def encode(value: object) -> bytes: """Encode a set of DER values""" set_value = cast(Union[FrozenSet[object], Set[object]], value) return b''.join(sorted(der_encode(item) for item in set_value)) @classmethod def decode(cls, constructed: bool, content: bytes) -> FrozenSet[object]: """Decode a set of DER values""" if not constructed: raise ASN1DecodeError('SET should always be constructed') offset = 0 length = len(content) value = set() while offset < length: item, consumed = der_decode_partial(content[offset:]) value.add(item) offset += consumed return frozenset(value) @DERTag(BIT_STRING) class BitString(DERType): """A string of bits This object can be initialized either with a byte string and an optional count of the number of least-significant bits in the last byte which should not be included in the value, or with a string consisting only of the digits '0' and '1'. An optional 'named' flag can also be set, indicating that the BitString was specified with named bits, indicating that the proper DER encoding of it should strip any trailing zeroes. """ def __init__(self, value: object, unused: int = 0, named: bool = False): if unused < 0 or unused > 7: raise ASN1EncodeError('Unused bit count must be between 0 and 7') if isinstance(value, bytes): if unused: if not value: raise ASN1EncodeError('Can\'t have unused bits with empty ' 'value') elif value[-1] & ((1 << unused) - 1): raise ASN1EncodeError('Unused bits in value should be ' 'zero') elif isinstance(value, str): if unused: raise ASN1EncodeError('Unused bit count should not be set ' 'when providing a string') used = len(value) % 8 unused = 8 - used if used else 0 value += unused * '0' value = bytes(int(value[i:i+8], 2) for i in range(0, len(value), 8)) else: raise ASN1EncodeError('Unexpected type of bit string value') if named: while value and not value[-1] & (1 << unused): unused += 1 if unused == 8: value = value[:-1] unused = 0 self.value = value self.unused = unused def __str__(self) -> str: result = ''.join(bin(b)[2:].zfill(8) for b in self.value) if self.unused: result = result[:-self.unused] return result def __repr__(self) -> str: return f"BitString('{self}')" def __eq__(self, other: object) -> bool: if not isinstance(other, BitString): # pragma: no cover return NotImplemented return self.value == other.value and self.unused == other.unused def __hash__(self) -> int: return hash((self.value, self.unused)) @staticmethod def encode(value: object) -> bytes: """Encode a DER bit string""" bitstr_value = cast('BitString', value) return bytes((bitstr_value.unused,)) + bitstr_value.value @classmethod def decode(cls, constructed: bool, content: bytes) -> 'BitString': """Decode a DER bit string""" if constructed: raise ASN1DecodeError('BIT STRING should not be constructed') if not content or content[0] > 7: raise ASN1DecodeError('Invalid unused bit count') return cls(content[1:], unused=content[0]) @DERTag(IA5_STRING) class IA5String(DERType): """An ASCII string value""" def __init__(self, value: Union[bytes, bytearray]): self.value = value def __str__(self) -> str: return self.value.decode('ascii') def __repr__(self) -> str: return f'IA5String({self.value!r})' def __eq__(self, other: object) -> bool: # pragma: no cover if not isinstance(other, IA5String): return NotImplemented return self.value == other.value def __hash__(self) -> int: return hash(self.value) @staticmethod def encode(value: object) -> bytes: """Encode a DER IA5 string""" # ASN.1 defines this type as only containing ASCII characters, but # some tools expecting ASN.1 allow IA5Strings to contain other # characters, so we leave it up to the caller to pass in a byte # string which has already done the appropriate encoding of any # non-ASCII characters. return cast('IA5String', value).value @classmethod def decode(cls, constructed: bool, content: bytes) -> 'IA5String': """Decode a DER IA5 string""" if constructed: raise ASN1DecodeError('IA5 STRING should not be constructed') # As noted in the encode method above, the decoded value for this # type is a byte string, leaving the decoding of any non-ASCII # characters up to the caller. return cls(content) @DERTag(OBJECT_IDENTIFIER) class ObjectIdentifier(DERType): """An object identifier (OID) value This object can be initialized from a string of dot-separated integer values, representing a hierarchical namespace. All OIDs show have at least two components, with the first being between 0 and 2 (indicating ITU-T, ISO, or joint assignment). In cases where the first component is 0 or 1, the second component must be in the range 0 to 39 due to the way these first two components are encoded. """ def __init__(self, value: str): self.value = value def __str__(self) -> str: return self.value def __repr__(self) -> str: return f"ObjectIdentifier('{self.value}')" def __eq__(self, other: object) -> bool: if not isinstance(other, ObjectIdentifier): # pragma: no cover return NotImplemented return self.value == other.value def __hash__(self) -> int: return hash(self.value) @staticmethod def encode(value: object) -> bytes: """Encode a DER object identifier""" def _bytes(component: int) -> bytes: """Convert a single element of an OID to a DER byte string""" if component < 0: raise ASN1EncodeError('Components of object identifier must ' 'be greater than or equal to 0') result = [component & 0x7f] while component >= 0x80: component >>= 7 result.append(0x80 | (component & 0x7f)) return bytes(result[::-1]) oid_value = cast('ObjectIdentifier', value) try: components = [int(c) for c in oid_value.value.split('.')] except ValueError: raise ASN1EncodeError('Component values must be ' 'integers') from None if len(components) < 2: raise ASN1EncodeError('Object identifiers must have at least two ' 'components') elif components[0] < 0 or components[0] > 2: raise ASN1EncodeError('First component of object identifier must ' 'be between 0 and 2') elif components[0] < 2 and (components[1] < 0 or components[1] > 39): raise ASN1EncodeError('Second component of object identifier must ' 'be between 0 and 39') components[0:2] = [components[0]*40 + components[1]] return b''.join(_bytes(c) for c in components) @classmethod def decode(cls, constructed: bool, content: bytes) -> 'ObjectIdentifier': """Decode a DER object identifier""" if constructed: raise ASN1DecodeError('OBJECT IDENTIFIER should not be ' 'constructed') if not content: raise ASN1DecodeError('Empty object identifier') b = content[0] components = list(divmod(b, 40)) if b < 80 else [2, b-80] component = 0 for b in content[1:]: if b == 0x80 and component == 0: raise ASN1DecodeError('Invalid component') elif b < 0x80: components.append(component | b) component = 0 else: component |= b & 0x7f component <<= 7 if component: raise ASN1DecodeError('Incomplete component') return cls('.'.join(str(c) for c in components)) def der_encode(value: object) -> bytes: """Encode a value in DER format This function takes a Python value and encodes it in DER format. The following mapping of types is used: NoneType -> NULL bool -> BOOLEAN int -> INTEGER bytes, bytearray -> OCTET STRING str -> UTF8 STRING list, tuple -> SEQUENCE set, frozenset -> SET BitString -> BIT STRING ObjectIdentifier -> OBJECT IDENTIFIER An explicitly tagged DER object can be encoded by passing in a TaggedDERObject which specifies the ASN.1 class, tag, and value to encode. Other types can be encoded by passing in a RawDERObject which specifies the ASN.1 class, tag, and raw content octets to encode. """ t = type(value) if t in (RawDERObject, TaggedDERObject): value = cast(Union[RawDERObject, TaggedDERObject], value) identifier = value.encode_identifier() content = value.encode(value) elif t in _der_class_by_type: cls = _der_class_by_type[t] identifier = cls.identifier content = cls.encode(value) else: raise ASN1EncodeError(f'Cannot DER encode type {t.__name__}') length = len(content) if length < 0x80: len_bytes = bytes((length,)) else: len_bytes = length.to_bytes((length.bit_length() + 7) // 8, 'big') len_bytes = bytes((0x80 | len(len_bytes),)) + len_bytes return identifier + len_bytes + content def der_decode_partial(data: bytes) -> Tuple[object, int]: """Decode a value in DER format and return the number of bytes consumed""" if len(data) < 2: raise ASN1DecodeError('Incomplete data') tag = data[0] asn1_class, constructed, tag = tag >> 6, bool(tag & 0x20), tag & 0x1f offset = 1 if tag == 0x1f: tag = 0 for b in data[offset:]: offset += 1 if b < 0x80: tag |= b break else: tag |= b & 0x7f tag <<= 7 else: raise ASN1DecodeError('Incomplete tag') if offset >= len(data): raise ASN1DecodeError('Incomplete data') length = data[offset] offset += 1 if length > 0x80: len_size = length & 0x7f length = int.from_bytes(data[offset:offset+len_size], 'big') offset += len_size elif length == 0x80: raise ASN1DecodeError('Indefinite length not allowed') end = offset + length content = data[offset:end] if end > len(data): raise ASN1DecodeError('Incomplete data') if asn1_class == UNIVERSAL and tag in _der_class_by_tag: cls = _der_class_by_tag[tag] value = cls.decode(constructed, content) elif constructed: value = TaggedDERObject(tag, der_decode(content), asn1_class) else: value = RawDERObject(tag, content, asn1_class) return value, end def der_decode(data: bytes) -> object: """Decode a value in DER format This function takes a byte string in DER format and converts it to a corresponding set of Python objects. The following mapping of ASN.1 tags to Python types is used: NULL -> NoneType BOOLEAN -> bool INTEGER -> int OCTET STRING -> bytes UTF8 STRING -> str SEQUENCE -> tuple SET -> frozenset BIT_STRING -> BitString OBJECT IDENTIFIER -> ObjectIdentifier Explicitly tagged objects are returned as type TaggedDERObject, with fields holding the object class, tag, and tagged value. Other object tags are returned as type RawDERObject, with fields holding the object class, tag, and raw content octets. If partial_ok is True, this function returns a tuple of the decoded value and number of bytes consumed. Otherwise, all data bytes must be consumed and only the decoded value is returned. """ value, end = der_decode_partial(data) if end < len(data): raise ASN1DecodeError('Data contains unexpected bytes at end') return value asyncssh-2.20.0/asyncssh/auth.py000066400000000000000000001015351475467777400166350ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH authentication handlers""" from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional from typing import Sequence, Tuple, Type, Union, cast from .constants import DEFAULT_LANG from .gss import GSSBase, GSSError from .logging import SSHLogger from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names from .misc import run_in_executor from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler from .public_key import SigningKey from .saslprep import saslprep, SASLPrepError if TYPE_CHECKING: import asyncio # pylint: disable=cyclic-import from .connection import SSHConnection, SSHClientConnection from .connection import SSHServerConnection KbdIntPrompts = Sequence[Tuple[str, bool]] KbdIntNewChallenge = Tuple[str, str, str, KbdIntPrompts] KbdIntChallenge = Union[bool, KbdIntNewChallenge] KbdIntResponse = Sequence[str] PasswordChangeResponse = Tuple[str, str] # SSH message values for GSS auth MSG_USERAUTH_GSSAPI_RESPONSE = 60 MSG_USERAUTH_GSSAPI_TOKEN = 61 MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = 63 MSG_USERAUTH_GSSAPI_ERROR = 64 MSG_USERAUTH_GSSAPI_ERRTOK = 65 MSG_USERAUTH_GSSAPI_MIC = 66 # SSH message values for public key auth MSG_USERAUTH_PK_OK = 60 # SSH message values for keyboard-interactive auth MSG_USERAUTH_INFO_REQUEST = 60 MSG_USERAUTH_INFO_RESPONSE = 61 # SSH message values for password auth MSG_USERAUTH_PASSWD_CHANGEREQ = 60 _auth_methods: List[bytes] = [] _client_auth_handlers: Dict[bytes, Type['ClientAuth']] = {} _server_auth_handlers: Dict[bytes, Type['ServerAuth']] = {} class Auth(SSHPacketHandler): """Parent class for authentication""" def __init__(self, conn: 'SSHConnection', coro: Awaitable[None]): self._conn = conn self._logger = conn.logger self._coro: Optional['asyncio.Task[None]'] = conn.create_task(coro) def send_packet(self, pkttype: int, *args: bytes, trivial: bool = True) -> None: """Send an auth packet""" self._conn.send_userauth_packet(pkttype, *args, handler=self, trivial=trivial) @property def logger(self) -> SSHLogger: """A logger associated with this authentication handler""" return self._logger def create_task(self, coro: Awaitable[None]) -> None: """Create an asynchronous auth task""" self.cancel() self._coro = self._conn.create_task(coro) def cancel(self) -> None: """Cancel any authentication in progress""" if self._coro: # pragma: no branch self._coro.cancel() self._coro = None class ClientAuth(Auth): """Parent class for client authentication""" _conn: 'SSHClientConnection' def __init__(self, conn: 'SSHClientConnection', method: bytes): self._method = method super().__init__(conn, self._start()) async def _start(self) -> None: """Abstract method for starting client authentication""" # Provided by subclass raise NotImplementedError def auth_succeeded(self) -> None: """Callback when auth succeeds""" def auth_failed(self) -> None: """Callback when auth fails""" async def send_request(self, *args: bytes, key: Optional[SigningKey] = None, trivial: bool = True) -> None: """Send a user authentication request""" await self._conn.send_userauth_request(self._method, *args, key=key, trivial=trivial) class _ClientNullAuth(ClientAuth): """Client side implementation of null auth""" async def _start(self) -> None: """Start client null authentication""" await self.send_request() class _ClientGSSKexAuth(ClientAuth): """Client side implementation of GSS key exchange auth""" async def _start(self) -> None: """Start client GSS key exchange authentication""" if self._conn.gss_kex_auth_requested(): self.logger.debug1('Trying GSS key exchange auth') await self.send_request(key=self._conn.get_gss_context(), trivial=False) else: self._conn.try_next_auth(next_method=True) class _ClientGSSMICAuth(ClientAuth): """Client side implementation of GSS MIC auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_GSSAPI_') def __init__(self, conn: 'SSHClientConnection', method: bytes): super().__init__(conn, method) self._gss: Optional[GSSBase] = None self._got_error = False async def _start(self) -> None: """Start client GSS MIC authentication""" if self._conn.gss_mic_auth_requested(): self.logger.debug1('Trying GSS MIC auth') self._gss = self._conn.get_gss_context() self._gss.reset() mechs = b''.join(String(mech) for mech in self._gss.mechs) await self.send_request(UInt32(len(self._gss.mechs)), mechs) else: self._conn.try_next_auth(next_method=True) def _finish(self) -> None: """Finish client GSS MIC authentication""" assert self._gss is not None if self._gss.provides_integrity: data = self._conn.get_userauth_request_data(self._method) self.send_packet(MSG_USERAUTH_GSSAPI_MIC, String(self._gss.sign(data)), trivial=False) else: self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) async def _process_response(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS response from the server""" mech = packet.get_string() packet.check_end() assert self._gss is not None if mech not in self._gss.mechs: raise ProtocolError('Mechanism mismatch') try: token = await run_in_executor(self._gss.step) assert token is not None self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) if self._gss.complete: self._finish() except GSSError as exc: if exc.token: self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token)) self._conn.try_next_auth(next_method=True) async def _process_token(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS token from the server""" token: Optional[bytes] = packet.get_string() packet.check_end() assert self._gss is not None try: token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) if self._gss.complete: self._finish() except GSSError as exc: if exc.token: self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token)) self._conn.try_next_auth(next_method=True) def _process_error(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS error from the server""" _ = packet.get_uint32() # major_status _ = packet.get_uint32() # minor_status msg = packet.get_string() _ = packet.get_string() # lang packet.check_end() self.logger.debug1('GSS error from server: %s', msg) self._got_error = True async def _process_error_token(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS error token from the server""" token = packet.get_string() packet.check_end() assert self._gss is not None try: await run_in_executor(self._gss.step, token) except GSSError as exc: if not self._got_error: # pragma: no cover self.logger.debug1('GSS error from server: %s', str(exc)) _packet_handlers = { MSG_USERAUTH_GSSAPI_RESPONSE: _process_response, MSG_USERAUTH_GSSAPI_TOKEN: _process_token, MSG_USERAUTH_GSSAPI_ERROR: _process_error, MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token } class _ClientHostBasedAuth(ClientAuth): """Client side implementation of host based auth""" async def _start(self) -> None: """Start client host based authentication""" keypair, client_host, client_username = \ await self._conn.host_based_auth_requested() if keypair is None: self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying host based auth of user %s on host %s ' 'with %s host key', client_username, client_host, keypair.algorithm) try: await self.send_request(String(keypair.algorithm), String(keypair.public_data), String(client_host), String(client_username), key=keypair) except ValueError as exc: self.logger.debug1('Host based auth failed: %s', str(exc)) self._conn.try_next_auth() class _ClientPublicKeyAuth(ClientAuth): """Client side implementation of public key auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_PK_') async def _start(self) -> None: """Start client public key authentication""" self._keypair = await self._conn.public_key_auth_requested() if self._keypair is None: self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying public key auth with %s key', self._keypair.algorithm) await self.send_request(Boolean(False), String(self._keypair.algorithm), String(self._keypair.public_data)) async def _send_signed_request(self) -> None: """Send signed public key request""" assert self._keypair is not None self.logger.debug1('Signing request with %s key', self._keypair.algorithm) try: await self.send_request(Boolean(True), String(self._keypair.algorithm), String(self._keypair.public_data), key=self._keypair, trivial=False) except ValueError as exc: self.logger.debug1('Public key auth failed: %s', str(exc)) self._conn.try_next_auth() def _process_public_key_ok(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a public key ok response""" algorithm = packet.get_string() key_data = packet.get_string() packet.check_end() assert self._keypair is not None if (algorithm != self._keypair.algorithm or key_data != self._keypair.public_data): raise ProtocolError('Key mismatch') self.create_task(self._send_signed_request()) _packet_handlers = { MSG_USERAUTH_PK_OK: _process_public_key_ok } class _ClientKbdIntAuth(ClientAuth): """Client side implementation of keyboard-interactive auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_INFO_') async def _start(self) -> None: """Start client keyboard interactive authentication""" submethods = await self._conn.kbdint_auth_requested() if submethods is None: self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying keyboard-interactive auth') await self.send_request(String(''), String(submethods)) async def _receive_challenge(self, name: str, instruction: str, lang: str, prompts: KbdIntPrompts) -> None: """Receive and respond to a keyboard interactive challenge""" responses = \ await self._conn.kbdint_challenge_received(name, instruction, lang, prompts) if responses is None: self._conn.try_next_auth(next_method=True) return self.send_packet(MSG_USERAUTH_INFO_RESPONSE, UInt32(len(responses)), b''.join(String(r) for r in responses), trivial=not responses) def _process_info_request(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a keyboard interactive authentication request""" name_bytes = packet.get_string() instruction_bytes = packet.get_string() lang_bytes = packet.get_string() try: name = name_bytes.decode('utf-8') instruction = instruction_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid keyboard interactive ' 'info request') from None num_prompts = packet.get_uint32() prompts = [] for _ in range(num_prompts): prompt_bytes = packet.get_string() echo = packet.get_boolean() try: prompt = prompt_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid keyboard interactive ' 'info request') from None prompts.append((prompt, echo)) self.create_task(self._receive_challenge(name, instruction, lang, prompts)) _packet_handlers = { MSG_USERAUTH_INFO_REQUEST: _process_info_request } class _ClientPasswordAuth(ClientAuth): """Client side implementation of password auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_PASSWD_') def __init__(self, conn: 'SSHClientConnection', method: bytes): super().__init__(conn, method) self._password_change = False async def _start(self) -> None: """Start client password authentication""" password = await self._conn.password_auth_requested() if password is None: self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying password auth') await self.send_request(Boolean(False), String(password), trivial=False) async def _change_password(self, prompt: str, lang: str) -> None: """Start password change""" result = await self._conn.password_change_requested(prompt, lang) if result == NotImplemented: # Password change not supported - move on to the next auth method self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying to chsnge password') old_password, new_password = cast(PasswordChangeResponse, result) self._password_change = True await self.send_request(Boolean(True), String(old_password.encode('utf-8')), String(new_password.encode('utf-8')), trivial=False) def auth_succeeded(self) -> None: if self._password_change: self._password_change = False self._conn.password_changed() def auth_failed(self) -> None: if self._password_change: self._password_change = False self._conn.password_change_failed() def _process_password_change(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a password change request""" prompt_bytes = packet.get_string() lang_bytes = packet.get_string() try: prompt = prompt_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid password change request') from None self.auth_failed() self.create_task(self._change_password(prompt, lang)) _packet_handlers = { MSG_USERAUTH_PASSWD_CHANGEREQ: _process_password_change } class ServerAuth(Auth): """Parent class for server authentication""" _conn: 'SSHServerConnection' def __init__(self, conn: 'SSHServerConnection', username: str, method: bytes, packet: SSHPacket): self._username = username self._method = method super().__init__(conn, self._start(packet)) @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether this authentication method is supported""" raise NotImplementedError async def _start(self, packet: SSHPacket) -> None: """Abstract method for starting server authentication""" # Provided by subclass raise NotImplementedError def send_failure(self, partial_success: bool = False) -> None: """Send a user authentication failure response""" self._conn.send_userauth_failure(partial_success) def send_success(self) -> None: """Send a user authentication success response""" self._conn.send_userauth_success() class _ServerNullAuth(ServerAuth): """Server side implementation of null auth""" @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return that null authentication is never a supported auth mode""" return False async def _start(self, packet: SSHPacket) -> None: """Supported always returns false, so we never get here""" class _ServerGSSKexAuth(ServerAuth): """Server side implementation of GSS key exchange auth""" def __init__(self, conn: 'SSHServerConnection', username: str, method: bytes, packet: SSHPacket): super().__init__(conn, username, method, packet) self._gss = conn.get_gss_context() @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether GSS key exchange authentication is supported""" return conn.gss_kex_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server GSS key exchange authentication""" mic = packet.get_string() packet.check_end() self.logger.debug1('Trying GSS key exchange auth') data = self._conn.get_userauth_request_data(self._method) if (self._gss.complete and self._gss.verify(data, mic) and (await self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host))): self.send_success() else: self.send_failure() class _ServerGSSMICAuth(ServerAuth): """Server side implementation of GSS MIC auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_GSSAPI_') def __init__(self, conn: 'SSHServerConnection', username: str, method: bytes, packet: SSHPacket) -> None: super().__init__(conn, username, method, packet) self._gss = conn.get_gss_context() @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether GSS MIC authentication is supported""" return conn.gss_mic_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server GSS MIC authentication""" mechs = set() n = packet.get_uint32() for _ in range(n): mechs.add(packet.get_string()) packet.check_end() match = None for mech in self._gss.mechs: if mech in mechs: match = mech break if not match: self.send_failure() return self.logger.debug1('Trying GSS MIC auth') self._gss.reset() self.send_packet(MSG_USERAUTH_GSSAPI_RESPONSE, String(match)) async def _finish(self) -> None: """Finish server GSS MIC authentication""" if (await self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host)): self.send_success() else: self.send_failure() async def _process_token(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS token from the client""" token: Optional[bytes] = packet.get_string() packet.check_end() try: token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) except GSSError as exc: self.send_packet(MSG_USERAUTH_GSSAPI_ERROR, UInt32(exc.maj_code), UInt32(exc.min_code), String(str(exc)), String(DEFAULT_LANG)) if exc.token: self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token)) self.send_failure() def _process_exchange_complete(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS exchange complete message from the client""" packet.check_end() if self._gss.complete and not self._gss.provides_integrity: self.create_task(self._finish()) else: self.send_failure() async def _process_error_token(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS error token from the client""" token = packet.get_string() packet.check_end() try: await run_in_executor(self._gss.step, token) except GSSError as exc: self.logger.debug1('GSS error from client: %s', str(exc)) def _process_mic(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS MIC from the client""" mic = packet.get_string() packet.check_end() data = self._conn.get_userauth_request_data(self._method) if (self._gss.complete and self._gss.provides_integrity and self._gss.verify(data, mic)): self.create_task(self._finish()) else: self.send_failure() _packet_handlers = { MSG_USERAUTH_GSSAPI_TOKEN: _process_token, MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: _process_exchange_complete, MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token, MSG_USERAUTH_GSSAPI_MIC: _process_mic } class _ServerHostBasedAuth(ServerAuth): """Server side implementation of host based auth""" @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether host based authentication is supported""" return conn.host_based_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server host based authentication""" algorithm = packet.get_string() key_data = packet.get_string() client_host_bytes = packet.get_string() client_username_bytes = packet.get_string() msg = packet.get_consumed_payload() signature = packet.get_string() packet.check_end() try: client_host = client_host_bytes.decode('utf-8') client_username = saslprep(client_username_bytes.decode('utf-8')) except (UnicodeDecodeError, SASLPrepError): raise ProtocolError('Invalid host-based auth request') from None self.logger.debug1('Verifying host based auth of user %s ' 'on host %s with %s host key', client_username, client_host, algorithm) if (await self._conn.validate_host_based_auth(self._username, key_data, client_host, client_username, msg, signature)): self.send_success() else: self.send_failure() class _ServerPublicKeyAuth(ServerAuth): """Server side implementation of public key auth""" @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether public key authentication is supported""" return conn.public_key_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server public key authentication""" sig_present = packet.get_boolean() algorithm = packet.get_string() key_data = packet.get_string() if sig_present: msg = packet.get_consumed_payload() signature = packet.get_string() else: msg = b'' signature = b'' packet.check_end() if sig_present: self.logger.debug1('Verifying request with %s key', algorithm) else: self.logger.debug1('Trying public key auth with %s key', algorithm) if (await self._conn.validate_public_key(self._username, key_data, msg, signature)): if sig_present: self.send_success() else: self.send_packet(MSG_USERAUTH_PK_OK, String(algorithm), String(key_data)) else: self.send_failure() class _ServerKbdIntAuth(ServerAuth): """Server side implementation of keyboard-interactive auth""" _handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_INFO_') @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether keyboard interactive authentication is supported""" return conn.kbdint_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server keyboard interactive authentication""" lang_bytes = packet.get_string() submethods_bytes = packet.get_string() packet.check_end() try: lang = lang_bytes.decode('ascii') submethods = submethods_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid keyboard interactive ' 'auth request') from None self.logger.debug1('Trying keyboard-interactive auth') challenge = await self._conn.get_kbdint_challenge(self._username, lang, submethods) self._send_challenge(challenge) def _send_challenge(self, challenge: KbdIntChallenge) -> None: """Send a keyboard interactive authentication request""" if isinstance(challenge, (tuple, list)): name, instruction, lang, prompts = challenge num_prompts = len(prompts) prompts_bytes = (String(prompt) + Boolean(echo) for prompt, echo in prompts) self.send_packet(MSG_USERAUTH_INFO_REQUEST, String(name), String(instruction), String(lang), UInt32(num_prompts), *prompts_bytes) elif challenge: self.send_success() else: self.send_failure() async def _validate_response(self, responses: KbdIntResponse) -> None: """Validate a keyboard interactive authentication response""" next_challenge = \ await self._conn.validate_kbdint_response(self._username, responses) self._send_challenge(next_challenge) def _process_info_response(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a keyboard interactive authentication response""" num_responses = packet.get_uint32() responses = [] for _ in range(num_responses): response_bytes = packet.get_string() try: response = response_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid keyboard interactive ' 'info response') from None responses.append(response) packet.check_end() self.create_task(self._validate_response(responses)) _packet_handlers = { MSG_USERAUTH_INFO_RESPONSE: _process_info_response } class _ServerPasswordAuth(ServerAuth): """Server side implementation of password auth""" @classmethod def supported(cls, conn: 'SSHServerConnection') -> bool: """Return whether password authentication is supported""" return conn.password_auth_supported() async def _start(self, packet: SSHPacket) -> None: """Start server password authentication""" password_change = packet.get_boolean() password_bytes = packet.get_string() new_password_bytes = packet.get_string() if password_change else b'' packet.check_end() try: password = saslprep(password_bytes.decode('utf-8')) new_password = saslprep(new_password_bytes.decode('utf-8')) except (UnicodeDecodeError, SASLPrepError): raise ProtocolError('Invalid password auth request') from None try: if password_change: self.logger.debug1('Trying to chsnge password') result = await self._conn.change_password(self._username, password, new_password) else: self.logger.debug1('Trying password auth') result = \ await self._conn.validate_password(self._username, password) if result: self.send_success() else: self.send_failure() except PasswordChangeRequired as exc: self.send_packet(MSG_USERAUTH_PASSWD_CHANGEREQ, String(exc.prompt), String(exc.lang)) def register_auth_method(alg: bytes, client_handler: Type[ClientAuth], server_handler: Type[ServerAuth]) -> None: """Register an authentication method""" _auth_methods.append(alg) _client_auth_handlers[alg] = client_handler _server_auth_handlers[alg] = server_handler def get_supported_client_auth_methods() -> Sequence[bytes]: """Return a list of supported client auth methods""" return [method for method in _client_auth_handlers if method != b'none'] def lookup_client_auth(conn: 'SSHClientConnection', method: bytes) -> Optional[ClientAuth]: """Look up the client authentication method to use""" if method in _auth_methods: return _client_auth_handlers[method](conn, method) else: return None def get_supported_server_auth_methods(conn: 'SSHServerConnection') -> \ Sequence[bytes]: """Return a list of supported server auth methods""" auth_methods = [] for method in _auth_methods: if _server_auth_handlers[method].supported(conn): auth_methods.append(method) return auth_methods def lookup_server_auth(conn: 'SSHServerConnection', username: str, method: bytes, packet: SSHPacket) -> \ Optional[ServerAuth]: """Look up the server authentication method to use""" handler = _server_auth_handlers.get(method) if handler and handler.supported(conn): return handler(conn, username, method, packet) else: conn.send_userauth_failure(False) return None _auth_method_list = ( (b'none', _ClientNullAuth, _ServerNullAuth), (b'gssapi-keyex', _ClientGSSKexAuth, _ServerGSSKexAuth), (b'gssapi-with-mic', _ClientGSSMICAuth, _ServerGSSMICAuth), (b'hostbased', _ClientHostBasedAuth, _ServerHostBasedAuth), (b'publickey', _ClientPublicKeyAuth, _ServerPublicKeyAuth), (b'keyboard-interactive', _ClientKbdIntAuth, _ServerKbdIntAuth), (b'password', _ClientPasswordAuth, _ServerPasswordAuth) ) for _args in _auth_method_list: register_auth_method(*_args) asyncssh-2.20.0/asyncssh/auth_keys.py000066400000000000000000000270661475467777400176760ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Parser for SSH authorized_keys files""" from typing import Dict, List, Mapping, Optional, Sequence from typing import Set, Tuple, Union, cast try: # pylint: disable=unused-import from .crypto import X509Name, X509NamePattern _x509_available = True except ImportError: # pragma: no cover _x509_available = False from .misc import ip_address, read_file from .pattern import HostPatternList, WildcardPatternList from .public_key import KeyImportError, SSHKey from .public_key import SSHX509Certificate, SSHX509CertificateChain from .public_key import import_public_key, import_certificate from .public_key import import_certificate_subject _EntryOptions = Mapping[str, object] class _SSHAuthorizedKeyEntry: """An entry in an SSH authorized_keys list""" def __init__(self, line: str): self.key: Optional[SSHKey] = None self.cert: Optional[SSHX509Certificate] = None self.options: Dict[str, object] = {} try: self._import_key_or_cert(line) return except KeyImportError: pass line = self._parse_options(line) self._import_key_or_cert(line) def _import_key_or_cert(self, line: str) -> None: """Import key or certificate in this entry""" try: self.key = import_public_key(line) return except KeyImportError: pass try: self.cert = cast(SSHX509Certificate, import_certificate(line)) if ('cert-authority' in self.options and self.cert.subject != self.cert.issuer): raise ValueError('X.509 cert-authority entries must ' 'contain a root CA certificate') return except KeyImportError: pass if 'cert-authority' not in self.options: try: self.key = None self.cert = None self._add_subject('subject', import_certificate_subject(line)) return except KeyImportError: pass raise KeyImportError('Unrecognized key, certificate, or subject') def _set_string(self, option: str, value: str) -> None: """Set an option with a string value""" self.options[option] = value def _add_environment(self, option: str, value: str) -> None: """Add an environment key/value pair""" if value.startswith('=') or '=' not in value: raise ValueError('Invalid environment entry in authorized_keys') name, value = value.split('=', 1) cast(Dict[str, str], self.options.setdefault(option, {}))[name] = value def _add_from(self, option: str, value: str) -> None: """Add a from host pattern""" from_patterns = cast(List[HostPatternList], self.options.setdefault(option, [])) from_patterns.append(HostPatternList(value)) def _add_permitopen(self, option: str, value: str) -> None: """Add a permitopen host/port pair""" try: host, port_str = value.rsplit(':', 1) if host.startswith('[') and host.endswith(']'): host = host[1:-1] port = None if port_str == '*' else int(port_str) except ValueError: raise ValueError(f'Illegal permitopen value: {value}') from None permitted_opens = cast(Set[Tuple[str, Optional[int]]], self.options.setdefault(option, set())) permitted_opens.add((host, port)) def _add_principals(self, option: str, value: str) -> None: """Add a principals wildcard pattern list""" principal_patterns = cast(List[WildcardPatternList], self.options.setdefault(option, [])) principal_patterns.append(WildcardPatternList(value)) def _add_subject(self, option: str, value: str) -> None: """Add an X.509 subject pattern""" if _x509_available: # pragma: no branch subject_patterns = cast(List[X509NamePattern], self.options.setdefault(option, [])) subject_patterns.append(X509NamePattern(value)) _handlers = { 'command': _set_string, 'environment': _add_environment, 'from': _add_from, 'permitopen': _add_permitopen, 'principals': _add_principals, 'subject': _add_subject } def _add_option(self) -> None: """Add an option value""" if self._option.startswith('='): raise ValueError('Missing option name in authorized_keys') if '=' in self._option: option, value = self._option.split('=', 1) handler = self._handlers.get(option) if handler: handler(self, option, value) else: values = cast(List[str], self.options.setdefault(option, [])) values.append(value) else: self.options[self._option] = True def _parse_options(self, line: str) -> str: """Parse options in this entry""" self._option = '' idx = 0 quoted = False escaped = False for idx, ch in enumerate(line): if escaped: self._option += ch escaped = False elif ch == '\\': escaped = True elif ch == '"': quoted = not quoted elif quoted: self._option += ch elif ch in ' \t': break elif ch == ',': self._add_option() self._option = '' else: self._option += ch self._add_option() if quoted: raise ValueError('Unbalanced quote in authorized_keys') elif escaped: raise ValueError('Unbalanced backslash in authorized_keys') return line[idx:].strip() def match_options(self, client_host: str, client_addr: str, cert_principals: Optional[Sequence[str]], cert_subject: Optional['X509Name'] = None) -> bool: """Match "from", "principals" and "subject" options in entry""" from_patterns = cast(List[HostPatternList], self.options.get('from')) if from_patterns: client_ip = ip_address(client_addr) if not all(pattern.matches(client_host, client_addr, client_ip) for pattern in from_patterns): return False principal_patterns = cast(List[WildcardPatternList], self.options.get('principals')) if cert_principals is not None and principal_patterns is not None: if not all(any(pattern.matches(principal) for principal in cert_principals) for pattern in principal_patterns): return False subject_patterns = cast(List['X509NamePattern'], self.options.get('subject')) if cert_subject is not None and subject_patterns is not None: if not all(pattern.matches(cert_subject) for pattern in subject_patterns): return False return True class SSHAuthorizedKeys: """An SSH authorized keys list""" def __init__(self, authorized_keys: Optional[str] = None): self._user_entries: List[_SSHAuthorizedKeyEntry] = [] self._ca_entries: List[_SSHAuthorizedKeyEntry] = [] self._x509_entries: List[_SSHAuthorizedKeyEntry] = [] if authorized_keys: self.load(authorized_keys) def load(self, authorized_keys: str) -> None: """Load authorized keys data into this object""" for line in authorized_keys.splitlines(): line = line.strip() if not line or line.startswith('#'): continue try: entry = _SSHAuthorizedKeyEntry(line) except KeyImportError: continue if entry.key: if 'cert-authority' in entry.options: self._ca_entries.append(entry) else: self._user_entries.append(entry) else: self._x509_entries.append(entry) if (not self._user_entries and not self._ca_entries and not self._x509_entries): raise ValueError('No valid entries found') def validate(self, key: SSHKey, client_host: str, client_addr: str, cert_principals: Optional[Sequence[str]] = None, ca: bool = False) -> Optional[Mapping[str, object]]: """Return whether a public key or CA is valid for authentication""" for entry in self._ca_entries if ca else self._user_entries: if (entry.key == key and entry.match_options(client_host, client_addr, cert_principals)): return entry.options return None def validate_x509(self, cert: SSHX509CertificateChain, client_host: str, client_addr: str) -> Tuple[Optional[_EntryOptions], Optional[SSHX509Certificate]]: """Return whether an X.509 certificate is valid for authentication""" for entry in self._x509_entries: if (entry.cert and 'cert-authority' not in entry.options and (cert.key != entry.cert.key or cert.subject != entry.cert.subject)): continue # pragma: no cover (work around bug in coverage tool) if entry.match_options(client_host, client_addr, cert.user_principals, cert.subject): return entry.options, entry.cert return None, None def import_authorized_keys(data: str) -> SSHAuthorizedKeys: """Import SSH authorized keys This function imports public keys and associated options in OpenSSH authorized keys format. :param data: The key data to import. :type data: `str` :returns: An :class:`SSHAuthorizedKeys` object """ return SSHAuthorizedKeys(data) def read_authorized_keys(filelist: Union[str, Sequence[str]]) -> \ SSHAuthorizedKeys: """Read SSH authorized keys from a file or list of files This function reads public keys and associated options in OpenSSH authorized_keys format from a file or list of files. :param filelist: The file or list of files to read the keys from. :type filenlist: `str` or `list` of `str` :returns: An :class:`SSHAuthorizedKeys` object """ authorized_keys = SSHAuthorizedKeys() if isinstance(filelist, str): files: Sequence[str] = [filelist] else: files = filelist for filename in files: authorized_keys.load(read_file(filename, 'r')) return authorized_keys asyncssh-2.20.0/asyncssh/channel.py000066400000000000000000002366451475467777400173170ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH channel and session handlers""" import asyncio import binascii import codecs import inspect import re import signal as _signal from types import MappingProxyType from typing import TYPE_CHECKING, Any, AnyStr, Awaitable, Callable from typing import Dict, Generic, Iterable, List, Mapping, Optional from typing import Set, Tuple, Union, cast from . import constants from .constants import DEFAULT_LANG, EXTENDED_DATA_STDERR from .constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_WINDOW_ADJUST from .constants import MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA from .constants import MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST from .constants import MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE from .constants import OPEN_CONNECT_FAILED, PTY_OP_RESERVED, PTY_OP_END from .constants import OPEN_REQUEST_X11_FORWARDING_FAILED from .constants import OPEN_REQUEST_PTY_FAILED, OPEN_REQUEST_SESSION_FAILED from .editor import SSHLineEditorChannel, SSHLineEditorSession from .logging import SSHLogger from .misc import ChannelOpenError, EnvMap, MaybeAwait, ProtocolError from .misc import TermModes, TermSize, TermSizeArg from .misc import decode_env, encode_env, get_symbol_names, map_handler_name from .packet import Boolean, Byte, String, UInt32, SSHPacket, SSHPacketHandler from .session import SSHSession, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHSessionFactory, SSHClientSessionFactory from .session import SSHTCPSessionFactory, SSHUNIXSessionFactory from .session import SSHTunTapSessionFactory from .stream import DataType from .tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_UNIT_ANY from .tuntap import SSH_TUN_AF_INET, SSH_TUN_AF_INET6 if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection, SSHClientConnection from .connection import SSHServerConnection _const_dict: Mapping[str, int] = constants.__dict__ _pty_mode_names = get_symbol_names(_const_dict, 'PTY_', 4) _data_type_names = get_symbol_names(_const_dict, 'EXTENDED_DATA_', 14) _signal_regex = re.compile(r'SIG[^_]') _signal_numbers = {k[3:]: int(v) for (k, v) in vars(_signal).items() if _signal_regex.match(k)} _signal_names = {v: k for (k, v) in _signal_numbers.items()} _ExitSignal = Tuple[str, bool, str, str] _RequestHandler = Optional[Callable[[SSHPacket], Optional[bool]]] class SSHChannel(Generic[AnyStr], SSHPacketHandler): """Parent class for SSH channels""" _handler_names = get_symbol_names(globals(), 'MSG_CHANNEL_') _read_datatypes: Set[int] = set() _write_datatypes: Set[int] = set() def __init__(self, conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, encoding: Optional[str], errors: str, window: int, max_pktsize: int): """Initialize an SSH channel If encoding is set, data sent and received will be in the form of strings, converted on the wire to bytes using the specified encoding. If encoding is None, data sent and received must be provided as bytes. Window specifies the initial receive window size. Max_pktsize specifies the maximum length of a single data packet. """ self._conn: Optional['SSHConnection'] = conn self._loop = loop self._session: Optional[SSHSession[AnyStr]] = None self._extra: Dict[str, object] = {'connection': conn} self._encoding: Optional[str] self._errors: str self._send_high_water: int self._send_low_water: int self._env: Dict[bytes, bytes] = {} self._str_env: Optional[Dict[str, str]] = None self._command: Optional[str] = None self._subsystem: Optional[str] = None self._send_state = 'closed' self._send_chan: Optional[int] = None self._send_window: int = 0 self._send_pktsize: int = 0 self._send_paused = False self._send_buf: List[Tuple[bytearray, DataType]] = [] self._send_buf_len = 0 self._recv_state = 'closed' self._init_recv_window = window self._recv_window = window self._recv_pktsize = max_pktsize self._recv_paused: Union[bool, str] = 'starting' self._recv_buf: List[Tuple[bytes, DataType]] = [] self._request_queue: List[Tuple[str, SSHPacket, bool]] = [] self._open_waiter: Optional[asyncio.Future[SSHPacket]] = None self._request_waiters: List[asyncio.Future[bool]] = [] self._close_event = asyncio.Event() self._recv_chan: Optional[int] = conn.add_channel(self) self._logger = conn.logger.get_child(context=f'chan={self._recv_chan}') self.set_encoding(encoding, errors) self.set_write_buffer_limits() @property def logger(self) -> SSHLogger: """A logger associated with this channel""" return self._logger def get_connection(self) -> 'SSHConnection': """Return the connection used by this channel""" assert self._conn is not None return self._conn def get_loop(self) -> asyncio.AbstractEventLoop: """Return the event loop used by this channel""" return self._loop def get_encoding(self) -> Tuple[Optional[str], str]: """Return the encoding used by this channel""" return self._encoding, self._errors def set_encoding(self, encoding: Optional[str], errors: str = 'strict') -> None: """Set the encoding on this channel""" self._encoding = encoding self._errors = errors if encoding: self._encoder: Optional[codecs.IncrementalEncoder] = \ codecs.getincrementalencoder(encoding)(errors) self._decoder: Optional[codecs.IncrementalDecoder] = \ codecs.getincrementaldecoder(encoding)(errors) else: self._encoder = None self._decoder = None def get_recv_window(self) -> int: """Return the configured receive window for this channel""" return self._init_recv_window def get_read_datatypes(self) -> Set[int]: """Return the legal read data types for this channel""" return self._read_datatypes def get_write_datatypes(self) -> Set[int]: """Return the legal write data types for this channel""" return self._write_datatypes def _cleanup(self, exc: Optional[Exception] = None) -> None: """Clean up this channel""" if self._open_waiter: if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_exception( ChannelOpenError(OPEN_CONNECT_FAILED, 'SSH connection closed')) self._open_waiter = None if self._request_waiters: for waiter in self._request_waiters: if not waiter.cancelled(): # pragma: no cover if exc: waiter.set_exception(exc) else: waiter.set_result(False) self._request_waiters = [] if self._session is not None: self._session.connection_lost(exc) self._session = None self._close_event.set() if self._conn: # pragma: no branch self.logger.info('Channel closed%s', ': ' + str(exc) if exc else '') self._conn.detach_x11_listener(self) assert self._recv_chan is not None self._conn.remove_channel(self._recv_chan) self._send_chan = None self._recv_chan = None self._conn = None def _close_send(self) -> None: """Discard unsent data and close the channel for sending""" # Discard unsent data self._send_buf = [] self._send_buf_len = 0 if self._send_state != 'closed': self.send_packet(MSG_CHANNEL_CLOSE) self._send_chan = None self._send_state = 'closed' def _discard_recv(self) -> None: """Discard unreceived data and clean up if close received""" # Discard unreceived data self._recv_buf = [] self._recv_paused = False # If recv is close_pending, we know send is already closed if self._recv_state == 'close_pending': self._recv_state = 'closed' self._loop.call_soon(self._cleanup) async def _start_reading(self) -> None: """Start processing data on a new connection""" # If owner of the channel didn't explicitly pause it at # startup, begin processing incoming data. if self._recv_paused == 'starting': self.logger.debug2('Reading from channel started') self._recv_paused = False self._flush_recv_buf() def _pause_resume_writing(self) -> None: """Pause or resume writing based on send buffer low/high water marks""" if self._send_paused: if self._send_buf_len <= self._send_low_water: self.logger.debug2('Writing from session resumed') self._send_paused = False assert self._session is not None self._session.resume_writing() else: if self._send_buf_len > self._send_high_water: self.logger.debug2('Writing from session paused') self._send_paused = True assert self._session is not None self._session.pause_writing() def _flush_send_buf(self) -> None: """Flush as much data in send buffer as the send window allows""" while self._send_buf and self._send_window: pktsize = min(self._send_window, self._send_pktsize) buf, datatype = self._send_buf[0] if len(buf) > pktsize: data = buf[:pktsize] del buf[:pktsize] else: data = buf del self._send_buf[0] self._send_buf_len -= len(data) self._send_window -= len(data) if datatype is None: self.send_packet(MSG_CHANNEL_DATA, String(data)) else: self.send_packet(MSG_CHANNEL_EXTENDED_DATA, UInt32(datatype), String(data)) self._pause_resume_writing() if not self._send_buf: if self._send_state == 'eof_pending': self.send_packet(MSG_CHANNEL_EOF) self._send_state = 'eof' elif self._send_state == 'close_pending': self._close_send() def _flush_recv_buf(self, exc: Optional[Exception] = None) -> None: """Flush as much data in the recv buffer as the application allows""" while self._recv_buf and not self._recv_paused: self._deliver_data(*self._recv_buf.pop(0)) if not self._recv_buf and self._recv_paused != 'starting': if self._encoding and not exc and \ self._recv_state in ('eof_pending', 'close_pending'): try: assert self._decoder is not None self._decoder.decode(b'', True) except UnicodeDecodeError as unicode_exc: raise ProtocolError(str(unicode_exc)) from None if self._recv_state == 'eof_pending': self._recv_state = 'eof' assert self._session is not None if (not self._session.eof_received() and self._send_state == 'open'): self.write_eof() if not self._recv_buf and self._recv_state == 'close_pending': self._recv_state = 'closed' self._loop.call_soon(self._cleanup, exc) def _deliver_data(self, data: bytes, datatype: DataType) -> None: """Deliver incoming data to the session""" self._recv_window -= len(data) if self._recv_window < self._init_recv_window / 2: adjust = self._init_recv_window - self._recv_window self.logger.debug2('Sending window adjust of %d bytes, ' 'new window %d', adjust, self._init_recv_window) self.send_packet(MSG_CHANNEL_WINDOW_ADJUST, UInt32(adjust)) self._recv_window = self._init_recv_window if self._encoding: try: assert self._decoder is not None decoded_data = cast(AnyStr, self._decoder.decode(data)) except UnicodeDecodeError as unicode_exc: raise ProtocolError(str(unicode_exc)) from None else: decoded_data = cast(AnyStr, data) if self._session is not None: self._session.data_received(decoded_data, datatype) def _accept_data(self, data: bytes, datatype: DataType = None) -> None: """Accept new data on the channel This method accepts new data on the channel, immediately delivering it to the session if it hasn't paused reading. If it has paused, data is buffered until reading is resumed. Data sent after the channel has been closed by the session is dropped. """ if not data: return if self._send_state in {'close_pending', 'closed'}: return if self._recv_paused: self._recv_buf.append((data, datatype)) else: self._deliver_data(data, datatype) def _service_next_request(self) -> None: """Process next item on channel request queue""" request, packet, _ = self._request_queue[0] name = '_process_' + map_handler_name(request) + '_request' handler = cast(_RequestHandler, getattr(self, name, None)) if handler: result = cast(Optional[bool], handler(packet)) else: self.logger.debug1('Received unknown channel request: %s', request) result = False if result is not None: self._report_response(result) def _report_response(self, result: bool) -> None: """Report back the response to a previously issued channel request""" request, _, want_reply = self._request_queue.pop(0) if want_reply and self._send_state not in {'close_pending', 'closed'}: if result: self.send_packet(MSG_CHANNEL_SUCCESS) else: self.send_packet(MSG_CHANNEL_FAILURE) if result and request in {'shell', 'exec', 'subsystem'}: assert self._session is not None self._session.session_started() self.resume_reading() if self._request_queue: self._service_next_request() def process_connection_close(self, exc: Optional[Exception]) -> None: """Process the SSH connection closing""" self.logger.info('Closing channel due to connection close') self._send_state = 'closed' self._close_send() self._cleanup(exc) def process_open(self, send_chan: int, send_window: int, send_pktsize: int, session: MaybeAwait[SSHSession[AnyStr]]) -> None: """Process a channel open request""" self._send_chan = send_chan self._send_window = send_window self._send_pktsize = send_pktsize self.logger.debug2(' Initial send window %d, packet size %d', send_window, send_pktsize) assert self._conn is not None self._conn.create_task(self._finish_open_request(session), self.logger) def _wrap_session(self, session: SSHSession[AnyStr]) -> \ Tuple['SSHChannel[AnyStr]', SSHSession[AnyStr]]: """Hook to optionally wrap channel and session objects""" # By default, return the original channel and session objects return self, session async def _finish_open_request( self, result: MaybeAwait[SSHSession[AnyStr]]) -> None: """Finish processing a channel open request""" try: if inspect.isawaitable(result): session = await cast(Awaitable[SSHSession[AnyStr]], result) else: session = cast(SSHSession[AnyStr], result) if not self._conn: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'SSH connection closed') chan, self._session = self._wrap_session(session) self.logger.debug2(' Initial recv window %d, packet size %d', self._recv_window, self._recv_pktsize) assert self._send_chan is not None assert self._recv_chan is not None self._conn.send_channel_open_confirmation(self._send_chan, self._recv_chan, self._recv_window, self._recv_pktsize) self._send_state = 'open' self._recv_state = 'open' self._session.connection_made(chan) except ChannelOpenError as exc: if self._conn: assert self._send_chan is not None self._conn.send_channel_open_failure(self._send_chan, exc.code, exc.reason, exc.lang) self._loop.call_soon(self._cleanup) def process_open_confirmation(self, send_chan: int, send_window: int, send_pktsize: int, packet: SSHPacket) -> None: """Process a channel open confirmation""" if not self._open_waiter: raise ProtocolError('Channel not being opened') self._send_chan = send_chan self._send_window = send_window self._send_pktsize = send_pktsize self.logger.debug2(' Initial send window %d, packet size %d', send_window, send_pktsize) self._send_state = 'open' self._recv_state = 'open' if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_result(packet) self._open_waiter = None def process_open_failure(self, code: int, reason: str, lang: str) -> None: """Process a channel open failure""" if not self._open_waiter: raise ProtocolError('Channel not being opened') if not self._open_waiter.cancelled(): # pragma: no branch self._open_waiter.set_exception( ChannelOpenError(code, reason, lang)) self._open_waiter = None self._loop.call_soon(self._cleanup) def _process_window_adjust(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a send window adjustment""" if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise ProtocolError('Channel not open') adjust = packet.get_uint32() packet.check_end() self._send_window += adjust self.logger.debug2('Received window adjust of %d bytes, ' 'new window %d', adjust, self._send_window) self._flush_send_buf() def _process_data(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process incoming data""" if self._recv_state != 'open': raise ProtocolError('Channel not open for sending') data = packet.get_string() packet.check_end() datalen = len(data) if datalen > self._recv_window: raise ProtocolError('Window exceeded') self.logger.debug2('Received %d data byte%s', datalen, 's' if datalen > 1 else '') self._accept_data(data) def _process_extended_data(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process incoming extended data""" if self._recv_state != 'open': raise ProtocolError('Channel not open for sending') datatype = packet.get_uint32() data = packet.get_string() packet.check_end() if datatype not in self._read_datatypes: raise ProtocolError('Invalid extended data type') datalen = len(data) if datalen > self._recv_window: raise ProtocolError('Window exceeded') self.logger.debug2('Received %d data byte%s from %s', datalen, 's' if datalen > 1 else '', _data_type_names[datatype]) self._accept_data(data, datatype) def _process_eof(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an incoming end of file""" if self._recv_state != 'open': raise ProtocolError('Channel not open for sending') packet.check_end() self.logger.debug2('Received EOF') self._recv_state = 'eof_pending' self._flush_recv_buf() def _process_close(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an incoming channel close""" if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise ProtocolError('Channel not open') packet.check_end() self.logger.info('Received channel close') self._close_send() self._recv_state = 'close_pending' self._flush_recv_buf() def _process_request(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an incoming channel request""" if self._recv_state not in {'open', 'eof_pending', 'eof'}: raise ProtocolError('Channel not open') request_bytes = packet.get_string() want_reply = packet.get_boolean() try: request = request_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid channel request') from None self._request_queue.append((request, packet, want_reply)) if len(self._request_queue) == 1: self._service_next_request() def _process_response(self, pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a success or failure response""" packet.check_end() if self._request_waiters: waiter = self._request_waiters.pop(0) if not waiter.cancelled(): # pragma: no branch waiter.set_result(pkttype == MSG_CHANNEL_SUCCESS) else: raise ProtocolError('Unexpected channel response') def _process_keepalive_at_openssh_dot_com_request( self, packet: SSHPacket) -> bool: """Process an incoming OpenSSH keepalive request""" packet.check_end() self.logger.debug2('Received OpenSSH keepalive channel request') return False _packet_handlers = { MSG_CHANNEL_WINDOW_ADJUST: _process_window_adjust, MSG_CHANNEL_DATA: _process_data, MSG_CHANNEL_EXTENDED_DATA: _process_extended_data, MSG_CHANNEL_EOF: _process_eof, MSG_CHANNEL_CLOSE: _process_close, MSG_CHANNEL_REQUEST: _process_request, MSG_CHANNEL_SUCCESS: _process_response, MSG_CHANNEL_FAILURE: _process_response } async def _open(self, chantype: bytes, *args: bytes) -> SSHPacket: """Make a request to open the channel""" if self._send_state != 'closed': raise OSError('Channel already open') self._open_waiter = self._loop.create_future() self.logger.debug2(' Initial recv window %d, packet size %d', self._recv_window, self._recv_pktsize) assert self._conn is not None assert self._recv_chan is not None self._conn.send_packet(MSG_CHANNEL_OPEN, String(chantype), UInt32(self._recv_chan), UInt32(self._recv_window), UInt32(self._recv_pktsize), *args, handler=self) return await self._open_waiter def send_packet(self, pkttype: int, *args: bytes) -> None: """Send a packet on the channel""" if self._send_chan is None: # pragma: no cover return payload = UInt32(self._send_chan) + b''.join(args) assert self._conn is not None self._conn.send_packet(pkttype, payload, handler=self) def _send_request(self, request: bytes, *args: bytes, want_reply: bool = False) -> None: """Send a channel request""" self.send_packet(MSG_CHANNEL_REQUEST, String(request), Boolean(want_reply), *args) async def _make_request(self, request: bytes, *args: bytes) -> Optional[bool]: """Make a channel request and wait for the response""" if self._send_chan is None: return False waiter = self._loop.create_future() self._request_waiters.append(waiter) self._send_request(request, *args, want_reply=True) return await waiter def abort(self) -> None: """Forcibly close the channel This method can be called to forcibly close the channel, after which no more data can be sent or received. Any unsent buffered data and any incoming data in flight will be discarded. """ self.logger.info('Aborting channel') if self._send_state not in {'close_pending', 'closed'}: # Send an immediate close, discarding unsent data self._close_send() if self._recv_state != 'closed': # Discard unreceived data self._discard_recv() def close(self) -> None: """Cleanly close the channel This method can be called to cleanly close the channel, after which no more data can be sent or received. Any unsent buffered data will be flushed asynchronously before the channel is closed. """ self.logger.info('Closing channel') if self._send_state not in {'close_pending', 'closed'}: # Send a close only after sending unsent data self._send_state = 'close_pending' self._flush_send_buf() if self._recv_state != 'closed': # Discard unreceived data self._discard_recv() def is_closing(self) -> bool: """Return if the channel is closing or is closed""" return self._send_state != 'open' async def wait_closed(self) -> None: """Wait for this channel to close This method is a coroutine which can be called to block until this channel has finished closing. """ await self._close_event.wait() def get_extra_info(self, name: str, default: Any = None) -> Any: """Get additional information about the channel This method returns extra information about the channel once it is established. Supported values include `'connection'` to return the SSH connection this channel is running over plus all of the values supported on that connection. For TCP channels, the values `'local_peername'` and `'remote_peername'` are added to return the local and remote host and port information for the tunneled TCP connection. For UNIX channels, the values `'local_peername'` and `'remote_peername'` are added to return the local and remote path information for the tunneled UNIX domain socket connection. Since UNIX domain sockets provide no "source" address, only one of these will be filled in. See :meth:`get_extra_info() ` on :class:`SSHClientConnection` for more information. Additional information stored on the channel by calling :meth:`set_extra_info` can also be returned here. """ return self._extra.get(name, self._conn.get_extra_info(name, default) if self._conn else default) def set_extra_info(self, **kwargs: Any) -> None: """Store additional information associated with the channel This method allows extra information to be associated with the channel. The information to store should be passed in as keyword parameters and can later be returned by calling :meth:`get_extra_info` with one of the keywords as the name to retrieve. """ self._extra.update(**kwargs) def can_write_eof(self) -> bool: """Return whether the channel supports :meth:`write_eof` This method always returns `True`. """ # pylint: disable=no-self-use return True def get_write_buffer_size(self) -> int: """Return the current size of the channel's output buffer This method returns how many bytes are currently in the channel's output buffer waiting to be written. """ return self._send_buf_len def set_write_buffer_limits(self, high: Optional[int] = None, low: Optional[int] = None) -> None: """Set the high- and low-water limits for write flow control This method sets the limits used when deciding when to call the :meth:`pause_writing() ` and :meth:`resume_writing() ` methods on SSH sessions. Writing will be paused when the write buffer size exceeds the high-water mark, and resumed when the write buffer size equals or drops below the low-water mark. """ if high is None: high = 4*low if low is not None else 65536 if low is None: low = high // 4 if not 0 <= low <= high: raise ValueError(f'high (high) must be >= low ({low}) ' 'must be >= 0') self.logger.debug1('Set write buffer limits: low-water=%d, ' 'high-water=%d', low, high) self._send_high_water = high self._send_low_water = low self._pause_resume_writing() def write(self, data: AnyStr, datatype: DataType = None) -> None: """Write data on the channel This method can be called to send data on the channel. If an encoding was specified when the channel was created, the data should be provided as a string and will be converted using that encoding. Otherwise, the data should be provided as bytes. An extended data type can optionally be provided. For instance, this is used from a :class:`SSHServerSession` to write data to `stderr`. :param data: The data to send on the channel :param datatype: (optional) The extended data type of the data, from :ref:`extended data types ` :type data: `str` or `bytes` :type datatype: `int` :raises: :exc:`OSError` if the channel isn't open for sending or the extended data type is not valid for this type of channel """ if self._send_state != 'open': raise BrokenPipeError('Channel not open for sending') if datatype is not None and datatype not in self._write_datatypes: raise OSError('Invalid extended data type') if not data: return if self._encoding: assert self._encoder is not None encoded_data = self._encoder.encode(cast(str, data)) else: encoded_data = cast(bytes, data) datalen = len(encoded_data) if datatype: typename = f' to {_data_type_names[datatype]}' else: typename = '' self.logger.debug2('Sending %d data byte%s%s', datalen, 's' if datalen > 1 else '', typename) self._send_buf.append((bytearray(encoded_data), datatype)) self._send_buf_len += datalen self._flush_send_buf() def writelines(self, list_of_data: Iterable[AnyStr], datatype: DataType = None) -> None: """Write a list of data bytes on the channel This method can be called to write a list (or any iterable) of data bytes to the channel. It is functionality equivalent to calling :meth:`write` on each element in the list. :param list_of_data: The data to send on the channel :param datatype: (optional) The extended data type of the data, from :ref:`extended data types ` :type list_of_data: iterable of `str` or `bytes` :type datatype: `int` :raises: :exc:`OSError` if the channel isn't open for sending or the extended data type is not valid for this type of channel """ if self._encoding: data = cast(AnyStr, ''.join(cast(Iterable[str], list_of_data))) else: data = cast(AnyStr, b''.join(cast(Iterable[bytes], list_of_data))) return self.write(data, datatype) def write_eof(self) -> None: """Write EOF on the channel This method sends an end-of-file indication on the channel, after which no more data can be sent. The channel remains open, though, and data may still be sent in the other direction. :raises: :exc:`OSError` if the channel isn't open for sending """ self.logger.debug2('Sending EOF') if self._send_state == 'open': self._send_state = 'eof_pending' self._flush_send_buf() def pause_reading(self) -> None: """Pause delivery of incoming data This method is used to temporarily suspend delivery of incoming channel data. After this call, incoming data will no longer be delivered until :meth:`resume_reading` is called. Data will be buffered locally up to the configured SSH channel window size, but window updates will no longer be sent, eventually causing back pressure on the remote system. .. note:: Channel close notifications are not suspended by this call. If the remote system closes the channel while delivery is suspended, the channel will be closed even though some buffered data may not have been delivered. """ self.logger.debug2('Reading from channel paused') self._recv_paused = True def resume_reading(self) -> None: """Resume delivery of incoming data This method can be called to resume delivery of incoming data which was suspended by a call to :meth:`pause_reading`. As soon as this method is called, any buffered data will be delivered immediately. A pending end-of-file notification may also be delivered if one was queued while reading was paused. """ if self._recv_paused: self.logger.debug2('Reading from channel resumed') self._recv_paused = False self._flush_recv_buf() def get_environment(self) -> Mapping[str, str]: """Return the environment for this session This method returns the environment set by the client when the session was opened. Keys and values are of type `str` and this object only provides access to keys and values sent as valid UTF-8 strings. Use :meth:`get_environment_bytes` if you need to access environment variables with keys or values containing binary data or non-UTF-8 encodings. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A dictionary containing the environment variables set by the client """ if self._str_env is None: self._str_env = dict(decode_env(self._env)) return MappingProxyType(self._str_env) def get_environment_bytes(self) -> Mapping[bytes, bytes]: """Return the environment for this session This method returns the environment set by the client when the session was opened. Keys and values are of type `bytes` and can include arbitrary binary data, with the exception of NUL (\0) bytes. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A dictionary containing the environment variables set by the client """ return MappingProxyType(self._env) def get_command(self) -> Optional[str]: """Return the command the client requested to execute, if any This method returns the command the client requested to execute when the session was opened, if any. If the client did not request that a command be executed, this method will return `None`. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. """ return self._command def get_subsystem(self) -> Optional[str]: """Return the subsystem the client requested to open, if any This method returns the subsystem the client requested to open when the session was opened, if any. If the client did not request that a subsystem be opened, this method will return `None`. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. """ return self._subsystem class SSHClientChannel(SSHChannel, Generic[AnyStr]): """SSH client channel""" _conn: 'SSHClientConnection' _session: SSHClientSession[AnyStr] _read_datatypes = {EXTENDED_DATA_STDERR} def __init__(self, conn: 'SSHClientConnection', loop: asyncio.AbstractEventLoop, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, loop, encoding, errors, window, max_pktsize) self._exit_status: Optional[int] = None self._exit_signal: Optional[_ExitSignal] = None async def create(self, session_factory: SSHClientSessionFactory[AnyStr], command: Optional[str], subsystem: Optional[str], env: Dict[bytes, bytes], request_pty: bool, term_type: Optional[str], term_size: TermSizeArg, term_modes: TermModes, x11_forwarding: Union[bool, str], x11_display: Optional[str], x11_auth_path: Optional[str], x11_single_connection: bool, agent_forwarding: bool) -> SSHClientSession[AnyStr]: """Create an SSH client session""" self.logger.info('Requesting new SSH session') packet = await self._open(b'session') # Client sessions should have no extra data in the open confirmation packet.check_end() self._session = session_factory() self._session.connection_made(self) self._env = env self._command = command self._subsystem = subsystem for key, value in env.items(): self.logger.debug1(' Env: %s=%s', key, value) if not isinstance(key, (bytes, str)): key = str(key) if not isinstance(value, (bytes, str)): value = str(value) self._send_request(b'env', String(key), String(value)) if request_pty: self.logger.debug1(' Terminal type: %s', term_type or 'None') if not term_size: width = height = pixwidth = pixheight = 0 elif len(term_size) == 2: width, height = cast(Tuple[int, int], term_size) pixwidth = pixheight = 0 self.logger.debug1(' Terminal size: %sx%s', width, height) elif len(term_size) == 4: width, height, pixwidth, pixheight = cast(TermSize, term_size) self.logger.debug1(' Terminal size: %sx%s (%sx%s pixels)', width, height, pixwidth, pixheight) else: raise ValueError('If set, terminal size must be a tuple of ' '2 or 4 integers') modes = b'' for mode, mode_value in term_modes.items(): if mode <= PTY_OP_END or mode >= PTY_OP_RESERVED: raise ValueError(f'Invalid pty mode: {mode}') name = _pty_mode_names.get(mode, str(mode)) self.logger.debug2(' Mode %s: %d', name, mode_value) modes += Byte(mode) + UInt32(mode_value) modes += Byte(PTY_OP_END) if not (await self._make_request(b'pty-req', String(term_type or ''), UInt32(width), UInt32(height), UInt32(pixwidth), UInt32(pixheight), String(modes))): self.close() raise ChannelOpenError(OPEN_REQUEST_PTY_FAILED, 'PTY request failed') if x11_forwarding: self.logger.debug1(' X11 forwarding enabled') try: attach_result: Optional[Tuple[bytes, bytes, int]] = \ await self._conn.attach_x11_listener( self, x11_display, x11_auth_path, x11_single_connection) except ValueError as exc: if x11_forwarding != 'ignore_failure': raise ChannelOpenError(OPEN_REQUEST_X11_FORWARDING_FAILED, str(exc)) from None else: attach_result = None self.logger.info(' X11 forwarding attach failure ignored') if attach_result: auth_proto, remote_auth, screen = attach_result result = await self._make_request( b'x11-req', Boolean(x11_single_connection), String(auth_proto), String(binascii.b2a_hex(remote_auth)), UInt32(screen)) if not result: if self._conn: # pragma: no branch self._conn.detach_x11_listener(self) if x11_forwarding != 'ignore_failure': raise ChannelOpenError( OPEN_REQUEST_X11_FORWARDING_FAILED, 'X11 forwarding request failed') else: self.logger.info( ' X11 forwarding request failure ignored') if agent_forwarding: self.logger.debug1(' Agent forwarding enabled') self._send_request(b'auth-agent-req@openssh.com') if command: self.logger.info(' Command: %s', command) result = await self._make_request(b'exec', String(command)) elif subsystem: self.logger.info(' Subsystem: %s', subsystem) result = await self._make_request(b'subsystem', String(subsystem)) else: self.logger.info(' Interactive shell requested') result = await self._make_request(b'shell') if not result: self.close() raise ChannelOpenError(OPEN_REQUEST_SESSION_FAILED, 'Session request failed') self._session.session_started() self._conn.create_task(self._start_reading(), self.logger) return self._session def _process_xon_xoff_request(self, packet: SSHPacket) -> bool: """Process a request to set up XON/XOFF processing""" client_can_do = packet.get_boolean() packet.check_end() self.logger.info('Received XON/XOFF flow control %s request', 'enable' if client_can_do else 'disable') self._session.xon_xoff_requested(client_can_do) return True def _process_exit_status_request(self, packet: SSHPacket) -> bool: """Process a request to deliver exit status""" status = packet.get_uint32() & 0xff packet.check_end() self.logger.info('Received exit status %d', status) self._exit_status = status self._session.exit_status_received(status) return True def _process_exit_signal_request(self, packet: SSHPacket) -> bool: """Process a request to deliver an exit signal""" signal_bytes = packet.get_string() core_dumped = packet.get_boolean() msg_bytes = packet.get_string() lang_bytes = packet.get_string() packet.check_end() try: signal = signal_bytes.decode('ascii') msg = msg_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid exit signal request') from None self.logger.info('Received exit signal %s', signal) self.logger.debug1(' Core dumped: %s', core_dumped) self.logger.debug1(' Message: %s', msg) self._exit_signal = (signal, core_dumped, msg, lang) self._session.exit_signal_received(signal, core_dumped, msg, lang) return True def get_exit_status(self) -> Optional[int]: """Return the session's exit status This method returns the exit status of the session if one has been sent. If an exit signal was sent, this method returns -1 and the exit signal information can be collected by calling :meth:`get_exit_signal`. If neither has been sent, this method returns `None`. """ if self._exit_status is not None: return self._exit_status elif self._exit_signal: return -1 else: return None def get_exit_signal(self) -> Optional[_ExitSignal]: """Return the session's exit signal, if one was sent This method returns information about the exit signal sent on this session. If an exit signal was sent, a tuple is returned containing the signal name, a boolean for whether a core dump occurred, a message associated with the signal, and the language the message was in. Otherwise, this method returns `None`. """ return self._exit_signal def get_returncode(self) -> Optional[int]: """Return the session's exit status or signal This method returns the exit status of the session if one has been sent. If an exit signal was sent, this method returns the negative of the numeric value of that signal, matching the behavior of :meth:`asyncio.SubprocessTransport.get_returncode`. If neither has been sent, this method returns `None`. :returns: `int` or `None` """ if self._exit_status is not None: return self._exit_status elif self._exit_signal: return -_signal_numbers.get(self._exit_signal[0], 99) else: return None def change_terminal_size(self, width: int, height: int, pixwidth: int = 0, pixheight: int = 0) -> None: """Change the terminal window size for this session This method changes the width and height of the terminal associated with this session. :param width: The width of the terminal in characters :param height: The height of the terminal in characters :param pixwidth: (optional) The width of the terminal in pixels :param pixheight: (optional) The height of the terminal in pixels :type width: `int` :type height: `int` :type pixwidth: `int` :type pixheight: `int` """ if pixwidth or pixheight: self.logger.info('Sending window size change: %sx%s (%sx%s pixels)', width, height, pixwidth, pixheight) else: self.logger.info('Sending window size change: %sx%s', width, height) self._send_request(b'window-change', UInt32(width), UInt32(height), UInt32(pixwidth), UInt32(pixheight)) def send_break(self, msec: int) -> None: """Send a break to the remote process This method requests that the server perform a break operation on the remote process or service as described in :rfc:`4335`. :param msec: The duration of the break in milliseconds :type msec: `int` :raises: :exc:`OSError` if the channel is not open """ self.logger.info('Sending %d msec break', msec) self._send_request(b'break', UInt32(msec)) def send_signal(self, signal: Union[str, int]) -> None: """Send a signal to the remote process This method can be called to deliver a signal to the remote process or service. Signal names should be as described in section 6.10 of :rfc:`RFC 4254 <4254#section-6.10>`, or can be integer values as defined in the :mod:`signal` module, in which case they will be translated to their corresponding signal name before being sent. .. note:: OpenSSH's SSH server implementation prior to version 7.9 does not support this message, so attempts to use :meth:`send_signal`, :meth:`terminate`, or :meth:`kill` with an older OpenSSH SSH server will end up being ignored. This was tracked in OpenSSH `bug 1424`__. __ https://bugzilla.mindrot.org/show_bug.cgi?id=1424 :param signal: The signal to deliver :type signal: `str` or `int` :raises: | :exc:`OSError` if the channel is not open | :exc:`ValueError` if the signal number is unknown """ if isinstance(signal, int): try: signal = _signal_names[signal] except KeyError: raise ValueError(f'Unknown signal: {signal}') from None self.logger.info('Sending %s signal', signal) self._send_request(b'signal', String(signal)) def terminate(self) -> None: """Terminate the remote process This method can be called to terminate the remote process or service by sending it a `TERM` signal. :raises: :exc:`OSError` if the channel is not open .. note:: If your server-side runs on OpenSSH, this might be ineffective; for more details, see the note in :meth:`send_signal` """ self.send_signal('TERM') def kill(self) -> None: """Forcibly kill the remote process This method can be called to forcibly stop the remote process or service by sending it a `KILL` signal. :raises: :exc:`OSError` if the channel is not open .. note:: If your server-side runs on OpenSSH, this might be ineffective; for more details, see the note in :meth:`send_signal` """ self.send_signal('KILL') class SSHServerChannel(SSHChannel, Generic[AnyStr]): """SSH server channel""" _conn: 'SSHServerConnection' _session: SSHServerSession[AnyStr] _write_datatypes = {EXTENDED_DATA_STDERR} def __init__(self, conn: 'SSHServerConnection', loop: asyncio.AbstractEventLoop, allow_pty: bool, line_editor: bool, line_echo: bool, line_history: int, max_line_length: int, encoding: Optional[str], errors: str, window: int, max_pktsize: int): """Initialize an SSH server channel""" super().__init__(conn, loop, encoding, errors, window, max_pktsize) env_opt = cast(EnvMap, conn.get_key_option('environment', {})) self._env = dict(encode_env(env_opt)) self._allow_pty = allow_pty self._line_editor = line_editor self._line_echo = line_echo self._line_history = line_history self._max_line_length = max_line_length self._term_type: Optional[str] = None self._term_size = (0, 0, 0, 0) self._term_modes: TermModes = {} self._x11_display: Optional[str] = None self.logger.info('New SSH session requested') def _wrap_session(self, session: SSHSession[AnyStr]) -> \ Tuple[SSHChannel[AnyStr], SSHSession[AnyStr]]: """Wrap a line editor around the session if enabled""" if self._line_editor: server_chan = cast(SSHServerChannel[str], self) server_session = cast(SSHServerSession[str], session) editor_chan = SSHLineEditorChannel(server_chan, server_session, self._line_echo, self._line_history, self._max_line_length) editor_session = SSHLineEditorSession(editor_chan, server_session) chan = cast(SSHChannel[AnyStr], editor_chan) session = cast(SSHSession[AnyStr], editor_session) else: chan = self return chan, session def _process_pty_req_request(self, packet: SSHPacket) -> bool: """Process a request to open a pseudo-terminal""" term_type_bytes = packet.get_string() width = packet.get_uint32() height = packet.get_uint32() pixwidth = packet.get_uint32() pixheight = packet.get_uint32() modes = packet.get_string() packet.check_end() if not self._allow_pty or \ not self._conn.check_key_permission('pty') or \ not self._conn.check_certificate_permission('pty'): self.logger.info('PTY request denied: PTY not permitted') return False try: term_type = term_type_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid pty request') from None term_size = (width, height, pixwidth, pixheight) term_modes = {} self.logger.debug1(' Terminal type: %s', term_type) if pixwidth or pixheight: self.logger.debug1(' Terminal size: %sx%s (%sx%s pixels)', width, height, pixwidth, pixheight) else: self.logger.debug1(' Terminal size: %sx%s', width, height) idx = 0 while idx < len(modes): mode = modes[idx] idx += 1 if mode == PTY_OP_END or mode >= PTY_OP_RESERVED: break if idx+4 <= len(modes): name = _pty_mode_names.get(mode, str(mode)) value = int.from_bytes(modes[idx:idx+4], 'big') self.logger.debug2(' Mode %s: %s', name, value) term_modes[mode] = value idx += 4 else: raise ProtocolError('Invalid pty modes string') result = self._session.pty_requested(term_type, term_size, term_modes) if result: self.logger.info(' PTY created') if self._line_editor: self.logger.info(' Line editor enabled') self._term_type = term_type self._term_size = term_size self._term_modes = term_modes else: self.logger.info(' PTY creation failed') return result def _process_x11_req_request(self, packet: SSHPacket) -> Optional[bool]: """Process request to enable X11 forwarding""" _ = packet.get_boolean() # single_connection auth_proto = packet.get_string() auth_data = packet.get_string() screen = packet.get_uint32() packet.check_end() try: auth_data = binascii.a2b_hex(auth_data) except binascii.Error: self.logger.debug1(' Invalid X11 auth data') return False self._conn.create_task(self._finish_x11_req_request(auth_proto, auth_data, screen), self.logger) return None async def _finish_x11_req_request(self, auth_proto: bytes, auth_data: bytes, screen: int) -> None: """Finish processing request to enable X11 forwarding""" self._x11_display = await self._conn.attach_x11_listener( self, auth_proto, auth_data, screen) if self._x11_display: self.logger.debug1(' X11 forwarding enabled') self._report_response(True) else: self.logger.debug1(' X11 forwarding failed') self._report_response(False) def _process_auth_agent_req_at_openssh_dot_com_request( self, packet: SSHPacket) -> None: """Process a request to enable ssh-agent forwarding""" packet.check_end() self._conn.create_task(self._finish_agent_req_request(), self.logger) async def _finish_agent_req_request(self) -> None: """Finish processing request to enable agent forwarding""" if await self._conn.create_agent_listener(): self.logger.debug1(' Agent forwarding enabled') self._report_response(True) else: self.logger.debug1(' Agent forwarding failed') self._report_response(False) def _process_env_request(self, packet: SSHPacket) -> bool: """Process a request to set an environment variable""" key = packet.get_string() value = packet.get_string() packet.check_end() self.logger.debug1(' Env: %s=%s', key, value) self._env[key] = value return True def _start_session(self, command: Optional[str] = None, subsystem: Optional[str] = None) -> bool: """Tell the session what type of channel is being requested""" forced_command = \ cast(str, self._conn.get_certificate_option('force-command')) if forced_command is None: forced_command = cast(str, self._conn.get_key_option('command')) if forced_command is not None: self.logger.info(' Forced command override: %s', forced_command) command = forced_command if command is not None: self._command = command result = self._session.exec_requested(command) elif subsystem is not None: self._subsystem = subsystem result = self._session.subsystem_requested(subsystem) else: result = self._session.shell_requested() return result def _process_shell_request(self, packet: SSHPacket) -> bool: """Process a request to open a shell""" packet.check_end() self.logger.info(' Interactive shell requested') return self._start_session() def _process_exec_request(self, packet: SSHPacket) -> bool: """Process a request to execute a command""" command_bytes = packet.get_string() packet.check_end() try: command = command_bytes.decode('utf-8') except UnicodeDecodeError: return False self.logger.info(' Command: %s', command) return self._start_session(command=command) def _process_subsystem_request(self, packet: SSHPacket) -> bool: """Process a request to open a subsystem""" subsystem_bytes = packet.get_string() packet.check_end() try: subsystem = subsystem_bytes.decode('ascii') except UnicodeDecodeError: return False self.logger.info(' Subsystem: %s', subsystem) return self._start_session(subsystem=subsystem) def _process_window_change_request(self, packet: SSHPacket) -> bool: """Process a request to change the window size""" width = packet.get_uint32() height = packet.get_uint32() pixwidth = packet.get_uint32() pixheight = packet.get_uint32() packet.check_end() if pixwidth or pixheight: self.logger.info('Received window change: %sx%s (%sx%s pixels)', width, height, pixwidth, pixheight) else: self.logger.info('Received window change: %sx%s', width, height) self._term_size = (width, height, pixwidth, pixheight) self._session.terminal_size_changed(width, height, pixwidth, pixheight) return True def _process_signal_request(self, packet: SSHPacket) -> bool: """Process a request to send a signal""" signal_bytes = packet.get_string() packet.check_end() try: signal = signal_bytes.decode('ascii') except UnicodeDecodeError: return False self.logger.info('Received %s signal', signal) self._session.signal_received(signal) return True def _process_break_request(self, packet: SSHPacket) -> bool: """Process a request to send a break""" msec = packet.get_uint32() packet.check_end() self.logger.info('Received %d msec break', msec) return self._session.break_received(msec) def get_terminal_type(self) -> Optional[str]: """Return the terminal type for this session This method returns the terminal type set by the client when the session was opened. If the client didn't request a pseudo-terminal, this method will return `None`. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A `str` containing the terminal type or `None` if no pseudo-terminal was requested """ return self._term_type def get_terminal_size(self) -> TermSize: """Return terminal size information for this session This method returns the latest terminal size information set by the client. If the client didn't set any terminal size information, all values returned will be zero. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. Also see :meth:`terminal_size_changed() ` or the :exc:`TerminalSizeChanged` exception for how to get notified when the terminal size changes. :returns: A tuple of four `int` values containing the width and height of the terminal in characters and the width and height of the terminal in pixels """ return self._term_size def get_terminal_mode(self, mode: int) -> Optional[int]: """Return the requested TTY mode for this session This method looks up the value of a POSIX terminal mode set by the client when the session was opened. If the client didn't request a pseudo-terminal or didn't set the requested TTY mode opcode, this method will return `None`. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :param mode: POSIX terminal mode taken from :ref:`POSIX terminal modes ` to look up :type mode: `int` :returns: An `int` containing the value of the requested POSIX terminal mode or `None` if the requested mode was not set """ return self._term_modes.get(mode) def get_terminal_modes(self) -> TermModes: """Return the TTY modes for this session This method returns a mapping of all the POSIX terminal modes set by the client when the session was opened. If the client didn't request a pseudo-terminal, this method will return an empty mapping. Calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, calls to this can be made at any time after the handler function has started up. :returns: A mapping containing all the POSIX terminal modes set by the client or an empty mapping if no pseudo-terminal was requested """ return MappingProxyType(self._term_modes) def get_x11_display(self) -> Optional[str]: """Return the display to use for X11 forwarding When X11 forwarding has been requested by the client, this method returns the X11 display which should be used to open a forwarded connection. If the client did not request X11 forwarding, this method returns `None`. :returns: A `str` containing the X11 display or `None` if X11 forwarding was not requested """ return self._x11_display def get_agent_path(self) -> Optional[str]: """Return the path of the ssh-agent listening socket When agent forwarding has been requested by the client, this method returns the path of the listening socket which should be used to open a forwarded agent connection. If the client did not request agent forwarding, this method returns `None`. :returns: A `str` containing the ssh-agent socket path or `None` if agent forwarding was not requested """ return self._conn.get_agent_path() def set_xon_xoff(self, client_can_do: bool) -> None: """Set whether the client should enable XON/XOFF flow control This method can be called to tell the client whether or not to enable XON/XOFF flow control, indicating that it should intercept Control-S and Control-Q coming from its local terminal to pause and resume output, respectively. Applications should set client_can_do to `True` to enable this functionality or to `False` to tell the client to forward Control-S and Control-Q through as normal input. :param client_can_do: Whether or not the client should enable XON/XOFF flow control :type client_can_do: `bool` """ self.logger.info('Sending XON/XOFF flow control %s', 'enable' if client_can_do else 'disable') self._send_request(b'xon-xoff', Boolean(client_can_do)) def write_stderr(self, data: AnyStr) -> None: """Write output to stderr This method can be called to send output to the client which is intended to be displayed on stderr. If an encoding was specified when the channel was created, the data should be provided as a string and will be converted using that encoding. Otherwise, the data should be provided as bytes. :param data: The data to send to stderr :type data: `str` or `bytes` :raises: :exc:`OSError` if the channel isn't open for sending """ self.write(data, EXTENDED_DATA_STDERR) def writelines_stderr(self, list_of_data: Iterable[AnyStr]) -> None: """Write a list of data bytes to stderr This method can be called to write a list (or any iterable) of data bytes to the channel. It is functionality equivalent to calling :meth:`write_stderr` on each element in the list. """ self.writelines(list_of_data, EXTENDED_DATA_STDERR) def exit(self, status: int) -> None: """Send exit status and close the channel This method can be called to report an exit status for the process back to the client and close the channel. A zero exit status is generally returned when the operation was successful. After reporting the status, the channel is closed. :param status: The exit status to report to the client :type status: `int` :raises: :exc:`OSError` if the channel isn't open """ status &= 0xff if self._send_state not in {'close_pending', 'closed'}: self.logger.info('Sending exit status %d', status) self._send_request(b'exit-status', UInt32(status)) self.close() def exit_with_signal(self, signal: str, core_dumped: bool = False, msg: str = '', lang: str = DEFAULT_LANG) -> None: """Send exit signal and close the channel This method can be called to report that the process terminated abnormslly with a signal. A more detailed error message may also provided, along with an indication of whether or not the process dumped core. After reporting the signal, the channel is closed. :param signal: The signal which caused the process to exit :param core_dumped: (optional) Whether or not the process dumped core :param msg: (optional) Details about what error occurred :param lang: (optional) The language the error message is in :type signal: `str` :type core_dumped: `bool` :type msg: `str` :type lang: `str` :raises: :exc:`OSError` if the channel isn't open """ self.logger.info('Sending exit signal %s', signal) self.logger.debug1(' Core dumped: %s', core_dumped) self.logger.debug1(' Message: %s', msg) if self._send_state not in {'close_pending', 'closed'}: self._send_request(b'exit-signal', String(signal), Boolean(core_dumped), String(msg), String(lang)) self.close() class SSHForwardChannel(SSHChannel, Generic[AnyStr]): """SSH channel for forwarding TCP and UNIX domain connections""" async def _finish_open_request( self, result: MaybeAwait[SSHSession[AnyStr]]) -> None: """Finish processing a forward channel open request""" await super()._finish_open_request(result) if self._session is not None: self._session.session_started() self.resume_reading() async def _open_forward(self, session_factory: SSHSessionFactory[AnyStr], chantype: bytes, *args: bytes) -> \ SSHSession[AnyStr]: """Open a forward channel""" packet = await super()._open(chantype, *args) # Forward channels should have no extra data in the open confirmation packet.check_end() self._session = session_factory() self._session.connection_made(self) self._session.session_started() assert self._conn is not None self._conn.create_task(self._start_reading(), self.logger) return self._session class SSHTCPChannel(SSHForwardChannel, Generic[AnyStr]): """SSH TCP channel""" async def _open_tcp(self, session_factory: SSHTCPSessionFactory[AnyStr], chantype: bytes, host: str, port: int, orig_host: str, orig_port: int) -> SSHTCPSession[AnyStr]: """Open a TCP channel""" self.set_extra_info(peername=('', 0), local_peername=(orig_host, orig_port), remote_peername=(host, port)) return cast(SSHTCPSession[AnyStr], await self._open_forward(session_factory, chantype, String(host), UInt32(port), String(orig_host), UInt32(orig_port))) async def connect(self, session_factory: SSHTCPSessionFactory[AnyStr], host: str, port: int, orig_host: str, orig_port: int) -> \ SSHTCPSession[AnyStr]: """Create a new outbound TCP session""" return (await self._open_tcp(session_factory, b'direct-tcpip', host, port, orig_host, orig_port)) async def accept(self, session_factory: SSHTCPSessionFactory[AnyStr], host: str, port: int, orig_host: str, orig_port: int) -> SSHTCPSession[AnyStr]: """Create a new forwarded TCP session""" return (await self._open_tcp(session_factory, b'forwarded-tcpip', host, port, orig_host, orig_port)) def set_inbound_peer_names(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> None: """Set local and remote peer names for inbound connections""" self.set_extra_info(peername=('', 0), local_peername=(dest_host, dest_port), remote_peername=(orig_host, orig_port)) class SSHUNIXChannel(SSHForwardChannel, Generic[AnyStr]): """SSH UNIX channel""" async def _open_unix(self, session_factory: SSHUNIXSessionFactory[AnyStr], chantype: bytes, path: str, *args: bytes) -> SSHUNIXSession[AnyStr]: """Open a UNIX channel""" self.set_extra_info(local_peername='', remote_peername=path) return cast(SSHUNIXSession[AnyStr], await self._open_forward(session_factory, chantype, String(path), *args)) async def connect(self, session_factory: SSHUNIXSessionFactory[AnyStr], path: str) -> SSHUNIXSession[AnyStr]: """Create a new outbound UNIX session""" # OpenSSH appears to have a bug which requires an originator # host and port to be sent after the path name to connect to # when opening a direct streamlocal channel. return await self._open_unix(session_factory, b'direct-streamlocal@openssh.com', path, String(''), UInt32(0)) async def accept(self, session_factory: SSHUNIXSessionFactory[AnyStr], path: str) -> SSHUNIXSession[AnyStr]: """Create a new forwarded UNIX session""" return await self._open_unix(session_factory, b'forwarded-streamlocal@openssh.com', path, String('')) def set_inbound_peer_names(self, dest_path: str) -> None: """Set local and remote peer names for inbound connections""" self.set_extra_info(local_peername=dest_path, remote_peername='') class SSHTunTapChannel(SSHForwardChannel[bytes]): """SSH TunTap channel""" def __init__(self, conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, loop, encoding, errors, window, max_pktsize) self._mode: Optional[int] = None def _accept_data(self, data: bytes, datatype: DataType = None) -> None: """Strip off address family on incoming packets in TUN mode""" if self._mode == SSH_TUN_MODE_POINTTOPOINT: data = data[4:] super()._accept_data(data, datatype) def write(self, data: bytes, datatype: DataType = None) -> None: """Add address family in outbound packets in TUN mode""" if self._mode == SSH_TUN_MODE_POINTTOPOINT: version = data[0] >> 4 family = SSH_TUN_AF_INET if version == 4 else SSH_TUN_AF_INET6 data = UInt32(family) + data super().write(data, datatype) async def open(self, session_factory: SSHTunTapSessionFactory, mode: int, unit: Optional[int]) -> SSHTunTapSession: """Open a TUN/TAP channel""" self._mode = mode if unit is None: unit = SSH_TUN_UNIT_ANY return cast(SSHTunTapSession, await self._open_forward(session_factory, b'tun@openssh.com', UInt32(mode), UInt32(unit))) def set_mode(self, mode: int) -> None: """Set mode for inbound connections""" self._mode = mode class SSHX11Channel(SSHForwardChannel[bytes]): """SSH X11 channel""" async def open(self, session_factory: SSHTCPSessionFactory[bytes], orig_host: str, orig_port: int) -> SSHTCPSession[bytes]: """Open an SSH X11 channel""" self.set_extra_info(local_peername=(orig_host, orig_port), remote_peername=('', 0)) return cast(SSHTCPSession[bytes], await self._open_forward(session_factory, b'x11', String(orig_host), UInt32(orig_port))) def set_inbound_peer_names(self, orig_host: str, orig_port: int) -> None: """Set local and remote peer name for inbound connections""" self.set_extra_info(local_peername=('', 0), remote_peername=(orig_host, orig_port)) class SSHAgentChannel(SSHForwardChannel[bytes]): """SSH agent channel""" async def open(self, session_factory: SSHUNIXSessionFactory[bytes]) -> \ SSHUNIXSession[bytes]: """Open an SSH agent channel""" return cast(SSHUNIXSession[bytes], await self._open_forward(session_factory, b'auth-agent@openssh.com')) asyncssh-2.20.0/asyncssh/client.py000066400000000000000000000422661475467777400171570ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH client protocol handler""" from typing import TYPE_CHECKING, Optional from .auth import KbdIntPrompts, KbdIntResponse, PasswordChangeResponse from .misc import MaybeAwait from .public_key import KeyPairListArg, SSHKey if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHClientConnection class SSHClient: """SSH client protocol handler Applications may subclass this when implementing an SSH client to receive callbacks when certain events occur on the SSH connection. Whenever a new SSH client connection is opened, a corresponding SSHClient object is created and the method :meth:`connection_made` is called, passing in the :class:`SSHClientConnection` object. When the connection is closed, the method :meth:`connection_lost` is called with an exception representing the reason for the disconnect, or `None` if the connection was closed cleanly. For simple password or public key based authentication, nothing needs to be defined here if the password or client keys are passed in when the connection is created. However, to prompt interactively or otherwise dynamically select these values, the methods :meth:`password_auth_requested` and/or :meth:`public_key_auth_requested` can be defined. Keyboard-interactive authentication is also supported via :meth:`kbdint_auth_requested` and :meth:`kbdint_challenge_received`. If the server sends an authentication banner, the method :meth:`auth_banner_received` will be called. If the server requires a password change, the method :meth:`password_change_requested` will be called, followed by either :meth:`password_changed` or :meth:`password_change_failed` depending on whether the password change is successful. .. note:: The authentication callbacks described here can be defined as coroutines. However, they may be cancelled if they are running when the SSH connection is closed by the server. If they attempt to catch the CancelledError exception to perform cleanup, they should make sure to re-raise it to allow AsyncSSH to finish its own cleanup. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, conn: 'SSHClientConnection') -> None: """Called when a connection is made This method is called as soon as the TCP connection completes. The `conn` parameter should be stored if needed for later use. :param conn: The connection which was successfully opened :type conn: :class:`SSHClientConnection` """ def connection_lost(self, exc: Optional[Exception]) -> None: """Called when a connection is lost or closed This method is called when a connection is closed. If the connection is shut down cleanly, *exc* will be `None`. Otherwise, it will be an exception explaining the reason for the disconnect. :param exc: The exception which caused the connection to close, or `None` if the connection closed cleanly :type exc: :class:`Exception` """ def debug_msg_received(self, msg: str, lang: str, always_display: bool) -> None: """A debug message was received on this connection This method is called when the other end of the connection sends a debug message. Applications should implement this method if they wish to process these debug messages. :param msg: The debug message sent :param lang: The language the message is in :param always_display: Whether or not to display the message :type msg: `str` :type lang: `str` :type always_display: `bool` """ def validate_host_public_key(self, host: str, addr: str, port: int, key: SSHKey) -> bool: """Return whether key is an authorized key for this host Server host key validation can be supported by passing known host keys in the `known_hosts` argument of :func:`create_connection`. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid host key for the server being connected to. By default, this method returns `False` for all host keys. .. note:: This function only needs to report whether the public key provided is a valid key for this host. If it is, AsyncSSH will verify that the server possesses the corresponding private key before allowing the validation to succeed. :param host: The hostname of the target host :param addr: The IP address of the target host :param port: The port number on the target host :param key: The public key sent by the server :type host: `str` :type addr: `str` :type port: `int` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid key for the target host """ return False # pragma: no cover def validate_host_ca_key(self, host: str, addr: str, port: int, key: SSHKey) -> bool: """Return whether key is an authorized CA key for this host Server host certificate validation can be supported by passing known host CA keys in the `known_hosts` argument of :func:`create_connection`. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid certificate authority key for the server being connected to. By default, this method returns `False` for all CA keys. .. note:: This function only needs to report whether the public key provided is a valid CA key for this host. If it is, AsyncSSH will verify that the certificate is valid, that the host is one of the valid principals for the certificate, and that the server possesses the private key corresponding to the public key in the certificate before allowing the validation to succeed. :param host: The hostname of the target host :param addr: The IP address of the target host :param port: The port number on the target host :param key: The public key which signed the certificate sent by the server :type host: `str` :type addr: `str` :type port: `int` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid CA key for the target host """ return False # pragma: no cover def auth_banner_received(self, msg: str, lang: str) -> None: """An incoming authentication banner was received This method is called when the server sends a banner to display during authentication. Applications should implement this method if they wish to do something with the banner. :param msg: The message the server wanted to display :param lang: The language the message is in :type msg: `str` :type lang: `str` """ def begin_auth(self, username: str) -> None: """Begin client authentication This method is called when client authentication is about to begin, Applications may store the username passed here to be used in future authentication callbacks. """ def auth_completed(self) -> None: """Authentication was completed successfully This method is called when authentication has completed successfully. Applications may use this method to create whatever client sessions and direct TCP/IP or UNIX domain connections are needed and/or set up listeners for incoming TCP/IP or UNIX domain connections coming from the server. However, :func:`create_connection` now blocks until authentication is complete, so any code which wishes to use the SSH connection can simply follow that call and doesn't need to be performed in a callback. """ def public_key_auth_requested(self) -> \ MaybeAwait[Optional[KeyPairListArg]]: """Public key authentication has been requested This method should return a private key corresponding to the user that authentication is being attempted for. This method may be called multiple times and can return a different key to try each time it is called. When there are no keys left to try, it should return `None` to indicate that some other authentication method should be tried. If client keys were provided when the connection was opened, they will be tried before this method is called. If blocking operations need to be performed to determine the key to authenticate with, this method may be defined as a coroutine. :returns: A key as described in :ref:`SpecifyingPrivateKeys` or `None` to move on to another authentication method """ return None # pragma: no cover def password_auth_requested(self) -> MaybeAwait[Optional[str]]: """Password authentication has been requested This method should return a string containing the password corresponding to the user that authentication is being attempted for. It may be called multiple times and can return a different password to try each time, but most servers have a limit on the number of attempts allowed. When there's no password left to try, this method should return `None` to indicate that some other authentication method should be tried. If a password was provided when the connection was opened, it will be tried before this method is called. If blocking operations need to be performed to determine the password to authenticate with, this method may be defined as a coroutine. :returns: A string containing the password to authenticate with or `None` to move on to another authentication method """ return None # pragma: no cover def password_change_requested(self, prompt: str, lang: str) -> \ MaybeAwait[PasswordChangeResponse]: """A password change has been requested This method is called when password authentication was attempted and the user's password was expired on the server. To request a password change, this method should return a tuple or two strings containing the old and new passwords. Otherwise, it should return `NotImplemented`. If blocking operations need to be performed to determine the passwords to authenticate with, this method may be defined as a coroutine. By default, this method returns `NotImplemented`. :param prompt: The prompt requesting that the user enter a new password :param lang: The language that the prompt is in :type prompt: `str` :type lang: `str` :returns: A tuple of two strings containing the old and new passwords or `NotImplemented` if password changes aren't supported """ return NotImplemented # pragma: no cover def password_changed(self) -> None: """The requested password change was successful This method is called to indicate that a requested password change was successful. It is generally followed by a call to :meth:`auth_completed` since this means authentication was also successful. """ def password_change_failed(self) -> None: """The requested password change has failed This method is called to indicate that a requested password change failed, generally because the requested new password doesn't meet the password criteria on the remote system. After this method is called, other forms of authentication will automatically be attempted. """ def kbdint_auth_requested(self) -> MaybeAwait[Optional[str]]: """Keyboard-interactive authentication has been requested This method should return a string containing a comma-separated list of submethods that the server should use for keyboard-interactive authentication. An empty string can be returned to let the server pick the type of keyboard-interactive authentication to perform. If keyboard-interactive authentication is not supported, `None` should be returned. By default, keyboard-interactive authentication is supported if a password was provided when the :class:`SSHClient` was created and it hasn't been sent yet. If the challenge is not a password challenge, this authentication will fail. This method and the :meth:`kbdint_challenge_received` method can be overridden if other forms of challenge should be supported. If blocking operations need to be performed to determine the submethods to request, this method may be defined as a coroutine. :returns: A string containing the submethods the server should use for authentication or `None` to move on to another authentication method """ return NotImplemented # pragma: no cover def kbdint_challenge_received(self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts) -> \ MaybeAwait[Optional[KbdIntResponse]]: """A keyboard-interactive auth challenge has been received This method is called when the server sends a keyboard-interactive authentication challenge. The return value should be a list of strings of the same length as the number of prompts provided if the challenge can be answered, or `None` to indicate that some other form of authentication should be attempted. If blocking operations need to be performed to determine the responses to authenticate with, this method may be defined as a coroutine. By default, this method will look for a challenge consisting of a single 'Password:' prompt, and call the method :meth:`password_auth_requested` to provide the response. It will also ignore challenges with no prompts (generally used to provide instructions). Any other form of challenge will cause this method to return `None` to move on to another authentication method. :param name: The name of the challenge :param instructions: Instructions to the user about how to respond to the challenge :param lang: The language the challenge is in :param prompts: The challenges the user should respond to and whether or not the responses should be echoed when they are entered :type name: `str` :type instructions: `str` :type lang: `str` :type prompts: `list` of tuples of `str` and `bool` :returns: List of string responses to the challenge or `None` to move on to another authentication method """ return None # pragma: no cover asyncssh-2.20.0/asyncssh/compression.py000066400000000000000000000101021475467777400202220ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH compression handlers""" from typing import Callable, List, Optional import zlib _cmp_algs: List[bytes] = [] _default_cmp_algs: List[bytes] = [] _cmp_params = {} _cmp_compressors = {} _cmp_decompressors = {} class Compressor: """Base class for data compressor""" def compress(self, data: bytes) -> Optional[bytes]: """Compress data""" raise NotImplementedError class Decompressor: """Base class for data decompressor""" def decompress(self, data: bytes) -> Optional[bytes]: """Decompress data""" raise NotImplementedError _CompressorType = Callable[[], Optional[Compressor]] _DecompressorType = Callable[[], Optional[Decompressor]] def _none() -> None: """Compressor/decompressor for no compression""" return None class _ZLibCompress(Compressor): """Wrapper class to force a sync flush and handle exceptions""" def __init__(self) -> None: self._comp = zlib.compressobj() def compress(self, data: bytes) -> Optional[bytes]: """Compress data using zlib compression with sync flush""" try: return self._comp.compress(data) + \ self._comp.flush(zlib.Z_SYNC_FLUSH) except zlib.error: # pragma: no cover return None class _ZLibDecompress(Decompressor): """Wrapper class to handle exceptions""" def __init__(self) -> None: self._decomp = zlib.decompressobj() def decompress(self, data: bytes) -> Optional[bytes]: """Decompress data using zlib compression""" try: return self._decomp.decompress(data) except zlib.error: # pragma: no cover return None def register_compression_alg(alg: bytes, compressor: _CompressorType, decompressor: _DecompressorType, after_auth: bool, default: bool) -> None: """Register a compression algorithm""" _cmp_algs.append(alg) if default: _default_cmp_algs.append(alg) _cmp_params[alg] = after_auth _cmp_compressors[alg] = compressor _cmp_decompressors[alg] = decompressor def get_compression_algs() -> List[bytes]: """Return supported compression algorithms""" return _cmp_algs def get_default_compression_algs() -> List[bytes]: """Return default compression algorithms""" return _default_cmp_algs def get_compression_params(alg: bytes) -> bool: """Get parameters of a compression algorithm This function returns whether or not a compression algorithm should be delayed until after authentication completes. """ return _cmp_params[alg] def get_compressor(alg: bytes) -> Optional[Compressor]: """Return an instance of a compressor This function returns an object that can be used for data compression. """ return _cmp_compressors[alg]() def get_decompressor(alg: bytes) -> Optional[Decompressor]: """Return an instance of a decompressor This function returns an object that can be used for data decompression. """ return _cmp_decompressors[alg]() register_compression_alg(b'none', _none, _none, False, True) register_compression_alg(b'zlib@openssh.com', _ZLibCompress, _ZLibDecompress, True, True) register_compression_alg(b'zlib', _ZLibCompress, _ZLibDecompress, False, False) asyncssh-2.20.0/asyncssh/config.py000066400000000000000000000624711475467777400171460ustar00rootroot00000000000000# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Parser for OpenSSH config files""" import os import re import shlex import socket import subprocess from hashlib import sha1 from pathlib import Path, PurePath from subprocess import DEVNULL from typing import Callable, Dict, List, NoReturn, Optional, Sequence from typing import Set, Tuple, Union, cast from .constants import DEFAULT_PORT from .logging import logger from .misc import DefTuple, FilePath, ip_address from .pattern import HostPatternList, WildcardPatternList ConfigPaths = Union[None, FilePath, Sequence[FilePath]] _token_pattern = re.compile(r'%(.)') _env_pattern = re.compile(r'\${(.*)}') def _exec(cmd: str) -> bool: """Execute a command and return if exit status is 0""" return subprocess.run(cmd, check=False, shell=True, stdin=DEVNULL, stdout=DEVNULL, stderr=DEVNULL).returncode == 0 class ConfigParseError(ValueError): """Configuration parsing exception""" class SSHConfig: """Settings from an OpenSSH config file""" _conditionals = {'match'} _no_split: Set[str] = set() _percent_expand = {'AuthorizedKeysFile'} _handlers: Dict[str, Tuple[str, Callable]] = {} def __init__(self, last_config: Optional['SSHConfig'], reload: bool, canonical: bool, final: bool): if last_config: self._last_options = last_config.get_options(reload) else: self._last_options = {} self._canonical = canonical self._final = True if final else None self._default_path = Path('~', '.ssh').expanduser() self._path = Path() self._line_no = 0 self._matching = True self._options = self._last_options.copy() self._tokens: Dict[str, str] = {} self.loaded = False def _error(self, reason: str) -> NoReturn: """Raise a configuration parsing error""" raise ConfigParseError(f'{self._path} line {self._line_no}: {reason}') def _match_val(self, match: str) -> object: """Return the value to match against in a match condition""" raise NotImplementedError def _set_tokens(self) -> None: """Set the tokens available for percent expansion""" raise NotImplementedError def _expand_token(self, match): """Expand a percent token reference""" try: token = match.group(1) return self._tokens[token] except KeyError: if token == 'd': raise ConfigParseError('Home directory is ' 'not available') from None elif token == 'i': raise ConfigParseError('User id not available') from None else: raise ConfigParseError('Invalid token expansion: ' + token) from None @staticmethod def _expand_env(match): """Expand an environment variable reference""" try: var = match.group(1) return os.environ[var] except KeyError: raise ConfigParseError('Invalid environment expansion: ' + var) from None def _expand_val(self, value: str) -> str: """Perform percent token and environment expansion on a string""" return _env_pattern.sub(self._expand_env, _token_pattern.sub(self._expand_token, value)) def _include(self, option: str, args: List[str]) -> None: """Read config from a list of other config files""" # pylint: disable=unused-argument old_path = self._path for pattern in args: path = Path(pattern).expanduser() if path.anchor: pattern = str(Path(*path.parts[1:])) path = Path(path.anchor) else: path = self._default_path paths = list(path.glob(pattern)) if not paths: logger.debug1(f'Config pattern "{pattern}" matched no files') for path in paths: self.parse(path) self._path = old_path args.clear() def _match(self, option: str, args: List[str]) -> None: """Begin a conditional block""" # pylint: disable=unused-argument matching = True while args: match = args.pop(0).lower() if match[0] == '!': match = match[1:] negated = True else: negated = False if match == 'final' and self._final is None: self._final = False if match == 'all': result = True elif match == 'canonical': result = self._canonical elif match == 'final': result = cast(bool, self._final) else: match_val = self._match_val(match) if match != 'exec' and match_val is None: self._error(f'Invalid match condition {match}') try: arg = args.pop(0) except IndexError: self._error(f'Missing {match} match pattern') if matching: if match == 'exec': result = _exec(arg) elif match in ('address', 'localaddress'): host_pat = HostPatternList(arg) ip = ip_address(cast(str, match_val)) \ if match_val else None result = host_pat.matches(None, match_val, ip) else: wild_pat = WildcardPatternList(arg) result = wild_pat.matches(match_val) if matching and result == negated: matching = False self._matching = matching def _set_bool(self, option: str, args: List[str]) -> None: """Set a boolean config option""" value_str = args.pop(0).lower() if value_str in ('yes', 'true'): value = True elif value_str in ('no', 'false'): value = False else: self._error(f'Invalid {option} boolean value: {value_str}') if option not in self._options: self._options[option] = value def _set_bool_or_str(self, option: str, args: List[str]) -> None: """Set a boolean or string config option""" value_str = args.pop(0) value_lower = value_str.lower() if value_lower in ('yes', 'true'): value: Union[bool, str] = True elif value_lower in ('no', 'false'): value = False else: value = value_str if option not in self._options: self._options[option] = value def _set_int(self, option: str, args: List[str]) -> None: """Set an integer config option""" value_str = args.pop(0) try: value = int(value_str) except ValueError: self._error(f'Invalid {option} integer value: {value_str}') if option not in self._options: self._options[option] = value def _set_string(self, option: str, args: List[str]) -> None: """Set a string config option""" value_str = args.pop(0) if value_str.lower() == 'none': value = None else: value = value_str if option not in self._options: self._options[option] = value def _append_string(self, option: str, args: List[str]) -> None: """Append a string config option to a list""" value_str = args.pop(0) if value_str.lower() != 'none': if option in self._options: cast(List[str], self._options[option]).append(value_str) else: self._options[option] = [value_str] else: if option not in self._options: self._options[option] = [] def _set_string_list(self, option: str, args: List[str]) -> None: """Set whitespace-separated string config options as a list""" if option not in self._options: if len(args) == 1 and args[0].lower() == 'none': self._options[option] = [] else: self._options[option] = args[:] args.clear() def _append_string_list(self, option: str, args: List[str]) -> None: """Append whitespace-separated string config options to a list""" if option in self._options: cast(List[str], self._options[option]).extend(args) else: self._options[option] = args[:] args.clear() def _set_address_family(self, option: str, args: List[str]) -> None: """Set an address family config option""" value_str = args.pop(0).lower() if value_str == 'any': value = socket.AF_UNSPEC elif value_str == 'inet': value = socket.AF_INET elif value_str == 'inet6': value = socket.AF_INET6 else: self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value def _set_canonicalize_host(self, option: str, args: List[str]) -> None: """Set a canonicalize host config option""" value_str = args.pop(0).lower() if value_str in ('yes', 'true'): value: Union[bool, str] = True elif value_str in ('no', 'false'): value = False elif value_str == 'always': value = value_str else: self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value def _set_rekey_limits(self, option: str, args: List[str]) -> None: """Set rekey limits config option""" byte_limit: Union[str, Tuple[()]] = args.pop(0).lower() if byte_limit == 'default': byte_limit = () if args: time_limit: Optional[Union[str, Tuple[()]]] = args.pop(0).lower() if time_limit == 'none': time_limit = None else: time_limit = () if option not in self._options: self._options[option] = byte_limit, time_limit def has_match_final(self) -> bool: """Return whether this config includes a 'Match final' block""" return self._final is not None def parse(self, path: Path) -> None: """Parse an OpenSSH config file and return matching declarations""" self._path = path self._line_no = 0 self._matching = True self._tokens = {'%': '%'} logger.debug1('Reading config from "%s"', path) with open(path) as file: for line in file: self._line_no += 1 line = line.strip() if not line or line[0] == '#': continue try: split_args = shlex.split(line) except ValueError as exc: self._error(str(exc)) args = [] loption = '' allow_equal = True for i, arg in enumerate(split_args, 1): if arg.startswith('='): if len(arg) > 1: args.append(arg[1:]) elif not allow_equal: args.extend(split_args[i-1:]) break elif arg.endswith('='): args.append(arg[:-1]) elif '=' in arg: arg, val = arg.split('=', 1) args.append(arg) args.append(val) else: args.append(arg) if i == 1: loption = args.pop(0).lower() allow_equal = loption in self._conditionals if loption in self._no_split: args = [line.lstrip()[len(loption):].strip()] if not self._matching and loption not in self._conditionals: continue try: option, handler = self._handlers[loption] except KeyError: continue if not args: self._error(f'Missing {option} value') handler(self, option, args) if args: self._error(f'Extra data at end: {" ".join(args)}') self._set_tokens() for option in self._percent_expand: try: value = self._options[option] except KeyError: pass else: if isinstance(value, list): value = [self._expand_val(item) for item in value] elif isinstance(value, str): value = self._expand_val(value) self._options[option] = value def get_options(self, reload: bool) -> Dict[str, object]: """Return options to base a new config object on""" return self._last_options.copy() if reload else self._options.copy() @classmethod def load(cls, last_config: Optional['SSHConfig'], config_paths: ConfigPaths, reload: bool, canonical: bool, final: bool, *args: object) -> 'SSHConfig': """Load a list of OpenSSH config files into a config object""" config = cls(last_config, reload, canonical, final, *args) if config_paths: if isinstance(config_paths, (str, PurePath)): paths: Sequence[FilePath] = [config_paths] else: paths = config_paths for path in paths: config.parse(Path(path)) config.loaded = True return config def get(self, option: str, default: object = None) -> object: """Get the value of a config option""" return self._options.get(option, default) def get_compression_algs(self) -> DefTuple[str]: """Return the compression algorithms to use""" compression = self.get('Compression') if compression is None: return () elif compression: return 'zlib@openssh.com,zlib,none' else: return 'none,zlib@openssh.com,zlib' class SSHClientConfig(SSHConfig): """Settings from an OpenSSH client config file""" _conditionals = {'host', 'match'} _no_split = {'proxycommand', 'remotecommand'} _percent_expand = {'CertificateFile', 'ForwardAgent', 'IdentityAgent', 'IdentityFile', 'ProxyCommand', 'RemoteCommand'} def __init__(self, last_config: 'SSHConfig', reload: bool, canonical: bool, final: bool, local_user: str, user: str, host: str, port: int) -> None: super().__init__(last_config, reload, canonical, final) self._local_user = local_user self._orig_host = host if user != (): self._options['User'] = user if port != (): self._options['Port'] = port def _match_val(self, match: str) -> object: """Return the value to match against in a match condition""" if match == 'host': return self._options.get('Hostname', self._orig_host) elif match == 'originalhost': return self._orig_host elif match == 'localuser': return self._local_user elif match == 'user': return self._options.get('User', self._local_user) elif match == 'tagged': return self._options.get('Tag', '') else: return None def _match_host(self, option: str, args: List[str]) -> None: """Begin a conditional block matching on host""" # pylint: disable=unused-argument pattern = ','.join(args) self._matching = WildcardPatternList(pattern).matches(self._orig_host) args.clear() def _set_hostname(self, option: str, args: List[str]) -> None: """Set hostname config option""" value = args.pop(0) if option not in self._options: self._tokens['h'] = \ cast(str, self._options.get(option, self._orig_host)) self._options[option] = self._expand_val(value) def _set_request_tty(self, option: str, args: List[str]) -> None: """Set a pseudo-terminal request config option""" value_str = args.pop(0).lower() if value_str in ('yes', 'true'): value: Union[bool, str] = True elif value_str in ('no', 'false'): value = False elif value_str in ('force', 'auto'): value = value_str else: self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value def _set_tokens(self) -> None: """Set the tokens available for percent expansion""" local_host = socket.gethostname() idx = local_host.find('.') short_local_host = local_host if idx < 0 else local_host[:idx] host = cast(str, self._options.get('Hostname', self._orig_host)) port = str(self._options.get('Port', DEFAULT_PORT)) user = cast(str, self._options.get('User') or self._local_user) home = os.path.expanduser('~') conn_info = ''.join((local_host, host, port, user)) conn_hash = sha1(conn_info.encode('utf-8')).hexdigest() self._tokens.update({'C': conn_hash, 'h': host, 'L': short_local_host, 'l': local_host, 'n': self._orig_host, 'p': port, 'r': user, 'u': self._local_user}) if home != '~': self._tokens['d'] = home if hasattr(os, 'getuid'): self._tokens['i'] = str(os.getuid()) _handlers = {option.lower(): (option, handler) for option, handler in ( ('Host', _match_host), ('Match', SSHConfig._match), ('Include', SSHConfig._include), ('AddressFamily', SSHConfig._set_address_family), ('BindAddress', SSHConfig._set_string), ('CanonicalDomains', SSHConfig._set_string_list), ('CanonicalizeFallbackLocal', SSHConfig._set_bool), ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), ('CanonicalizeMaxDots', SSHConfig._set_int), ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('CertificateFile', SSHConfig._append_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), ('Ciphers', SSHConfig._set_string), ('Compression', SSHConfig._set_bool), ('ConnectTimeout', SSHConfig._set_int), ('EnableSSHKeySign', SSHConfig._set_bool), ('ForwardAgent', SSHConfig._set_bool_or_str), ('ForwardX11Trusted', SSHConfig._set_bool), ('GlobalKnownHostsFile', SSHConfig._set_string_list), ('GSSAPIAuthentication', SSHConfig._set_bool), ('GSSAPIDelegateCredentials', SSHConfig._set_bool), ('GSSAPIKeyExchange', SSHConfig._set_bool), ('HostbasedAuthentication', SSHConfig._set_bool), ('HostKeyAlgorithms', SSHConfig._set_string), ('Hostname', _set_hostname), ('HostKeyAlias', SSHConfig._set_string), ('IdentitiesOnly', SSHConfig._set_bool), ('IdentityAgent', SSHConfig._set_string), ('IdentityFile', SSHConfig._append_string), ('KbdInteractiveAuthentication', SSHConfig._set_bool), ('KexAlgorithms', SSHConfig._set_string), ('MACs', SSHConfig._set_string), ('PasswordAuthentication', SSHConfig._set_bool), ('PKCS11Provider', SSHConfig._set_string), ('PreferredAuthentications', SSHConfig._set_string), ('Port', SSHConfig._set_int), ('ProxyCommand', SSHConfig._set_string), ('ProxyJump', SSHConfig._set_string), ('PubkeyAuthentication', SSHConfig._set_bool), ('RekeyLimit', SSHConfig._set_rekey_limits), ('RemoteCommand', SSHConfig._set_string), ('RequestTTY', _set_request_tty), ('SendEnv', SSHConfig._append_string_list), ('ServerAliveCountMax', SSHConfig._set_int), ('ServerAliveInterval', SSHConfig._set_int), ('SetEnv', SSHConfig._set_string_list), ('Tag', SSHConfig._set_string), ('TCPKeepAlive', SSHConfig._set_bool), ('User', SSHConfig._set_string), ('UserKnownHostsFile', SSHConfig._set_string_list) )} class SSHServerConfig(SSHConfig): """Settings from an OpenSSH server config file""" def __init__(self, last_config: 'SSHConfig', reload: bool, canonical: bool, final: bool, local_addr: str, local_port: int, user: str, host: str, addr: str) -> None: super().__init__(last_config, reload, canonical, final) self._local_addr = local_addr self._local_port = local_port self._user = user self._host = host or addr self._addr = addr def _match_val(self, match: str) -> object: """Return the value to match against in a match condition""" if match == 'localaddress': return self._local_addr elif match == 'localport': return str(self._local_port) elif match == 'user': return self._user elif match == 'host': return self._host elif match == 'address': return self._addr else: return None def _set_tokens(self) -> None: """Set the tokens available for percent expansion""" self._tokens.update({'u': self._user}) _handlers = {option.lower(): (option, handler) for option, handler in ( ('Match', SSHConfig._match), ('Include', SSHConfig._include), ('AddressFamily', SSHConfig._set_address_family), ('AuthorizedKeysFile', SSHConfig._set_string_list), ('AllowAgentForwarding', SSHConfig._set_bool), ('BindAddress', SSHConfig._set_string), ('CanonicalDomains', SSHConfig._set_string_list), ('CanonicalizeFallbackLocal', SSHConfig._set_bool), ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), ('CanonicalizeMaxDots', SSHConfig._set_int), ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), ('Ciphers', SSHConfig._set_string), ('ClientAliveCountMax', SSHConfig._set_int), ('ClientAliveInterval', SSHConfig._set_int), ('Compression', SSHConfig._set_bool), ('GSSAPIAuthentication', SSHConfig._set_bool), ('GSSAPIKeyExchange', SSHConfig._set_bool), ('HostbasedAuthentication', SSHConfig._set_bool), ('HostCertificate', SSHConfig._append_string), ('HostKey', SSHConfig._append_string), ('KbdInteractiveAuthentication', SSHConfig._set_bool), ('KexAlgorithms', SSHConfig._set_string), ('LoginGraceTime', SSHConfig._set_int), ('MACs', SSHConfig._set_string), ('PasswordAuthentication', SSHConfig._set_bool), ('PermitTTY', SSHConfig._set_bool), ('Port', SSHConfig._set_int), ('PubkeyAuthentication', SSHConfig._set_bool), ('RekeyLimit', SSHConfig._set_rekey_limits), ('TCPKeepAlive', SSHConfig._set_bool), ('UseDNS', SSHConfig._set_bool) )} asyncssh-2.20.0/asyncssh/connection.py000066400000000000000000014420131475467777400200330ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH connection handlers""" import asyncio import functools import getpass import inspect import io import ipaddress import os import shlex import socket import sys import tempfile import time from collections import OrderedDict from functools import partial from pathlib import Path from types import TracebackType from typing import TYPE_CHECKING, Any, AnyStr, Awaitable, Callable, Dict from typing import Generic, List, Mapping, Optional, Sequence, Set, Tuple from typing import Type, TypeVar, Union, cast from typing_extensions import Protocol, Self from .agent import SSHAgentClient, SSHAgentListener from .auth import Auth, ClientAuth, KbdIntChallenge, KbdIntPrompts from .auth import KbdIntResponse, PasswordChangeResponse from .auth import get_supported_client_auth_methods, lookup_client_auth from .auth import get_supported_server_auth_methods, lookup_server_auth from .auth_keys import SSHAuthorizedKeys, read_authorized_keys from .channel import SSHChannel, SSHClientChannel, SSHServerChannel from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel from .channel import SSHX11Channel, SSHAgentChannel from .client import SSHClient from .compression import Compressor, Decompressor, get_compression_algs from .compression import get_default_compression_algs, get_compression_params from .compression import get_compressor, get_decompressor from .config import ConfigPaths, SSHConfig, SSHClientConfig, SSHServerConfig from .constants import DEFAULT_LANG, DEFAULT_PORT from .constants import DISC_BY_APPLICATION from .constants import EXTENDED_DATA_STDERR from .constants import MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG from .constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_EXT_INFO from .constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from .constants import MSG_CHANNEL_OPEN_FAILURE from .constants import MSG_CHANNEL_FIRST, MSG_CHANNEL_LAST from .constants import MSG_KEXINIT, MSG_NEWKEYS, MSG_KEX_FIRST, MSG_KEX_LAST from .constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE from .constants import MSG_USERAUTH_SUCCESS, MSG_USERAUTH_BANNER from .constants import MSG_USERAUTH_FIRST, MSG_USERAUTH_LAST from .constants import MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS from .constants import MSG_REQUEST_FAILURE from .constants import OPEN_ADMINISTRATIVELY_PROHIBITED, OPEN_CONNECT_FAILED from .constants import OPEN_UNKNOWN_CHANNEL_TYPE from .encryption import Encryption, get_encryption_algs from .encryption import get_default_encryption_algs from .encryption import get_encryption_params, get_encryption from .forward import SSHForwarder from .gss import GSSBase, GSSClient, GSSServer, GSSError from .kex import Kex, get_kex_algs, get_default_kex_algs from .kex import expand_kex_algs, get_kex from .keysign import KeySignPath, SSHKeySignKeyPair from .keysign import find_keysign, get_keysign_keys from .known_hosts import KnownHostsArg, match_known_hosts from .listener import ListenKey, SSHListener from .listener import SSHTCPClientListener, SSHUNIXClientListener from .listener import TCPListenerFactory, UNIXListenerFactory from .listener import create_tcp_forward_listener, create_unix_forward_listener from .listener import create_socks_listener from .logging import SSHLogger, logger from .mac import get_mac_algs, get_default_mac_algs from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvSeq, FilePath from .misc import HostPort, IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr from .misc import ChannelListenError, ChannelOpenError, CompressionError from .misc import DisconnectError, ConnectionLost, HostKeyNotVerifiable from .misc import KeyExchangeFailed, IllegalUserName, MACError from .misc import PasswordChangeRequired, PermissionDenied, ProtocolError from .misc import ProtocolNotSupported, ServiceNotAvailable from .misc import TermModesArg, TermSizeArg from .misc import async_context_manager, construct_disc_error, encode_env from .misc import get_symbol_names, ip_address, lookup_env, map_handler_name from .misc import parse_byte_count, parse_time_interval, split_args from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError from .packet import SSHPacket, SSHPacketHandler, SSHPacketLogger from .pattern import WildcardPattern, WildcardPatternList from .pkcs11 import load_pkcs11_keys from .process import PIPE, ProcessSource, ProcessTarget from .process import SSHServerProcessFactory, SSHCompletedProcess from .process import SSHClientProcess, SSHServerProcess from .public_key import CERT_TYPE_HOST, CERT_TYPE_USER, KeyImportError from .public_key import CertListArg, IdentityListArg, KeyListArg, SigningKey from .public_key import KeyPairListArg, X509CertPurposes, SSHKey, SSHKeyPair from .public_key import SSHCertificate, SSHOpenSSHCertificate from .public_key import SSHX509Certificate, SSHX509CertificateChain from .public_key import decode_ssh_public_key, decode_ssh_certificate from .public_key import get_public_key_algs, get_default_public_key_algs from .public_key import get_certificate_algs, get_default_certificate_algs from .public_key import get_x509_certificate_algs from .public_key import get_default_x509_certificate_algs from .public_key import load_keypairs, load_default_keypairs from .public_key import load_public_keys, load_default_host_public_keys from .public_key import load_certificates from .public_key import load_identities, load_default_identities from .saslprep import saslprep, SASLPrepError from .server import SSHServer from .session import DataType, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHClientSessionFactory, SSHTCPSessionFactory from .session import SSHUNIXSessionFactory, SSHTunTapSessionFactory from .sftp import MIN_SFTP_VERSION, SFTPClient, SFTPServer from .sftp import start_sftp_client from .stream import SSHReader, SSHWriter, SFTPServerFactory from .stream import SSHSocketSessionFactory, SSHServerSessionFactory from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SSHTCPStreamSession, SSHUNIXStreamSession from .stream import SSHTunTapStreamSession from .subprocess import SSHSubprocessTransport, SSHSubprocessProtocol from .subprocess import SubprocessFactory, SSHSubprocessWritePipe from .tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_MODE_ETHERNET from .tuntap import SSH_TUN_UNIT_ANY, create_tuntap from .version import __version__ from .x11 import SSHX11ClientForwarder from .x11 import SSHX11ClientListener, SSHX11ServerListener from .x11 import create_x11_client_listener, create_x11_server_listener if TYPE_CHECKING: # pylint: disable=unused-import from .crypto import X509NamePattern _ClientFactory = Callable[[], SSHClient] _ServerFactory = Callable[[], SSHServer] _ProtocolFactory = Union[_ClientFactory, _ServerFactory] _Conn = TypeVar('_Conn', bound='SSHConnection') _Options = TypeVar('_Options', bound='SSHConnectionOptions') _ServerHostKeysHandler = Optional[Callable[[List[SSHKey], List[SSHKey], List[SSHKey], List[SSHKey]], MaybeAwait[None]]] class _TunnelProtocol(Protocol): """Base protocol for connections to tunnel SSH over""" def close(self) -> None: """Close this tunnel""" class _TunnelConnectorProtocol(_TunnelProtocol, Protocol): """Protocol to open a connection to tunnel an SSH connection over""" async def create_connection( self, session_factory: SSHTCPSessionFactory[bytes], remote_host: str, remote_port: int) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Create an outbound tunnel connection""" class _TunnelListenerProtocol(_TunnelProtocol, Protocol): """Protocol to open a listener to tunnel SSH connections over""" async def create_server(self, session_factory: TCPListenerFactory, listen_host: str, listen_port: int) -> SSHListener: """Create an inbound tunnel listener""" _AcceptHandler = Optional[Callable[['SSHConnection'], MaybeAwait[None]]] _ErrorHandler = Optional[Callable[['SSHConnection', Optional[Exception]], None]] _OpenHandler = Callable[[SSHPacket], Tuple[SSHClientChannel, SSHClientSession]] _PacketHandler = Callable[[SSHPacket], None] _AlgsArg = DefTuple[Union[str, Sequence[str]]] _AuthArg = DefTuple[bool] _AuthKeysArg = DefTuple[Union[None, str, List[str], SSHAuthorizedKeys]] _ClientHostKey = Union[SSHKeyPair, SSHKeySignKeyPair] _ClientKeysArg = Union[KeyListArg, KeyPairListArg] _CNAMEArg = DefTuple[Union[Sequence[str], Sequence[Tuple[str, str]]]] _GlobalRequest = Tuple[Optional[_PacketHandler], SSHPacket, bool] _GlobalRequestResult = Tuple[int, SSHPacket] _KeyOrCertOptions = Mapping[str, object] _ListenerArg = Union[bool, SSHListener] _ProxyCommand = Optional[Union[str, Sequence[str]]] _RequestPTY = Union[bool, str] _TCPServerHandlerFactory = Callable[[str, int], SSHSocketSessionFactory] _UNIXServerHandlerFactory = Callable[[], SSHSocketSessionFactory] _TunnelConnector = Union[None, str, _TunnelConnectorProtocol] _TunnelListener = Union[None, str, _TunnelListenerProtocol] _VersionArg = DefTuple[BytesOrStr] SSHAcceptHandler = Callable[[str, int], MaybeAwait[bool]] # SSH service names _USERAUTH_SERVICE = b'ssh-userauth' _CONNECTION_SERVICE = b'ssh-connection' # Max banner and version line length and count _MAX_BANNER_LINES = 1024 _MAX_BANNER_LINE_LEN = 8192 _MAX_VERSION_LINE_LEN = 255 # Max allowed username length _MAX_USERNAME_LEN = 1024 # Default rekey parameters _DEFAULT_REKEY_BYTES = 1 << 30 # 1 GiB _DEFAULT_REKEY_SECONDS = 3600 # 1 hour # Default login timeout _DEFAULT_LOGIN_TIMEOUT = 120 # 2 minutes # Default keepalive interval and count max _DEFAULT_KEEPALIVE_INTERVAL = 0 # disabled by default _DEFAULT_KEEPALIVE_COUNT_MAX = 3 # Default channel parameters _DEFAULT_WINDOW = 2*1024*1024 # 2 MiB _DEFAULT_MAX_PKTSIZE = 32768 # 32 kiB # Default line editor parameters _DEFAULT_LINE_HISTORY = 1000 # 1000 lines _DEFAULT_MAX_LINE_LENGTH = 1024 # 1024 characters async def _canonicalize_host(loop: asyncio.AbstractEventLoop, options: 'SSHConnectionOptions') -> Optional[str]: """Canonicalize a host name""" host = options.host if not options.canonicalize_hostname or not options.canonical_domains: logger.info('Host canonicalization disabled') return None if host.count('.') > options.canonicalize_max_dots: logger.info('Host canonicalization skipped due to max dots') return None try: ipaddress.ip_address(host) except ValueError: pass else: logger.info('Hostname canonicalization skipped on IP address') return None logger.debug1('Beginning hostname canonicalization') for domain in options.canonical_domains: logger.debug1(' Checking domain %s', domain) canon_host = f'{host}.{domain}' try: addrinfo = await loop.getaddrinfo( canon_host, 0, flags=socket.AI_CANONNAME) except socket.gaierror: continue cname = addrinfo[0][3] if cname and cname != canon_host: logger.debug1(' Checking CNAME rules for hostname %s ' 'with CNAME %s', canon_host, cname) for patterns in options.canonicalize_permitted_cnames: host_pat, cname_pat = map(WildcardPatternList, patterns) if host_pat.matches(canon_host) and cname_pat.matches(cname): logger.info('Hostname canonicalization to CNAME ' 'applied: %s -> %s', options.host, cname) return cname logger.info('Hostname canonicalization applied: %s -> %s', options.host, canon_host) return canon_host if not options.canonicalize_fallback_local: logger.info('Hostname canonicalization failed (fallback disabled)') raise OSError(f'Unable to canonicalize hostname "{host}"') logger.info('Hostname canonicalization failed, using local resolver') return None async def _open_proxy( loop: asyncio.AbstractEventLoop, command: Sequence[str], conn_factory: Callable[[], _Conn]) -> _Conn: """Open a tunnel running a proxy command""" class _ProxyCommandTunnel(asyncio.SubprocessProtocol): """SSH proxy command tunnel""" def __init__(self) -> None: super().__init__() self._transport: Optional[asyncio.SubprocessTransport] = None self._stdin: Optional[asyncio.WriteTransport] = None self._conn = conn_factory() self._close_event = asyncio.Event() def get_extra_info(self, name: str, default: Any = None) -> Any: """Return extra information associated with this tunnel""" assert self._transport is not None return self._transport.get_extra_info(name, default) def get_conn(self) -> _Conn: """Return the connection associated with this tunnel""" return self._conn def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle startup of the subprocess""" self._transport = cast(asyncio.SubprocessTransport, transport) self._stdin = cast(asyncio.WriteTransport, self._transport.get_pipe_transport(0)) self._conn.connection_made(cast(asyncio.BaseTransport, self)) def pipe_data_received(self, fd: int, data: bytes) -> None: """Handle data received from this tunnel""" # pylint: disable=unused-argument self._conn.data_received(data) def pipe_connection_lost(self, fd: int, exc: Optional[Exception]) -> None: """Handle when this tunnel is closed""" # pylint: disable=unused-argument self._conn.connection_lost(exc) def write(self, data: bytes) -> None: """Write data to this tunnel""" assert self._stdin is not None self._stdin.write(data) def abort(self) -> None: """Forcibly close this tunnel""" self.close() def close(self) -> None: """Close this tunnel""" if self._transport: # pragma: no cover self._transport.close() self._close_event.set() _, tunnel = await loop.subprocess_exec(_ProxyCommandTunnel, *command) return cast(_Conn, cast(_ProxyCommandTunnel, tunnel).get_conn()) async def _open_tunnel(tunnels: object, options: _Options, config: DefTuple[ConfigPaths]) -> \ Optional['SSHClientConnection']: """Parse and open connection to tunnel over""" username: DefTuple[str] port: DefTuple[int] if isinstance(tunnels, str): conn: Optional[SSHClientConnection] = None for tunnel in tunnels.split(','): if '@' in tunnel: username, host = tunnel.rsplit('@', 1) else: username, host = (), tunnel if ':' in host: host, port_str = host.rsplit(':', 1) port = int(port_str) else: port = () last_conn = conn conn = await connect(host, port, username=username, passphrase=options.passphrase, tunnel=conn, config=config) conn.set_tunnel(last_conn) if options.canonicalize_hostname != 'always': options.canonicalize_hostname = False return conn else: return None async def _connect(options: _Options, config: DefTuple[ConfigPaths], loop: asyncio.AbstractEventLoop, flags: int, sock: Optional[socket.socket], conn_factory: Callable[[], _Conn], msg: str) -> _Conn: """Make outbound TCP or SSH tunneled connection""" options.waiter = loop.create_future() canon_host = await _canonicalize_host(loop, options) host = canon_host if canon_host else options.host canonical = bool(canon_host) final = options.config.has_match_final() if canonical or final: options.update(host=host, reload=True, canonical=canonical, final=final) host = options.host port = options.port tunnel = options.tunnel family = options.family local_addr = options.local_addr proxy_command = options.proxy_command free_conn = True new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelConnectorProtocol try: if sock: logger.info('%s already-connected socket', msg) _, session = await loop.create_connection(conn_factory, sock=sock) conn = cast(_Conn, session) elif new_tunnel: new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel) # pylint: disable=broad-except try: _, tunnel_session = await new_tunnel.create_connection( cast(SSHTCPSessionFactory[bytes], conn_factory), host, port) except Exception: new_tunnel.close() await new_tunnel.wait_closed() raise else: conn = cast(_Conn, tunnel_session) conn.set_tunnel(new_tunnel) elif tunnel: tunnel_logger = getattr(tunnel, 'logger', logger) tunnel_logger.info('%s %s via SSH tunnel', msg, (host, port)) _, tunnel_session = await tunnel.create_connection( cast(SSHTCPSessionFactory[bytes], conn_factory), host, port) conn = cast(_Conn, tunnel_session) elif proxy_command: conn = await _open_proxy(loop, proxy_command, conn_factory) else: logger.info('%s %s', msg, (host, port)) _, session = await loop.create_connection( conn_factory, host, port, family=family, flags=flags, local_addr=local_addr) conn = cast(_Conn, session) except asyncio.CancelledError: options.waiter.cancel() raise conn.set_extra_info(host=host, port=port) try: await options.waiter free_conn = False return conn finally: if free_conn: conn.abort() await conn.wait_closed() async def _listen(options: _Options, config: DefTuple[ConfigPaths], loop: asyncio.AbstractEventLoop, flags: int, backlog: int, sock: Optional[socket.socket], reuse_address: bool, reuse_port: bool, conn_factory: Callable[[], _Conn], msg: str) -> 'SSHAcceptor': """Make inbound TCP or SSH tunneled listener""" def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession: """Ignore original host and port""" return cast(SSHTCPSession, conn_factory()) host = options.host port = options.port tunnel = options.tunnel family = options.family new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelListenerProtocol if sock: logger.info('%s already-connected socket', msg) server: asyncio.AbstractServer = await loop.create_server( conn_factory, sock=sock, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port) elif new_tunnel: new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel) # pylint: disable=broad-except try: tunnel_server = await new_tunnel.create_server( tunnel_factory, host, port) except Exception: new_tunnel.close() await new_tunnel.wait_closed() raise else: tunnel_server.set_tunnel(new_tunnel) server = cast(asyncio.AbstractServer, tunnel_server) elif tunnel: tunnel_logger = getattr(tunnel, 'logger', logger) tunnel_logger.info('%s %s via SSH tunnel', msg, (host, port)) tunnel_server = await tunnel.create_server(tunnel_factory, host, port) server = cast(asyncio.AbstractServer, tunnel_server) else: logger.info('%s %s', msg, (host, port)) server = await loop.create_server( conn_factory, host, port, family=family, flags=flags, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port) return SSHAcceptor(server, options) def _validate_version(version: DefTuple[BytesOrStr]) -> bytes: """Validate requested SSH version""" if version == (): version = b'AsyncSSH_' + __version__.encode('ascii') else: if isinstance(version, str): version = version.encode('ascii') else: assert isinstance(version, bytes) # Version including 'SSH-2.0-' and CRLF must be 255 chars or less if len(version) > 245: raise ValueError('Version string is too long') for b in version: if b < 0x20 or b > 0x7e: raise ValueError('Version string must be printable ASCII') return version def _expand_algs(alg_type: str, algs: str, possible_algs: List[bytes], default_algs: List[bytes], strict_match: bool) -> Sequence[bytes]: """Expand the set of allowed algorithms""" if algs[:1] in '^+-': prefix = algs[:1] algs = algs[1:] else: prefix = '' matched: List[bytes] = [] for pat in algs.split(','): pattern = WildcardPattern(pat) matches = [alg for alg in possible_algs if pattern.matches(alg.decode('ascii'))] if not matches and strict_match: raise ValueError(f'"{pat}" matches no valid {alg_type} algorithms') matched.extend(matches) if prefix == '^': return matched + default_algs elif prefix == '+': return default_algs + matched elif prefix == '-': return [alg for alg in default_algs if alg not in matched] else: return matched def _select_algs(alg_type: str, algs: _AlgsArg, config_algs: _AlgsArg, possible_algs: List[bytes], default_algs: List[bytes], none_value: Optional[bytes] = None) -> Sequence[bytes]: """Select a set of allowed algorithms""" if algs == (): algs = config_algs strict_match = False else: strict_match = True if algs in ((), 'default'): return default_algs elif algs: if isinstance(algs, str): expanded_algs = _expand_algs(alg_type, algs, possible_algs, default_algs, strict_match) else: expanded_algs = [alg.encode('ascii') for alg in algs] result: List[bytes] = [] for alg in expanded_algs: if alg not in possible_algs: raise ValueError(f'{alg.decode("ascii")} is not a valid ' f'{alg_type} algorithm') if alg not in result: result.append(alg) return result elif none_value: return [none_value] else: raise ValueError(f'No {alg_type} algorithms selected') def _select_host_key_algs(algs: _AlgsArg, config_algs: _AlgsArg, default_algs: List[bytes]) -> Sequence[bytes]: """Select a set of allowed host key algorithms""" possible_algs = (get_x509_certificate_algs() + get_certificate_algs() + get_public_key_algs()) return _select_algs('host key', algs, config_algs, possible_algs, default_algs) def _validate_algs(config: SSHConfig, kex_algs_arg: _AlgsArg, enc_algs_arg: _AlgsArg, mac_algs_arg: _AlgsArg, cmp_algs_arg: _AlgsArg, sig_algs_arg: _AlgsArg, allow_x509: bool) -> \ Tuple[Sequence[bytes], Sequence[bytes], Sequence[bytes], Sequence[bytes], Sequence[bytes]]: """Validate requested algorithms""" kex_algs = _select_algs('key exchange', kex_algs_arg, cast(_AlgsArg, config.get('KexAlgorithms', ())), get_kex_algs(), get_default_kex_algs()) enc_algs = _select_algs('encryption', enc_algs_arg, cast(_AlgsArg, config.get('Ciphers', ())), get_encryption_algs(), get_default_encryption_algs()) mac_algs = _select_algs('MAC', mac_algs_arg, cast(_AlgsArg, config.get('MACs', ())), get_mac_algs(), get_default_mac_algs()) cmp_algs = _select_algs('compression', cmp_algs_arg, cast(_AlgsArg, config.get_compression_algs()), get_compression_algs(), get_default_compression_algs(), b'none') allowed_sig_algs = get_x509_certificate_algs() if allow_x509 else [] allowed_sig_algs = allowed_sig_algs + get_public_key_algs() default_sig_algs = get_default_x509_certificate_algs() if allow_x509 else [] default_sig_algs = allowed_sig_algs + get_default_public_key_algs() sig_algs = _select_algs('signature', sig_algs_arg, cast(_AlgsArg, config.get('CASignatureAlgorithms', ())), allowed_sig_algs, default_sig_algs) return kex_algs, enc_algs, mac_algs, cmp_algs, sig_algs class SSHAcceptor: """SSH acceptor This class in a wrapper around an :class:`asyncio.Server` listener which provides the ability to update the the set of SSH client or server connection options associated with that listener. This is accomplished by calling the :meth:`update` method, which takes the same keyword arguments as the :class:`SSHClientConnectionOptions` and :class:`SSHServerConnectionOptions` classes. In addition, this class supports all of the methods supported by :class:`asyncio.Server` to control accepting of new connections. """ def __init__(self, server: asyncio.AbstractServer, options: 'SSHConnectionOptions'): self._server = server self._options = options async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: self.close() await self.wait_closed() return False def __getattr__(self, name: str) -> Any: return getattr(self._server, name) def get_addresses(self) -> List[Tuple]: """Return socket addresses being listened on This method returns the socket addresses being listened on. It returns tuples of the form returned by :meth:`socket.getsockname`. If the listener was created using a hostname, the host's resolved IPs will be returned. If the requested listening port was `0`, the selected listening ports will be returned. :returns: A list of socket addresses being listened on """ if hasattr(self._server, 'get_addresses'): return self._server.get_addresses() else: return [sock.getsockname() for sock in self.sockets] def get_port(self) -> int: """Return the port number being listened on This method returns the port number being listened on. If it is listening on multiple sockets with different port numbers, this function will return `0`. In that case, :meth:`get_addresses` can be used to retrieve the full list of listening addresses and ports. :returns: The port number being listened on, if there's only one """ if hasattr(self._server, 'get_port'): return self._server.get_port() else: ports = {addr[1] for addr in self.get_addresses()} return ports.pop() if len(ports) == 1 else 0 def close(self) -> None: """Stop listening for new connections This method can be called to stop listening for new SSH connections. Existing connections will remain open. """ self._server.close() async def wait_closed(self) -> None: """Wait for this listener to close This method is a coroutine which waits for this listener to be closed. """ await self._server.wait_closed() def update(self, **kwargs: object) -> None: """Update options on an SSH listener Acceptors started by :func:`listen` support options defined in :class:`SSHServerConnectionOptions`. Acceptors started by :func:`listen_reverse` support options defined in :class:`SSHClientConnectionOptions`. Changes apply only to SSH client/server connections accepted after the change is made. Previously accepted connections will continue to use the options set when they were accepted. """ self._options.update(**kwargs) class SSHConnection(SSHPacketHandler, asyncio.Protocol): """Parent class for SSH connections""" _handler_names = get_symbol_names(globals(), 'MSG_') next_conn = 0 # Next connection number, for logging @staticmethod def _get_next_conn() -> int: """Return the next available connection number (for logging)""" next_conn = SSHConnection.next_conn SSHConnection.next_conn += 1 return next_conn def __init__(self, loop: asyncio.AbstractEventLoop, options: 'SSHConnectionOptions', acceptor: _AcceptHandler, error_handler: _ErrorHandler, wait: Optional[str], server: bool): self._loop = loop self._options = options self._protocol_factory = options.protocol_factory self._acceptor = acceptor self._error_handler = error_handler self._server = server self._wait = wait self._waiter = options.waiter if wait else None self._transport: Optional[asyncio.Transport] = None self._local_addr = '' self._local_port = 0 self._peer_host = '' self._peer_addr = '' self._peer_port = 0 self._tcp_keepalive = options.tcp_keepalive self._owner: Optional[Union[SSHClient, SSHServer]] = None self._extra: Dict[str, object] = {} self._inpbuf = b'' self._packet = b'' self._pktlen = 0 self._banner_lines = 0 self._version = options.version self._client_version = b'' self._server_version = b'' self._client_kexinit = b'' self._server_kexinit = b'' self._session_id = b'' self._send_seq = 0 self._send_encryption: Optional[Encryption] = None self._send_enchdrlen = 5 self._send_blocksize = 8 self._compressor: Optional[Compressor] = None self._compress_after_auth = False self._deferred_packets: List[Tuple[int, Sequence[bytes]]] = [] self._recv_handler = self._recv_version self._recv_seq = 0 self._recv_encryption: Optional[Encryption] = None self._recv_blocksize = 8 self._recv_macsize = 0 self._decompressor: Optional[Decompressor] = None self._decompress_after_auth = False self._next_recv_encryption: Optional[Encryption] = None self._next_recv_blocksize = 0 self._next_recv_macsize = 0 self._next_decompressor: Optional[Decompressor] = None self._next_decompress_after_auth = False self._trusted_host_keys: Optional[Set[SSHKey]] = set() self._trusted_host_key_algs: List[bytes] = [] self._trusted_ca_keys: Optional[Set[SSHKey]] = set() self._revoked_host_keys: Set[SSHKey] = set() self._x509_trusted_certs = options.x509_trusted_certs self._x509_trusted_cert_paths = options.x509_trusted_cert_paths self._x509_revoked_certs: Set[SSHX509Certificate] = set() self._x509_trusted_subjects: Sequence['X509NamePattern'] = [] self._x509_revoked_subjects: Sequence['X509NamePattern'] = [] self._x509_purposes = options.x509_purposes self._kex_algs = options.kex_algs self._enc_algs = options.encryption_algs self._mac_algs = options.mac_algs self._cmp_algs = options.compression_algs self._sig_algs = options.signature_algs self._host_based_auth = options.host_based_auth self._public_key_auth = options.public_key_auth self._kbdint_auth = options.kbdint_auth self._password_auth = options.password_auth self._kex: Optional[Kex] = None self._kexinit_sent = False self._kex_complete = False self._ignore_first_kex = False self._strict_kex = False self._gss: Optional[GSSBase] = None self._gss_kex = False self._gss_auth = False self._gss_kex_auth = False self._gss_mic_auth = False self._preferred_auth: Optional[Sequence[bytes]] = None self._rekey_bytes = options.rekey_bytes self._rekey_seconds = options.rekey_seconds self._rekey_bytes_sent = 0 self._rekey_time = 0. self._keepalive_count = 0 self._keepalive_count_max = options.keepalive_count_max self._keepalive_interval = options.keepalive_interval self._keepalive_timer: Optional[asyncio.TimerHandle] = None self._tunnel: Optional[_TunnelProtocol] = None self._enc_alg_cs = b'' self._enc_alg_sc = b'' self._mac_alg_cs = b'' self._mac_alg_sc = b'' self._cmp_alg_cs = b'' self._cmp_alg_sc = b'' self._can_send_ext_info = False self._extensions_to_send: 'OrderedDict[bytes, bytes]' = OrderedDict() self._can_recv_ext_info = False self._server_sig_algs: Set[bytes] = set() self._next_service: Optional[bytes] = None self._agent: Optional[SSHAgentClient] = None self._auth: Optional[Auth] = None self._auth_in_progress = False self._auth_complete = False self._auth_final = False self._auth_methods = [b'none'] self._auth_was_trivial = True self._username = '' self._channels: Dict[int, SSHChannel] = {} self._next_recv_chan = 0 self._global_request_queue: List[_GlobalRequest] = [] self._global_request_waiters: \ 'List[asyncio.Future[_GlobalRequestResult]]' = [] self._local_listeners: Dict[ListenKey, SSHListener] = {} self._x11_listener: Union[None, SSHX11ClientListener, SSHX11ServerListener] = None self._tasks: Set[asyncio.Task[None]] = set() self._close_event = asyncio.Event() self._server_host_key_algs: Optional[Sequence[bytes]] = None self._logger = logger.get_child( context=f'conn={self._get_next_conn()}') self._login_timer: Optional[asyncio.TimerHandle] if options.login_timeout: self._login_timer = self._loop.call_later( options.login_timeout, self._login_timer_callback) else: self._login_timer = None self._disable_trivial_auth = False async def __aenter__(self) -> Self: """Allow SSHConnection to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: """Wait for connection close when used as an async context manager""" if not self._loop.is_closed(): # pragma: no branch self.close() await self.wait_closed() return False @property def logger(self) -> SSHLogger: """A logger associated with this connection""" return self._logger def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this connection""" self._cancel_keepalive_timer() for chan in list(self._channels.values()): chan.process_connection_close(exc) for listener in list(self._local_listeners.values()): listener.close() while self._global_request_waiters: self._process_global_response(MSG_REQUEST_FAILURE, 0, SSHPacket(b'')) if self._auth: self._auth.cancel() self._auth = None if self._error_handler: self._error_handler(self, exc) self._acceptor = None self._error_handler = None if self._wait and self._waiter and not self._waiter.cancelled(): if exc: self._waiter.set_exception(exc) else: # pragma: no cover self._waiter.set_result(None) self._wait = None if self._owner: # pragma: no branch self._owner.connection_lost(exc) self._owner = None self._cancel_login_timer() self._close_event.set() self._inpbuf = b'' if self._tunnel: self._tunnel.close() self._tunnel = None def _cancel_login_timer(self) -> None: """Cancel the login timer""" if self._login_timer: self._login_timer.cancel() self._login_timer = None def _login_timer_callback(self) -> None: """Close the connection if authentication hasn't completed yet""" self._login_timer = None self.connection_lost(ConnectionLost('Login timeout expired')) def _cancel_keepalive_timer(self) -> None: """Cancel the keepalive timer""" if self._keepalive_timer: self._keepalive_timer.cancel() self._keepalive_timer = None def _set_keepalive_timer(self) -> None: """Set the keepalive timer""" if self._keepalive_interval: self._keepalive_timer = self._loop.call_later( self._keepalive_interval, self._keepalive_timer_callback) def _reset_keepalive_timer(self) -> None: """Reset the keepalive timer""" if self._auth_complete: self._cancel_keepalive_timer() self._set_keepalive_timer() async def _make_keepalive_request(self) -> None: """Send keepalive request""" self.logger.debug1('Sending keepalive request') await self._make_global_request(b'keepalive@openssh.com') if self._keepalive_timer: self.logger.debug1('Got keepalive response') self._keepalive_count = 0 def _keepalive_timer_callback(self) -> None: """Handle keepalive check""" self._keepalive_count += 1 if self._keepalive_count > self._keepalive_count_max: self.connection_lost( ConnectionLost(('Server' if self.is_client() else 'Client') + ' not responding to keepalive')) else: self._set_keepalive_timer() self.create_task(self._make_keepalive_request()) def _force_close(self, exc: Optional[Exception]) -> None: """Force this connection to close immediately""" if not self._transport: return self._loop.call_soon(self._transport.abort) self._transport = None self._loop.call_soon(self._cleanup, exc) def _reap_task(self, task_logger: Optional[SSHLogger], task: 'asyncio.Task[None]') -> None: """Collect result of an async task, reporting errors""" self._tasks.discard(task) # pylint: disable=broad-except try: task.result() except asyncio.CancelledError: pass except DisconnectError as exc: self._send_disconnect(exc.code, exc.reason, exc.lang) self._force_close(exc) except Exception: self.internal_error(error_logger=task_logger) def create_task(self, coro: Awaitable[None], task_logger: Optional[SSHLogger] = None) -> \ 'asyncio.Task[None]': """Create an asynchronous task which catches and reports errors""" task = asyncio.ensure_future(coro) task.add_done_callback(partial(self._reap_task, task_logger)) self._tasks.add(task) return task def is_client(self) -> bool: """Return if this is a client connection""" return not self._server def is_server(self) -> bool: """Return if this is a server connection""" return self._server def is_closed(self): """Return whether the connection is closed""" return self._close_event.is_set() def get_owner(self) -> Optional[Union[SSHClient, SSHServer]]: """Return the SSHClient or SSHServer which owns this connection""" return self._owner def get_hash_prefix(self) -> bytes: """Return the bytes used in calculating unique connection hashes This methods returns a packetized version of the client and server version and kexinit strings which is needed to perform key exchange hashes. """ return b''.join((String(self._client_version), String(self._server_version), String(self._client_kexinit), String(self._server_kexinit))) def set_tunnel(self, tunnel: Optional[_TunnelProtocol]) -> None: """Set tunnel used to open this connection""" self._tunnel = tunnel def _match_known_hosts(self, known_hosts: KnownHostsArg, host: str, addr: str, port: Optional[int]) -> None: """Determine the set of trusted host keys and certificates""" trusted_host_keys, trusted_ca_keys, revoked_host_keys, \ trusted_x509_certs, revoked_x509_certs, \ trusted_x509_subjects, revoked_x509_subjects = \ match_known_hosts(known_hosts, host, addr, port) assert self._trusted_host_keys is not None for key in trusted_host_keys: self._trusted_host_keys.add(key) if key.algorithm not in self._trusted_host_key_algs: self._trusted_host_key_algs.extend(key.sig_algorithms) self._trusted_ca_keys = set(trusted_ca_keys) self._revoked_host_keys = set(revoked_host_keys) if self._x509_trusted_certs is not None: self._x509_trusted_certs = list(self._x509_trusted_certs) self._x509_trusted_certs.extend(trusted_x509_certs) self._x509_revoked_certs = set(revoked_x509_certs) self._x509_trusted_subjects = trusted_x509_subjects self._x509_revoked_subjects = revoked_x509_subjects def _validate_openssh_host_certificate( self, host: str, addr: str, port: int, cert: SSHOpenSSHCertificate) -> SSHKey: """Validate an OpenSSH host certificate""" if self._trusted_ca_keys is not None: if cert.signing_key in self._revoked_host_keys: raise ValueError('Host CA key is revoked') if not self._owner: # pragma: no cover raise ValueError('Connection closed') if cert.signing_key not in self._trusted_ca_keys and \ not self._owner.validate_host_ca_key(host, addr, port, cert.signing_key): raise ValueError('Host CA key is not trusted') cert.validate(CERT_TYPE_HOST, host) return cert.key def _validate_x509_host_certificate_chain( self, host: str, cert: SSHX509CertificateChain) -> SSHKey: """Validate an X.509 host certificate""" if (self._x509_revoked_subjects and any(pattern.matches(cert.subject) for pattern in self._x509_revoked_subjects)): raise ValueError('X.509 subject name is revoked') if (self._x509_trusted_subjects and not any(pattern.matches(cert.subject) for pattern in self._x509_trusted_subjects)): raise ValueError('X.509 subject name is not trusted') # Only validate hostname against X.509 certificate host # principals when there are no X.509 trusted subject # entries matched in known_hosts. if self._x509_trusted_subjects: host = '' assert self._x509_trusted_certs is not None cert.validate_chain(self._x509_trusted_certs, self._x509_trusted_cert_paths, self._x509_revoked_certs, self._x509_purposes, host_principal=host) return cert.key def _validate_host_key(self, host: str, addr: str, port: int, key_data: bytes) -> SSHKey: """Validate and return a trusted host key""" try: cert = decode_ssh_certificate(key_data) except KeyImportError: pass else: if cert.is_x509_chain: return self._validate_x509_host_certificate_chain( host, cast(SSHX509CertificateChain, cert)) else: return self._validate_openssh_host_certificate( host, addr, port, cast(SSHOpenSSHCertificate, cert)) try: key = decode_ssh_public_key(key_data) except KeyImportError: pass else: if self._trusted_host_keys is not None: if key in self._revoked_host_keys: raise ValueError('Host key is revoked') if not self._owner: # pragma: no cover raise ValueError('Connection closed') if key not in self._trusted_host_keys and \ not self._owner.validate_host_public_key(host, addr, port, key): raise ValueError('Host key is not trusted') return key raise ValueError('Unable to decode host key') def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened connection""" self._transport = cast(asyncio.Transport, transport) sock = cast(socket.socket, transport.get_extra_info('socket')) if sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1 if self._tcp_keepalive else 0) if sock.family in (socket.AF_INET, socket.AF_INET6): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sockname = cast(SockAddr, transport.get_extra_info('sockname')) if sockname: # pragma: no branch self._local_addr, self._local_port = sockname[:2] peername = cast(SockAddr, transport.get_extra_info('peername')) if peername: # pragma: no branch self._peer_addr, self._peer_port = peername[:2] self._owner = self._protocol_factory() # pylint: disable=broad-except try: self._connection_made() self._owner.connection_made(self) # type: ignore self._send_version() except Exception: self._loop.call_soon(self.internal_error, sys.exc_info()) def connection_lost(self, exc: Optional[Exception] = None) -> None: """Handle the closing of a connection""" if exc is None and self._transport: exc = ConnectionLost('Connection lost') self._force_close(exc) def internal_error(self, exc_info: Optional[OptExcInfo] = None, error_logger: Optional[SSHLogger] = None) -> None: """Handle a fatal error in connection processing""" if not exc_info: exc_info = sys.exc_info() if not error_logger: error_logger = self.logger error_logger.debug1('Uncaught exception', exc_info=exc_info) self._force_close(cast(Exception, exc_info[1])) def session_started(self) -> None: """Handle session start when opening tunneled SSH connection""" # pylint: disable=arguments-differ def data_received(self, data: bytes, datatype: DataType = None) -> None: """Handle incoming data on the connection""" # pylint: disable=unused-argument self._inpbuf += data self._recv_data() # pylint: enable=arguments-differ def eof_received(self) -> None: """Handle an incoming end of file on the connection""" self.connection_lost(None) def pause_writing(self) -> None: """Handle a request from the transport to pause writing data""" # Do nothing with this for now def resume_writing(self) -> None: """Handle a request from the transport to resume writing data""" # Do nothing with this for now def add_channel(self, chan: SSHChannel[AnyStr]) -> int: """Add a new channel, returning its channel number""" if not self._transport: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'SSH connection closed') while self._next_recv_chan in self._channels: # pragma: no cover self._next_recv_chan = (self._next_recv_chan + 1) & 0xffffffff recv_chan = self._next_recv_chan self._next_recv_chan = (self._next_recv_chan + 1) & 0xffffffff self._channels[recv_chan] = chan return recv_chan def remove_channel(self, recv_chan: int) -> None: """Remove the channel with the specified channel number""" del self._channels[recv_chan] def get_gss_context(self) -> GSSBase: """Return the GSS context associated with this connection""" assert self._gss is not None return self._gss def enable_gss_kex_auth(self) -> None: """Enable GSS key exchange authentication""" self._gss_kex_auth = self._gss_auth def _choose_alg(self, alg_type: str, local_algs: Sequence[bytes], remote_algs: Sequence[bytes]) -> bytes: """Choose a common algorithm from the client & server lists This method returns the earliest algorithm on the client's list which is supported by the server. """ if self.is_client(): client_algs, server_algs = local_algs, remote_algs else: client_algs, server_algs = remote_algs, local_algs for alg in client_algs: if alg in server_algs: return alg raise KeyExchangeFailed( f'No matching {alg_type} algorithm found, sent ' f'{b",".join(local_algs).decode("ascii")} and received ' f'{b",".join(remote_algs).decode("ascii")}') def _get_extra_kex_algs(self) -> List[bytes]: """Return the extra kex algs to add""" if self.is_client(): return [b'ext-info-c', b'kex-strict-c-v00@openssh.com'] else: return [b'ext-info-s', b'kex-strict-s-v00@openssh.com'] def _send(self, data: bytes) -> None: """Send data to the SSH connection""" if self._transport: try: self._transport.write(data) except ConnectionError: # pragma: no cover pass def _send_version(self) -> None: """Start the SSH handshake""" version = b'SSH-2.0-' + self._version self.logger.debug1('Sending version %s', version) if self.is_client(): self._client_version = version self.set_extra_info(client_version=version.decode('ascii')) else: self._server_version = version self.set_extra_info(server_version=version.decode('ascii')) self._send(version + b'\r\n') def _recv_data(self) -> None: """Parse received data""" self._reset_keepalive_timer() # pylint: disable=broad-except try: while self._inpbuf and self._recv_handler(): pass except DisconnectError as exc: self._send_disconnect(exc.code, exc.reason, exc.lang) self._force_close(exc) except Exception: self.internal_error() def _recv_version(self) -> bool: """Receive and parse the remote SSH version""" idx = self._inpbuf.find(b'\n', 0, _MAX_BANNER_LINE_LEN) if idx < 0: if len(self._inpbuf) >= _MAX_BANNER_LINE_LEN: self._force_close(ProtocolError('Banner line too long')) return False version = self._inpbuf[:idx] if version.endswith(b'\r'): version = version[:-1] self._inpbuf = self._inpbuf[idx+1:] if (version.startswith(b'SSH-2.0-') or (self.is_client() and version.startswith(b'SSH-1.99-'))): if len(version) > _MAX_VERSION_LINE_LEN: self._force_close(ProtocolError('Version too long')) # Accept version 2.0, or 1.99 if we're a client if self.is_server(): self._client_version = version self.set_extra_info(client_version=version.decode('ascii')) else: self._server_version = version self.set_extra_info(server_version=version.decode('ascii')) self.logger.debug1('Received version %s', version) self._send_kexinit() self._kexinit_sent = True self._recv_handler = self._recv_pkthdr elif self.is_client() and not version.startswith(b'SSH-'): # As a client, ignore the line if it doesn't appear to be a version self._banner_lines += 1 if self._banner_lines > _MAX_BANNER_LINES: self._force_close(ProtocolError('Too many banner lines')) return False else: # Otherwise, reject the unknown version self._force_close(ProtocolNotSupported('Unsupported SSH version')) return False return True def _recv_pkthdr(self) -> bool: """Receive and parse an SSH packet header""" if len(self._inpbuf) < self._recv_blocksize: return False self._packet = self._inpbuf[:self._recv_blocksize] self._inpbuf = self._inpbuf[self._recv_blocksize:] if self._recv_encryption: self._packet, pktlen = \ self._recv_encryption.decrypt_header(self._recv_seq, self._packet, 4) else: pktlen = self._packet[:4] self._pktlen = int.from_bytes(pktlen, 'big') self._recv_handler = self._recv_packet return True def _recv_packet(self) -> bool: """Receive the remainder of an SSH packet and process it""" rem = 4 + self._pktlen + self._recv_macsize - self._recv_blocksize if len(self._inpbuf) < rem: return False seq = self._recv_seq rest = self._inpbuf[:rem-self._recv_macsize] mac = self._inpbuf[rem-self._recv_macsize:rem] if self._recv_encryption: packet_data = self._recv_encryption.decrypt_packet( seq, self._packet, rest, 4, mac) if not packet_data: raise MACError('MAC verification failed') else: packet_data = self._packet[4:] + rest self._inpbuf = self._inpbuf[rem:] self._packet = b'' orig_payload = packet_data[1:-packet_data[0]] if self._decompressor and (self._auth_complete or not self._decompress_after_auth): payload = self._decompressor.decompress(orig_payload) if payload is None: raise CompressionError('Decompression failed') else: payload = orig_payload packet = SSHPacket(payload) pkttype = packet.get_byte() handler: SSHPacketHandler = self skip_reason = '' exc_reason = '' if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: if self._kex: if self._ignore_first_kex: # pragma: no cover skip_reason = 'ignored first kex' self._ignore_first_kex = False else: handler = self._kex else: skip_reason = 'kex not in progress' exc_reason = 'Key exchange not in progress' elif self._strict_kex and not self._recv_encryption and \ MSG_IGNORE <= pkttype <= MSG_DEBUG: skip_reason = 'strict kex violation' exc_reason = 'Strict key exchange violation: ' \ f'unexpected packet type {pkttype} received' elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST: if self._auth: handler = self._auth else: skip_reason = 'auth not in progress' exc_reason = 'Authentication not in progress' elif pkttype > MSG_KEX_LAST and not self._recv_encryption: skip_reason = 'invalid request before kex complete' exc_reason = 'Invalid request before key exchange was complete' elif pkttype > MSG_USERAUTH_LAST and not self._auth_complete: skip_reason = 'invalid request before auth complete' exc_reason = 'Invalid request before authentication was complete' elif MSG_CHANNEL_FIRST <= pkttype <= MSG_CHANNEL_LAST: try: recv_chan = packet.get_uint32() except PacketDecodeError: skip_reason = 'incomplete channel request' exc_reason = 'Incomplete channel request received' else: try: handler = self._channels[recv_chan] except KeyError: skip_reason = 'invalid channel number' exc_reason = f'Invalid channel number {recv_chan} received' handler.log_received_packet(pkttype, seq, packet, skip_reason) if not skip_reason: try: result = handler.process_packet(pkttype, seq, packet) except PacketDecodeError as exc: raise ProtocolError(str(exc)) from None if inspect.isawaitable(result): # Buffer received data until current packet is processed self._recv_handler = lambda: False task = self.create_task(result) task.add_done_callback(functools.partial( self._finish_recv_packet, pkttype, seq, is_async=True)) return False elif not result: if self._strict_kex and not self._recv_encryption: exc_reason = 'Strict key exchange violation: ' \ f'unexpected packet type {pkttype} received' else: self.logger.debug1('Unknown packet type %d received', pkttype) self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq)) if exc_reason: raise ProtocolError(exc_reason) self._finish_recv_packet(pkttype, seq) return True def _finish_recv_packet(self, pkttype: int, seq: int, _task: Optional[asyncio.Task] = None, is_async: bool = False) -> None: """Finish processing a packet""" if pkttype > MSG_USERAUTH_LAST: self._auth_final = True if self._transport: if self._recv_seq == 0xffffffff and not self._recv_encryption: raise ProtocolError('Sequence rollover before kex complete') if pkttype == MSG_NEWKEYS and self._strict_kex: self._recv_seq = 0 else: self._recv_seq = (seq + 1) & 0xffffffff self._recv_handler = self._recv_pkthdr if is_async and self._inpbuf: self._recv_data() def send_packet(self, pkttype: int, *args: bytes, handler: Optional[SSHPacketLogger] = None) -> None: """Send an SSH packet""" if (self._auth_complete and self._kex_complete and (self._rekey_bytes_sent >= self._rekey_bytes or (self._rekey_seconds and time.monotonic() >= self._rekey_time))): self._send_kexinit() self._kexinit_sent = True if (((pkttype in {MSG_DEBUG, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or pkttype > MSG_KEX_LAST) and not self._kex_complete) or (pkttype == MSG_USERAUTH_BANNER and not (self._auth_in_progress or self._auth_complete)) or (pkttype > MSG_USERAUTH_LAST and not self._auth_complete)): self._deferred_packets.append((pkttype, args)) return # If we're encrypting and we have no data outstanding, insert an # ignore packet into the stream if self._send_encryption and pkttype > MSG_KEX_LAST: self.send_packet(MSG_IGNORE, String(b'')) orig_payload = Byte(pkttype) + b''.join(args) if self._compressor and (self._auth_complete or not self._compress_after_auth): payload = self._compressor.compress(orig_payload) if payload is None: # pragma: no cover raise CompressionError('Compression failed') else: payload = orig_payload padlen = -(self._send_enchdrlen + len(payload)) % self._send_blocksize if padlen < 4: padlen += self._send_blocksize packet = Byte(padlen) + payload + os.urandom(padlen) pktlen = len(packet) hdr = UInt32(pktlen) seq = self._send_seq if self._send_encryption: packet, mac = self._send_encryption.encrypt_packet(seq, hdr, packet) else: packet = hdr + packet mac = b'' self._send(packet + mac) if self._send_seq == 0xffffffff and not self._send_encryption: self._send_seq = 0 raise ProtocolError('Sequence rollover before kex complete') if pkttype == MSG_NEWKEYS and self._strict_kex: self._send_seq = 0 else: self._send_seq = (seq + 1) & 0xffffffff if self._kex_complete: self._rekey_bytes_sent += pktlen if not handler: handler = self handler.log_sent_packet(pkttype, seq, orig_payload) def _send_deferred_packets(self) -> None: """Send packets deferred due to key exchange or auth""" deferred_packets = self._deferred_packets self._deferred_packets = [] for pkttype, args in deferred_packets: self.send_packet(pkttype, *args) def _send_disconnect(self, code: int, reason: str, lang: str) -> None: """Send a disconnect packet""" self.logger.info('Sending disconnect: %s (%d)', reason, code) self.send_packet(MSG_DISCONNECT, UInt32(code), String(reason), String(lang)) def _send_kexinit(self) -> None: """Start a key exchange""" self._kex_complete = False self._rekey_bytes_sent = 0 if self._rekey_seconds: self._rekey_time = time.monotonic() + self._rekey_seconds if self._gss_kex: assert self._gss is not None gss_mechs = self._gss.mechs else: gss_mechs = [] kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) + \ self._get_extra_kex_algs() host_key_algs = self._server_host_key_algs or [b'null'] self.logger.debug1('Requesting key exchange') self.logger.debug2(' Key exchange algs: %s', kex_algs) self.logger.debug2(' Host key algs: %s', host_key_algs) self.logger.debug2(' Encryption algs: %s', self._enc_algs) self.logger.debug2(' MAC algs: %s', self._mac_algs) self.logger.debug2(' Compression algs: %s', self._cmp_algs) cookie = os.urandom(16) kex_algs = NameList(kex_algs) host_key_algs = NameList(host_key_algs) enc_algs = NameList(self._enc_algs) mac_algs = NameList(self._mac_algs) cmp_algs = NameList(self._cmp_algs) langs = NameList([]) packet = b''.join((Byte(MSG_KEXINIT), cookie, kex_algs, host_key_algs, enc_algs, enc_algs, mac_algs, mac_algs, cmp_algs, cmp_algs, langs, langs, Boolean(False), UInt32(0))) if self.is_server(): self._server_kexinit = packet else: self._client_kexinit = packet self.send_packet(MSG_KEXINIT, packet[1:]) def _send_ext_info(self) -> None: """Send extension information""" packet = UInt32(len(self._extensions_to_send)) self.logger.debug2('Sending extension info') for name, value in self._extensions_to_send.items(): packet += String(name) + String(value) self.logger.debug2(' %s: %s', name, value) self.send_packet(MSG_EXT_INFO, packet) def send_newkeys(self, k: bytes, h: bytes) -> None: """Finish a key exchange and send a new keys message""" if not self._session_id: first_kex = True self._session_id = h else: first_kex = False enc_keysize_cs, enc_ivsize_cs, enc_blocksize_cs, \ mac_keysize_cs, mac_hashsize_cs, etm_cs = \ get_encryption_params(self._enc_alg_cs, self._mac_alg_cs) enc_keysize_sc, enc_ivsize_sc, enc_blocksize_sc, \ mac_keysize_sc, mac_hashsize_sc, etm_sc = \ get_encryption_params(self._enc_alg_sc, self._mac_alg_sc) if mac_keysize_cs == 0: self._mac_alg_cs = self._enc_alg_cs if mac_keysize_sc == 0: self._mac_alg_sc = self._enc_alg_sc cmp_after_auth_cs = get_compression_params(self._cmp_alg_cs) cmp_after_auth_sc = get_compression_params(self._cmp_alg_sc) self.logger.debug2(' Client to server:') self.logger.debug2(' Encryption alg: %s', self._enc_alg_cs) self.logger.debug2(' MAC alg: %s', self._mac_alg_cs) self.logger.debug2(' Compression alg: %s', self._cmp_alg_cs) self.logger.debug2(' Server to client:') self.logger.debug2(' Encryption alg: %s', self._enc_alg_sc) self.logger.debug2(' MAC alg: %s', self._mac_alg_sc) self.logger.debug2(' Compression alg: %s', self._cmp_alg_sc) assert self._kex is not None iv_cs = self._kex.compute_key(k, h, b'A', self._session_id, enc_ivsize_cs) iv_sc = self._kex.compute_key(k, h, b'B', self._session_id, enc_ivsize_sc) enc_key_cs = self._kex.compute_key(k, h, b'C', self._session_id, enc_keysize_cs) enc_key_sc = self._kex.compute_key(k, h, b'D', self._session_id, enc_keysize_sc) mac_key_cs = self._kex.compute_key(k, h, b'E', self._session_id, mac_keysize_cs) mac_key_sc = self._kex.compute_key(k, h, b'F', self._session_id, mac_keysize_sc) self._kex = None next_enc_cs = get_encryption(self._enc_alg_cs, enc_key_cs, iv_cs, self._mac_alg_cs, mac_key_cs, etm_cs) next_enc_sc = get_encryption(self._enc_alg_sc, enc_key_sc, iv_sc, self._mac_alg_sc, mac_key_sc, etm_sc) self.send_packet(MSG_NEWKEYS) self._extensions_to_send[b'global-requests-ok'] = b'' if self.is_client(): self._send_encryption = next_enc_cs self._send_enchdrlen = 1 if etm_cs else 5 self._send_blocksize = max(8, enc_blocksize_cs) self._compressor = get_compressor(self._cmp_alg_cs) self._compress_after_auth = cmp_after_auth_cs self._next_recv_encryption = next_enc_sc self._next_recv_blocksize = max(8, enc_blocksize_sc) self._next_recv_macsize = mac_hashsize_sc self._next_decompressor = get_decompressor(self._cmp_alg_sc) self._next_decompress_after_auth = cmp_after_auth_sc self.set_extra_info( send_cipher=self._enc_alg_cs.decode('ascii'), send_mac=self._mac_alg_cs.decode('ascii'), send_compression=self._cmp_alg_cs.decode('ascii'), recv_cipher=self._enc_alg_sc.decode('ascii'), recv_mac=self._mac_alg_sc.decode('ascii'), recv_compression=self._cmp_alg_sc.decode('ascii')) if first_kex: if self._wait == 'kex' and self._waiter and \ not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None return else: self._extensions_to_send[b'server-sig-algs'] = \ b','.join(self._sig_algs) self._send_encryption = next_enc_sc self._send_enchdrlen = 1 if etm_sc else 5 self._send_blocksize = max(8, enc_blocksize_sc) self._compressor = get_compressor(self._cmp_alg_sc) self._compress_after_auth = cmp_after_auth_sc self._next_recv_encryption = next_enc_cs self._next_recv_blocksize = max(8, enc_blocksize_cs) self._next_recv_macsize = mac_hashsize_cs self._next_decompressor = get_decompressor(self._cmp_alg_cs) self._next_decompress_after_auth = cmp_after_auth_cs self.set_extra_info( send_cipher=self._enc_alg_sc.decode('ascii'), send_mac=self._mac_alg_sc.decode('ascii'), send_compression=self._cmp_alg_sc.decode('ascii'), recv_cipher=self._enc_alg_cs.decode('ascii'), recv_mac=self._mac_alg_cs.decode('ascii'), recv_compression=self._cmp_alg_cs.decode('ascii')) if self._can_send_ext_info: self._send_ext_info() self._can_send_ext_info = False self._kex_complete = True if first_kex: if self.is_client(): self.send_service_request(_USERAUTH_SERVICE) else: self._next_service = _USERAUTH_SERVICE self._send_deferred_packets() def send_service_request(self, service: bytes) -> None: """Send a service request""" self.logger.debug2('Requesting service %s', service) self._next_service = service self.send_packet(MSG_SERVICE_REQUEST, String(service)) def _get_userauth_request_packet(self, method: bytes, args: Tuple[bytes, ...]) -> bytes: """Get packet data for a user authentication request""" return b''.join((Byte(MSG_USERAUTH_REQUEST), String(self._username), String(_CONNECTION_SERVICE), String(method)) + args) def get_userauth_request_data(self, method: bytes, *args: bytes) -> bytes: """Get signature data for a user authentication request""" return (String(self._session_id) + self._get_userauth_request_packet(method, args)) def send_userauth_packet(self, pkttype: int, *args: bytes, handler: Optional[SSHPacketLogger] = None, trivial: bool = True) -> None: """Send a user authentication packet""" self._auth_was_trivial &= trivial self.send_packet(pkttype, *args, handler=handler) async def send_userauth_request(self, method: bytes, *args: bytes, key: Optional[SigningKey] = None, trivial: bool = True) -> None: """Send a user authentication request""" packet = self._get_userauth_request_packet(method, args) if key: data = String(self._session_id) + packet sign_async: Optional[Callable[[bytes], Awaitable[bytes]]] = \ getattr(key, 'sign_async', None) if sign_async: # pylint: disable=not-callable sig = await sign_async(data) elif getattr(key, 'use_executor', False): sig = await self._loop.run_in_executor(None, key.sign, data) else: sig = key.sign(data) packet += String(sig) self.send_userauth_packet(MSG_USERAUTH_REQUEST, packet[1:], trivial=trivial) def send_userauth_failure(self, partial_success: bool) -> None: """Send a user authentication failure response""" methods = get_supported_server_auth_methods( cast(SSHServerConnection, self)) self.logger.debug2('Remaining auth methods: %s', methods or 'None') self._auth = None self.send_packet(MSG_USERAUTH_FAILURE, NameList(methods), Boolean(partial_success)) def send_userauth_success(self) -> None: """Send a user authentication success response""" self.logger.info('Auth for user %s succeeded', self._username) self.send_packet(MSG_USERAUTH_SUCCESS) self._auth = None self._auth_in_progress = False self._auth_complete = True self._next_service = None self.set_extra_info(username=self._username) self._send_deferred_packets() self._cancel_login_timer() self._set_keepalive_timer() if self._owner: # pragma: no branch self._owner.auth_completed() if self._acceptor: result = self._acceptor(self) if inspect.isawaitable(result): assert result is not None self.create_task(result) self._acceptor = None self._error_handler = None if self._wait == 'auth' and self._waiter and \ not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None return # This method is only in SSHServerConnection # pylint: disable=no-member cast(SSHServerConnection, self).send_server_host_keys() def send_channel_open_confirmation(self, send_chan: int, recv_chan: int, recv_window: int, recv_pktsize: int, *result_args: bytes) -> None: """Send a channel open confirmation""" self.send_packet(MSG_CHANNEL_OPEN_CONFIRMATION, UInt32(send_chan), UInt32(recv_chan), UInt32(recv_window), UInt32(recv_pktsize), *result_args) def send_channel_open_failure(self, send_chan: int, code: int, reason: str, lang: str) -> None: """Send a channel open failure""" self.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(send_chan), UInt32(code), String(reason), String(lang)) def _send_global_request(self, request: bytes, *args: bytes, want_reply: bool = False) -> None: """Send a global request""" self.send_packet(MSG_GLOBAL_REQUEST, String(request), Boolean(want_reply), *args) async def _make_global_request(self, request: bytes, *args: bytes) -> Tuple[int, SSHPacket]: """Send a global request and wait for the response""" if not self._transport: return MSG_REQUEST_FAILURE, SSHPacket(b'') waiter: 'asyncio.Future[_GlobalRequestResult]' = \ self._loop.create_future() self._global_request_waiters.append(waiter) self._send_global_request(request, *args, want_reply=True) return await waiter def _report_global_response(self, result: Union[bool, bytes]) -> None: """Report back the response to a previously issued global request""" _, _, want_reply = self._global_request_queue.pop(0) if want_reply: # pragma: no branch if result: response = b'' if result is True else cast(bytes, result) self.send_packet(MSG_REQUEST_SUCCESS, response) else: self.send_packet(MSG_REQUEST_FAILURE) if self._global_request_queue: self._service_next_global_request() def _service_next_global_request(self) -> None: """Process next item on global request queue""" handler, packet, _ = self._global_request_queue[0] if callable(handler): handler(packet) else: self._report_global_response(False) def _connection_made(self) -> None: """Handle the opening of a new connection""" raise NotImplementedError def _process_disconnect(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a disconnect message""" code = packet.get_uint32() reason_bytes = packet.get_string() lang_bytes = packet.get_string() packet.check_end() try: reason = reason_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid disconnect message') from None self.logger.debug1('Received disconnect: %s (%d)', reason, code) if code != DISC_BY_APPLICATION or self._wait: exc: Optional[Exception] = construct_disc_error(code, reason, lang) else: exc = None self._force_close(exc) def _process_ignore(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an ignore message""" # Work around missing payload bytes in an ignore message # in some Cisco SSH servers if b'Cisco' not in self._server_version: # pragma: no branch _ = packet.get_string() # data packet.check_end() def _process_unimplemented(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an unimplemented message response""" # pylint: disable=no-self-use _ = packet.get_uint32() # seq packet.check_end() def _process_debug(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a debug message""" always_display = packet.get_boolean() msg_bytes = packet.get_string() lang_bytes = packet.get_string() packet.check_end() try: msg = msg_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid debug message') from None self.logger.debug1('Received debug message: %s%s', msg, ' (always display)' if always_display else '') if self._owner: # pragma: no branch self._owner.debug_msg_received(msg, lang, always_display) def _process_service_request(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a service request""" service = packet.get_string() packet.check_end() if self.is_client(): raise ProtocolError('Unexpected service request received') if not self._recv_encryption: raise ProtocolError('Service request received before kex complete') if service != self._next_service: raise ServiceNotAvailable('Unexpected service in service request') self.logger.debug2('Accepting request for service %s', service) self.send_packet(MSG_SERVICE_ACCEPT, String(service)) self._next_service = None if service == _USERAUTH_SERVICE: # pragma: no branch self._auth_in_progress = True self._can_recv_ext_info = False self._send_deferred_packets() def _process_service_accept(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a service accept response""" service = packet.get_string() packet.check_end() if self.is_server(): raise ProtocolError('Unexpected service accept received') if not self._recv_encryption: raise ProtocolError('Service accept received before kex complete') if service != self._next_service: raise ServiceNotAvailable('Unexpected service in service accept') self.logger.debug2('Request for service %s accepted', service) self._next_service = None if service == _USERAUTH_SERVICE: # pragma: no branch self.logger.info('Beginning auth for user %s', self._username) self._auth_in_progress = True if self._owner: # pragma: no branch self._owner.begin_auth(self._username) # This method is only in SSHClientConnection # pylint: disable=no-member cast('SSHClientConnection', self).try_next_auth() def _process_ext_info(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process extension information""" if not self._can_recv_ext_info: raise ProtocolError('Unexpected ext_info received') extensions: Dict[bytes, bytes] = {} self.logger.debug2('Received extension info') num_extensions = packet.get_uint32() for _ in range(num_extensions): name = packet.get_string() value = packet.get_string() extensions[name] = value self.logger.debug2(' %s: %s', name, value) packet.check_end() if self.is_client(): self._server_sig_algs = \ set(extensions.get(b'server-sig-algs', b'').split(b',')) async def _process_kexinit(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a key exchange request""" if self._kex: raise ProtocolError('Key exchange already in progress') _ = packet.get_bytes(16) # cookie peer_kex_algs = packet.get_namelist() peer_host_key_algs = packet.get_namelist() enc_algs_cs = packet.get_namelist() enc_algs_sc = packet.get_namelist() mac_algs_cs = packet.get_namelist() mac_algs_sc = packet.get_namelist() cmp_algs_cs = packet.get_namelist() cmp_algs_sc = packet.get_namelist() _ = packet.get_namelist() # lang_cs _ = packet.get_namelist() # lang_sc first_kex_follows = packet.get_boolean() _ = packet.get_uint32() # reserved packet.check_end() if self.is_server(): self._client_kexinit = packet.get_consumed_payload() if not self._session_id: if b'ext-info-c' in peer_kex_algs: self._can_send_ext_info = True if b'kex-strict-c-v00@openssh.com' in peer_kex_algs: self._strict_kex = True else: self._server_kexinit = packet.get_consumed_payload() if not self._session_id: if b'ext-info-s' in peer_kex_algs: self._can_send_ext_info = True if b'kex-strict-s-v00@openssh.com' in peer_kex_algs: self._strict_kex = True if self._strict_kex and not self._recv_encryption and \ self._recv_seq != 0: raise ProtocolError('Strict key exchange violation: ' 'KEXINIT was not the first packet') if self._kexinit_sent: self._kexinit_sent = False else: self._send_kexinit() if self._gss: self._gss.reset() if self._gss_kex: assert self._gss is not None gss_mechs = self._gss.mechs else: gss_mechs = [] kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) self.logger.debug1('Received key exchange request') self.logger.debug2(' Key exchange algs: %s', peer_kex_algs) self.logger.debug2(' Host key algs: %s', peer_host_key_algs) self.logger.debug2(' Client to server:') self.logger.debug2(' Encryption algs: %s', enc_algs_cs) self.logger.debug2(' MAC algs: %s', mac_algs_cs) self.logger.debug2(' Compression algs: %s', cmp_algs_cs) self.logger.debug2(' Server to client:') self.logger.debug2(' Encryption algs: %s', enc_algs_sc) self.logger.debug2(' MAC algs: %s', mac_algs_sc) self.logger.debug2(' Compression algs: %s', cmp_algs_sc) kex_alg = self._choose_alg('key exchange', kex_algs, peer_kex_algs) self._kex = get_kex(self, kex_alg) self._ignore_first_kex = (first_kex_follows and self._kex.algorithm != peer_kex_algs[0]) if self.is_server(): # This method is only in SSHServerConnection # pylint: disable=no-member if (not cast(SSHServerConnection, self).choose_server_host_key( peer_host_key_algs) and not kex_alg.startswith(b'gss-')): raise KeyExchangeFailed('Unable to find compatible ' 'server host key') self._enc_alg_cs = self._choose_alg('encryption', self._enc_algs, enc_algs_cs) self._enc_alg_sc = self._choose_alg('encryption', self._enc_algs, enc_algs_sc) self._mac_alg_cs = self._choose_alg('MAC', self._mac_algs, mac_algs_cs) self._mac_alg_sc = self._choose_alg('MAC', self._mac_algs, mac_algs_sc) self._cmp_alg_cs = self._choose_alg('compression', self._cmp_algs, cmp_algs_cs) self._cmp_alg_sc = self._choose_alg('compression', self._cmp_algs, cmp_algs_sc) self.logger.debug1('Beginning key exchange') self.logger.debug2(' Key exchange alg: %s', self._kex.algorithm) await self._kex.start() def _process_newkeys(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a new keys message, finishing a key exchange""" packet.check_end() if self._next_recv_encryption: self._recv_encryption = self._next_recv_encryption self._recv_blocksize = self._next_recv_blocksize self._recv_macsize = self._next_recv_macsize self._decompressor = self._next_decompressor self._decompress_after_auth = self._next_decompress_after_auth self._next_recv_encryption = None self._can_recv_ext_info = True else: raise ProtocolError('New keys not negotiated') self.logger.debug1('Completed key exchange') def _process_userauth_request(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication request""" username_bytes = packet.get_string() service = packet.get_string() method = packet.get_string() if len(username_bytes) >= _MAX_USERNAME_LEN: raise IllegalUserName('Username too long') if service != _CONNECTION_SERVICE: raise ServiceNotAvailable('Unexpected service in auth request') try: username = saslprep(username_bytes.decode('utf-8')) except (UnicodeDecodeError, SASLPrepError) as exc: raise IllegalUserName(str(exc)) from None if self.is_client(): raise ProtocolError('Unexpected userauth request') elif self._auth_complete: # Silently ignore additional auth requests after auth succeeds, # until the client sends a non-auth message if self._auth_final: raise ProtocolError('Unexpected userauth request') else: if username != self._username: self.logger.info('Beginning auth for user %s', username) self._username = username begin_auth = True else: begin_auth = False self.create_task(self._finish_userauth(begin_auth, method, packet)) async def _finish_userauth(self, begin_auth: bool, method: bytes, packet: SSHPacket) -> None: """Finish processing a user authentication request""" if not self._owner: # pragma: no cover return if begin_auth: # This method is only in SSHServerConnection # pylint: disable=no-member await cast(SSHServerConnection, self).reload_config() result = cast(SSHServer, self._owner).begin_auth(self._username) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) if not result: self.send_userauth_success() return if not self._owner: # pragma: no cover return if self._auth: self._auth.cancel() self._auth = lookup_server_auth(cast(SSHServerConnection, self), self._username, method, packet) def _process_userauth_failure(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication failure response""" auth_methods = packet.get_namelist() partial_success = packet.get_boolean() packet.check_end() self.logger.debug2('Remaining auth methods: %s', auth_methods or 'None') if self._wait == 'auth_methods' and self._waiter and \ not self._waiter.cancelled(): self._waiter.set_result(None) self._auth_methods = list(auth_methods) self._wait = None return if self._preferred_auth: self.logger.debug2('Preferred auth methods: %s', self._preferred_auth or 'None') auth_methods = [method for method in self._preferred_auth if method in auth_methods] self._auth_methods = list(auth_methods) if self.is_client() and self._auth: auth = cast(ClientAuth, self._auth) if partial_success: # pragma: no cover # Partial success not implemented yet auth.auth_succeeded() else: auth.auth_failed() # This method is only in SSHClientConnection # pylint: disable=no-member cast(SSHClientConnection, self).try_next_auth() else: raise ProtocolError('Unexpected userauth failure response') def _process_userauth_success(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication success response""" packet.check_end() if self.is_client() and self._auth: auth = cast(ClientAuth, self._auth) if self._auth_was_trivial and self._disable_trivial_auth: raise PermissionDenied('Trivial auth disabled') self.logger.info('Auth for user %s succeeded', self._username) if self._wait == 'auth_methods' and self._waiter and \ not self._waiter.cancelled(): self._waiter.set_result(None) self._auth_methods = [b'none'] self._wait = None return auth.auth_succeeded() auth.cancel() self._auth = None self._auth_in_progress = False self._auth_complete = True self._can_recv_ext_info = False if self._agent: self._agent.close() self.set_extra_info(username=self._username) self._cancel_login_timer() self._send_deferred_packets() self._set_keepalive_timer() if self._owner: # pragma: no branch self._owner.auth_completed() if self._acceptor: result = self._acceptor(self) if inspect.isawaitable(result): assert result is not None self.create_task(result) self._acceptor = None self._error_handler = None if self._wait == 'auth' and self._waiter and \ not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None else: raise ProtocolError('Unexpected userauth success response') def _process_userauth_banner(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication banner message""" msg_bytes = packet.get_string() lang_bytes = packet.get_string() # Work around an extra NUL byte appearing in the user # auth banner message in some versions of cryptlib if b'cryptlib' in self._server_version and \ packet.get_remaining_payload() == b'\0': # pragma: no cover packet.get_byte() packet.check_end() try: msg = msg_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid userauth banner') from None self.logger.debug1('Received authentication banner') if self.is_client(): cast(SSHClient, self._owner).auth_banner_received(msg, lang) else: raise ProtocolError('Unexpected userauth banner') def _process_global_request(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a global request""" request_bytes = packet.get_string() want_reply = packet.get_boolean() try: request = request_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid global request') from None name = '_process_' + map_handler_name(request) + '_global_request' handler = cast(Optional[_PacketHandler], getattr(self, name, None)) if not handler: self.logger.debug1('Received unknown global request: %s', request) self._global_request_queue.append((handler, packet, want_reply)) if len(self._global_request_queue) == 1: self._service_next_global_request() def _process_global_response(self, pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a global response""" if self._global_request_waiters: waiter = self._global_request_waiters.pop(0) if not waiter.cancelled(): # pragma: no branch waiter.set_result((pkttype, packet)) else: raise ProtocolError('Unexpected global response') def _process_channel_open(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a channel open request""" chantype_bytes = packet.get_string() send_chan = packet.get_uint32() send_window = packet.get_uint32() send_pktsize = packet.get_uint32() # Work around an off-by-one error in dropbear introduced in # https://github.com/mkj/dropbear/commit/49263b5 if b'dropbear' in self._client_version and self._compressor: send_pktsize -= 1 try: chantype = chantype_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid channel open request') from None try: name = '_process_' + map_handler_name(chantype) + '_open' handler = cast(Optional[_OpenHandler], getattr(self, name, None)) if callable(handler): chan, session = handler(packet) chan.process_open(send_chan, send_window, send_pktsize, session) else: raise ChannelOpenError(OPEN_UNKNOWN_CHANNEL_TYPE, 'Unknown channel type') except ChannelOpenError as exc: self.logger.debug1('Open failed for channel type %s: %s', chantype, exc.reason) self.send_channel_open_failure(send_chan, exc.code, exc.reason, exc.lang) def _process_channel_open_confirmation(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a channel open confirmation response""" recv_chan = packet.get_uint32() send_chan = packet.get_uint32() send_window = packet.get_uint32() send_pktsize = packet.get_uint32() # Work around an off-by-one error in dropbear introduced in # https://github.com/mkj/dropbear/commit/49263b5 if b'dropbear' in self._server_version and self._compressor: send_pktsize -= 1 chan = self._channels.get(recv_chan) if chan: chan.process_open_confirmation(send_chan, send_window, send_pktsize, packet) else: self.logger.debug1('Received open confirmation for unknown ' 'channel %d', recv_chan) raise ProtocolError('Invalid channel number') def _process_channel_open_failure(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a channel open failure response""" recv_chan = packet.get_uint32() code = packet.get_uint32() reason_bytes = packet.get_string() lang_bytes = packet.get_string() packet.check_end() try: reason = reason_bytes.decode('utf-8') lang = lang_bytes.decode('ascii') except UnicodeDecodeError: raise ProtocolError('Invalid channel open failure') from None chan = self._channels.get(recv_chan) if chan: chan.process_open_failure(code, reason, lang) else: self.logger.debug1('Received open failure for unknown ' 'channel %d', recv_chan) raise ProtocolError('Invalid channel number') def _process_keepalive_at_openssh_dot_com_global_request( self, packet: SSHPacket) -> None: """Process an incoming OpenSSH keepalive request""" packet.check_end() self.logger.debug2('Received OpenSSH keepalive request') self._report_global_response(True) _packet_handlers = { MSG_DISCONNECT: _process_disconnect, MSG_IGNORE: _process_ignore, MSG_UNIMPLEMENTED: _process_unimplemented, MSG_DEBUG: _process_debug, MSG_SERVICE_REQUEST: _process_service_request, MSG_SERVICE_ACCEPT: _process_service_accept, MSG_EXT_INFO: _process_ext_info, MSG_KEXINIT: _process_kexinit, MSG_NEWKEYS: _process_newkeys, MSG_USERAUTH_REQUEST: _process_userauth_request, MSG_USERAUTH_FAILURE: _process_userauth_failure, MSG_USERAUTH_SUCCESS: _process_userauth_success, MSG_USERAUTH_BANNER: _process_userauth_banner, MSG_GLOBAL_REQUEST: _process_global_request, MSG_REQUEST_SUCCESS: _process_global_response, MSG_REQUEST_FAILURE: _process_global_response, MSG_CHANNEL_OPEN: _process_channel_open, MSG_CHANNEL_OPEN_CONFIRMATION: _process_channel_open_confirmation, MSG_CHANNEL_OPEN_FAILURE: _process_channel_open_failure } def abort(self) -> None: """Forcibly close the SSH connection This method closes the SSH connection immediately, without waiting for pending operations to complete and without sending an explicit SSH disconnect message. Buffered data waiting to be sent will be lost and no more data will be received. When the the connection is closed, :meth:`connection_lost() ` on the associated :class:`SSHClient` object will be called with the value `None`. """ self.logger.info('Aborting connection') self._force_close(None) def close(self) -> None: """Cleanly close the SSH connection This method calls :meth:`disconnect` with the reason set to indicate that the connection was closed explicitly by the application. """ self.logger.info('Closing connection') self.disconnect(DISC_BY_APPLICATION, 'Disconnected by application') async def wait_closed(self) -> None: """Wait for this connection to close This method is a coroutine which can be called to block until this connection has finished closing. """ if self._agent: await self._agent.wait_closed() await self._close_event.wait() def disconnect(self, code: int, reason: str, lang: str = DEFAULT_LANG) -> None: """Disconnect the SSH connection This method sends a disconnect message and closes the SSH connection after buffered data waiting to be written has been sent. No more data will be received. When the connection is fully closed, :meth:`connection_lost() ` on the associated :class:`SSHClient` or :class:`SSHServer` object will be called with the value `None`. :param code: The reason for the disconnect, from :ref:`disconnect reason codes ` :param reason: A human readable reason for the disconnect :param lang: The language the reason is in :type code: `int` :type reason: `str` :type lang: `str` """ for chan in list(self._channels.values()): chan.close() self._send_disconnect(code, reason, lang) self._force_close(None) def get_extra_info(self, name: str, default: Any = None) -> Any: """Get additional information about the connection This method returns extra information about the connection once it is established. Supported values include everything supported by a socket transport plus: | host | port | username | client_version | server_version | send_cipher | send_mac | send_compression | recv_cipher | recv_mac | recv_compression See :meth:`get_extra_info() ` in :class:`asyncio.BaseTransport` for more information. Additional information stored on the connection by calling :meth:`set_extra_info` can also be returned here. """ return self._extra.get(name, self._transport.get_extra_info(name, default) if self._transport else default) def set_extra_info(self, **kwargs: Any) -> None: """Store additional information associated with the connection This method allows extra information to be associated with the connection. The information to store should be passed in as keyword parameters and can later be returned by calling :meth:`get_extra_info` with one of the keywords as the name to retrieve. """ self._extra.update(**kwargs) def set_keepalive(self, interval: Union[None, float, str] = None, count_max: Optional[int] = None) -> None: """Set keep-alive timer on this connection This method sets the parameters of the keepalive timer on the connection. If *interval* is set to a non-zero value, keep-alive requests will be sent whenever the connection is idle, and if a response is not received after *count_max* attempts, the connection is closed. :param interval: (optional) The time in seconds to wait before sending a keep-alive message if no data has been received. This defaults to 0, which disables sending these messages. :param count_max: (optional) The maximum number of keepalive messages which will be sent without getting a response before closing the connection. This defaults to 3, but only applies when *interval* is non-zero. :type interval: `int`, `float`, or `str` :type count_max: `int` """ if interval is not None: if isinstance(interval, str): interval = parse_time_interval(interval) if interval < 0: raise ValueError('Keepalive interval cannot be negative') self._keepalive_interval = interval if count_max is not None: if count_max < 0: raise ValueError('Keepalive count max cannot be negative') self._keepalive_count_max = count_max self._reset_keepalive_timer() def send_debug(self, msg: str, lang: str = DEFAULT_LANG, always_display: bool = False) -> None: """Send a debug message on this connection This method can be called to send a debug message to the other end of the connection. :param msg: The debug message to send :param lang: The language the message is in :param always_display: Whether or not to display the message :type msg: `str` :type lang: `str` :type always_display: `bool` """ self.logger.debug1('Sending debug message: %s%s', msg, ' (always display)' if always_display else '') self.send_packet(MSG_DEBUG, Boolean(always_display), String(msg), String(lang)) def create_tcp_channel(self, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ SSHTCPChannel: """Create an SSH TCP channel for a new direct TCP connection This method can be called by :meth:`connection_requested() ` to create an :class:`SSHTCPChannel` with the desired encoding, Unicode error handling strategy, window, and max packet size for a newly created SSH direct connection. :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection. This defaults to `None`, allowing the application to send and receive raw bytes. :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHTCPChannel` """ return SSHTCPChannel(self, self._loop, encoding, errors, window, max_pktsize) def create_unix_channel(self, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ SSHUNIXChannel: """Create an SSH UNIX channel for a new direct UNIX domain connection This method can be called by :meth:`unix_connection_requested() ` to create an :class:`SSHUNIXChannel` with the desired encoding, Unicode error handling strategy, window, and max packet size for a newly created SSH direct UNIX domain socket connection. :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection. This defaults to `None`, allowing the application to send and receive raw bytes. :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHUNIXChannel` """ return SSHUNIXChannel(self, self._loop, encoding, errors, window, max_pktsize) def create_tuntap_channel(self, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ SSHTunTapChannel: """Create a channel to use for TUN/TAP forwarding This method can be called by :meth:`tun_requested() ` or :meth:`tap_requested() ` to create an :class:`SSHTunTapChannel` with the desired window and max packet size for a newly created TUN/TAP tunnel. :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHTunTapChannel` """ return SSHTunTapChannel(self, self._loop, None, 'strict', window, max_pktsize) def create_x11_channel( self, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHX11Channel: """Create an SSH X11 channel to use in X11 forwarding""" return SSHX11Channel(self, self._loop, None, 'strict', window, max_pktsize) def create_agent_channel( self, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHAgentChannel: """Create an SSH agent channel to use in agent forwarding""" return SSHAgentChannel(self, self._loop, None, 'strict', window, max_pktsize) async def create_connection( self, session_factory: SSHTCPSessionFactory[AnyStr], remote_host: str, remote_port: int, orig_host: str = '', orig_port: int = 0, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHTCPChannel[AnyStr], SSHTCPSession[AnyStr]]: """Create an SSH direct or forwarded TCP connection""" raise NotImplementedError async def create_unix_connection( self, session_factory: SSHUNIXSessionFactory[AnyStr], remote_path: str, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]: """Create an SSH direct or forwarded UNIX domain socket connection""" raise NotImplementedError async def forward_connection( self, dest_host: str, dest_port: int) -> SSHForwarder: """Forward a tunneled TCP connection This method is a coroutine which can be returned by a `session_factory` to forward connections tunneled over SSH to the specified destination host and port. :param dest_host: The hostname or address to forward the connections to :param dest_port: The port number to forward the connections to :type dest_host: `str` or `None` :type dest_port: `int` :returns: :class:`asyncio.BaseProtocol` """ try: _, peer = await self._loop.create_connection(SSHForwarder, dest_host, dest_port) self.logger.info(' Forwarding TCP connection to %s', (dest_host, dest_port)) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHForwarder(cast(SSHForwarder, peer)) async def forward_unix_connection(self, dest_path: str) -> SSHForwarder: """Forward a tunneled UNIX domain socket connection This method is a coroutine which can be returned by a `session_factory` to forward connections tunneled over SSH to the specified destination path. :param dest_path: The path to forward the connection to :type dest_path: `str` :returns: :class:`asyncio.BaseProtocol` """ try: _, peer = \ await self._loop.create_unix_connection(SSHForwarder, dest_path) self.logger.info(' Forwarding UNIX connection to %s', dest_path) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHForwarder(cast(SSHForwarder, peer)) @async_context_manager async def forward_local_port( self, listen_host: str, listen_port: int, dest_host: str, dest_port: int, accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: """Set up local port forwarding This method is a coroutine which attempts to set up port forwarding from a local listening port to a remote host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. :param listen_host: The hostname or address on the local host to listen on :param listen_port: The port number on the local host to listen on :param dest_host: The hostname or address to forward the connections to :param dest_port: The port number to forward the connections to :param accept_handler: A `callable` or coroutine which takes arguments of the original host and port of the client and decides whether or not to allow connection forwarding, returning `True` to accept the connection and begin forwarding or `False` to reject and close it. :type listen_host: `str` :type listen_port: `int` :type dest_host: `str` :type dest_port: `int` :type accept_handler: `callable` or coroutine :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ async def tunnel_connection( session_factory: SSHTCPSessionFactory[bytes], orig_host: str, orig_port: int) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Forward a local connection over SSH""" if accept_handler: result = accept_handler(orig_host, orig_port) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) if not result: self.logger.info('Request for TCP forwarding from ' '%s to %s denied by application', (orig_host, orig_port), (dest_host, dest_port)) raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Connection forwarding denied') return (await self.create_connection(session_factory, dest_host, dest_port, orig_host, orig_port)) if (listen_host, listen_port) == (dest_host, dest_port): self.logger.info('Creating local TCP forwarder on %s', (listen_host, listen_port)) else: self.logger.info('Creating local TCP forwarder from %s to %s', (listen_host, listen_port), (dest_host, dest_port)) try: listener = await create_tcp_forward_listener(self, self._loop, tunnel_connection, listen_host, listen_port) except OSError as exc: self.logger.debug1('Failed to create local TCP listener: %s', exc) raise if listen_port == 0: listen_port = listener.get_port() if dest_port == 0: dest_port = listen_port self._local_listeners[listen_host, listen_port] = listener return listener @async_context_manager async def forward_local_path(self, listen_path: str, dest_path: str) -> SSHListener: """Set up local UNIX domain socket forwarding This method is a coroutine which attempts to set up UNIX domain socket forwarding from a local listening path to a remote path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the UNIX domain socket forwarding. :param listen_path: The path on the local host to listen on :param dest_path: The path on the remote host to forward the connections to :type listen_path: `str` :type dest_path: `str` :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ async def tunnel_connection( session_factory: SSHUNIXSessionFactory[bytes]) -> \ Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]: """Forward a local connection over SSH""" return await self.create_unix_connection(session_factory, dest_path) self.logger.info('Creating local UNIX forwarder from %s to %s', listen_path, dest_path) try: listener = await create_unix_forward_listener(self, self._loop, tunnel_connection, listen_path) except OSError as exc: self.logger.debug1('Failed to create local UNIX listener: %s', exc) raise self._local_listeners[listen_path] = listener return listener def forward_tuntap(self, mode: int, unit: Optional[int]) -> SSHForwarder: """Set up TUN/TAP forwarding""" try: transport, peer = create_tuntap(SSHForwarder, mode, unit) interface = transport.get_extra_info('interface') self.logger.info(' Forwarding layer %d traffic to %s', 3 if mode == SSH_TUN_MODE_POINTTOPOINT else 2, interface) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHForwarder(cast(SSHForwarder, peer), extra={'interface': interface}) def close_forward_listener(self, listen_key: ListenKey) -> None: """Mark a local forwarding listener as closed""" self._local_listeners.pop(listen_key, None) def detach_x11_listener(self, chan: SSHChannel[AnyStr]) -> None: """Detach a session from a local X11 listener""" raise NotImplementedError class SSHClientConnection(SSHConnection): """SSH client connection This class represents an SSH client connection. Once authentication is successful on a connection, new client sessions can be opened by calling :meth:`create_session`. Direct TCP connections can be opened by calling :meth:`create_connection`. Remote listeners for forwarded TCP connections can be opened by calling :meth:`create_server`. Direct UNIX domain socket connections can be opened by calling :meth:`create_unix_connection`. Remote listeners for forwarded UNIX domain socket connections can be opened by calling :meth:`create_unix_server`. TCP port forwarding can be set up by calling :meth:`forward_local_port` or :meth:`forward_remote_port`. UNIX domain socket forwarding can be set up by calling :meth:`forward_local_path` or :meth:`forward_remote_path`. Mixed forwarding from a TCP port to a UNIX domain socket or vice-versa can be set up by calling :meth:`forward_local_port_to_path`, :meth:`forward_local_path_to_port`, :meth:`forward_remote_port_to_path`, or :meth:`forward_remote_path_to_port`. """ _options: 'SSHClientConnectionOptions' _owner: SSHClient _x11_listener: Optional[SSHX11ClientListener] def __init__(self, loop: asyncio.AbstractEventLoop, options: 'SSHClientConnectionOptions', acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, wait: Optional[str] = None): super().__init__(loop, options, acceptor, error_handler, wait, server=False) self._host = options.host self._port = options.port self._known_hosts = options.known_hosts self._host_key_alias = options.host_key_alias self._server_host_key_algs: Optional[Sequence[bytes]] = None self._server_host_key: Optional[SSHKey] = None self._server_host_keys_handler = options.server_host_keys_handler self._username = options.username self._password = options.password self._client_host_keys: List[_ClientHostKey] = [] self._client_keys: List[SSHKeyPair] = \ list(options.client_keys) if options.client_keys else [] self._saved_rsa_key: Optional[_ClientHostKey] = None if options.preferred_auth != (): self._preferred_auth = [method.encode('ascii') for method in options.preferred_auth] else: self._preferred_auth = get_supported_client_auth_methods() self._disable_trivial_auth = options.disable_trivial_auth if options.agent_path is not None: self._agent = SSHAgentClient(options.agent_path) self._agent_identities = options.agent_identities self._agent_forward_path = options.agent_forward_path self._get_agent_keys = bool(self._agent) self._pkcs11_provider = options.pkcs11_provider self._pkcs11_pin = options.pkcs11_pin self._get_pkcs11_keys = bool(self._pkcs11_provider) gss_host = options.gss_host if options.gss_host != () else options.host if gss_host: try: self._gss = GSSClient(gss_host, options.gss_store, options.gss_delegate_creds) self._gss_kex = options.gss_kex self._gss_auth = options.gss_auth self._gss_mic_auth = self._gss_auth except GSSError: pass self._kbdint_password_auth = False self._remote_listeners: \ Dict[ListenKey, Union[SSHTCPClientListener, SSHUNIXClientListener]] = {} self._dynamic_remote_listeners: Dict[str, SSHTCPClientListener] = {} def _connection_made(self) -> None: """Handle the opening of a new connection""" assert self._transport is not None if not self._host: if self._peer_addr: self._host = self._peer_addr self._port = self._peer_port else: remote_peer = self.get_extra_info('remote_peername') self._host, self._port = cast(HostPort, remote_peer) if self._options.client_host_keysign: sock = cast(socket.socket, self._transport.get_extra_info('socket')) self._client_host_keys = list(get_keysign_keys( self._options.client_host_keysign, sock.fileno(), self._options.client_host_pubkeys)) elif self._options.client_host_keypairs: self._client_host_keys = list(self._options.client_host_keypairs) else: self._client_host_keys = [] if self._known_hosts is None: self._trusted_host_keys = None self._trusted_ca_keys = None else: if not self._known_hosts: default_known_hosts = Path('~', '.ssh', 'known_hosts').expanduser() if (default_known_hosts.is_file() and os.access(default_known_hosts, os.R_OK)): self._known_hosts = str(default_known_hosts) else: self._known_hosts = b'' port = self._port if self._port != DEFAULT_PORT else None self._match_known_hosts(cast(KnownHostsArg, self._known_hosts), self._host_key_alias or self._host, self._peer_addr, port) default_host_key_algs = [] if self._options.server_host_key_algs != 'default': if self._trusted_host_key_algs: default_host_key_algs = self._trusted_host_key_algs if self._trusted_ca_keys: default_host_key_algs = \ get_default_certificate_algs() + default_host_key_algs if not default_host_key_algs: default_host_key_algs = \ get_default_certificate_algs() + get_default_public_key_algs() if self._x509_trusted_certs is not None: if self._x509_trusted_certs or self._x509_trusted_cert_paths: default_host_key_algs = \ get_default_x509_certificate_algs() + default_host_key_algs self._server_host_key_algs = _select_host_key_algs( self._options.server_host_key_algs, cast(DefTuple[str], self._options.config.get( 'HostKeyAlgorithms', ())), default_host_key_algs) self.logger.info('Connected to SSH server at %s', (self._host, self._port)) if self._options.proxy_command: proxy_command = ' '.join(shlex.quote(arg) for arg in self._options.proxy_command) self.logger.info(' Proxy command: %s', proxy_command) else: self.logger.info(' Local address: %s', (self._local_addr, self._local_port)) self.logger.info(' Peer address: %s', (self._peer_addr, self._peer_port)) def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this client connection""" if self._agent: self._agent.close() if self._remote_listeners: for tcp_listener in list(self._remote_listeners.values()): tcp_listener.close() self._remote_listeners = {} self._dynamic_remote_listeners = {} if exc is None: self.logger.info('Connection closed') elif isinstance(exc, ConnectionLost): self.logger.info(str(exc)) else: self.logger.info('Connection failure: ' + str(exc)) super()._cleanup(exc) def _choose_signature_alg(self, keypair: _ClientHostKey) -> bool: """Choose signature algorithm to use for key-based authentication""" if self._server_sig_algs: for alg in keypair.sig_algorithms: if keypair.use_webauthn and not alg.startswith(b'webauthn-'): continue if alg in self._sig_algs and alg in self._server_sig_algs: keypair.set_sig_algorithm(alg) return True return keypair.sig_algorithms[-1] in self._sig_algs def validate_server_host_key(self, key_data: bytes) -> SSHKey: """Validate and return the server's host key""" try: host_key = self._validate_host_key( self._host_key_alias or self._host, self._peer_addr, self._port, key_data) except ValueError as exc: host = self._host if self._host_key_alias: host += f' with alias {self._host_key_alias}' raise HostKeyNotVerifiable(f'{exc} for host {host}') from None self._server_host_key = host_key return host_key def get_server_host_key(self) -> Optional[SSHKey]: """Return the server host key used in the key exchange This method returns the server host key used to complete the key exchange with the server. If GSS key exchange is used, `None` is returned. :returns: An :class:`SSHKey` public key or `None` """ return self._server_host_key def get_server_auth_methods(self) -> Sequence[str]: """Return the server host key used in the key exchange This method returns the auth methods available to authenticate to the server. :returns: `list` of `str` """ return [method.decode('ascii') for method in self._auth_methods] def try_next_auth(self, *, next_method: bool = False) -> None: """Attempt client authentication using the next compatible method""" if self._auth: self._auth.cancel() self._auth = None if next_method: self._auth_methods.pop(0) while self._auth_methods: self._auth = lookup_client_auth(self, self._auth_methods[0]) if self._auth: return self._auth_methods.pop(0) self.logger.info('Auth failed for user %s', self._username) self._force_close(PermissionDenied('Permission denied for user ' f'{self._username} on host ' f'{self._host}')) def gss_kex_auth_requested(self) -> bool: """Return whether to allow GSS key exchange authentication or not""" if self._gss_kex_auth: self._gss_kex_auth = False return True else: return False def gss_mic_auth_requested(self) -> bool: """Return whether to allow GSS MIC authentication or not""" if self._gss_mic_auth: self._gss_mic_auth = False return True else: return False async def host_based_auth_requested(self) -> \ Tuple[Optional[_ClientHostKey], str, str]: """Return a host key, host, and user to authenticate with""" if not self._host_based_auth: return None, '', '' key: Optional[_ClientHostKey] while True: if self._saved_rsa_key: key = self._saved_rsa_key key.algorithm = key.sig_algorithm + b'-cert-v01@openssh.com' self._saved_rsa_key = None else: try: key = self._client_host_keys.pop(0) except IndexError: key = None break assert key is not None if self._choose_signature_alg(key): if key.algorithm == b'ssh-rsa-cert-v01@openssh.com' and \ key.sig_algorithm != b'ssh-rsa': self._saved_rsa_key = key break client_host = self._options.client_host if client_host is None: sockname = cast(SockAddr, self.get_extra_info('sockname')) if sockname: try: client_host, _ = await self._loop.getnameinfo( sockname, socket.NI_NUMERICSERV) except socket.gaierror: client_host = sockname[0] else: client_host = '' # Add a trailing '.' to the client host to be compatible with # ssh-keysign from OpenSSH if self._options.client_host_keysign and client_host[-1:] != '.': client_host += '.' return key, client_host, self._options.client_username async def public_key_auth_requested(self) -> Optional[SSHKeyPair]: """Return a client key pair to authenticate with""" if not self._public_key_auth: return None if self._get_agent_keys: assert self._agent is not None try: agent_keys = await self._agent.get_keys(self._agent_identities) self._client_keys[:0] = list(agent_keys) except ValueError: pass self._get_agent_keys = False if self._get_pkcs11_keys: assert self._pkcs11_provider is not None pkcs11_keys = await self._loop.run_in_executor( None, load_pkcs11_keys, self._pkcs11_provider, self._pkcs11_pin) self._client_keys[:0] = list(pkcs11_keys) self._get_pkcs11_keys = False while True: if not self._client_keys: result = self._owner.public_key_auth_requested() if inspect.isawaitable(result): result = await cast(Awaitable[KeyPairListArg], result) if not result: return None result: KeyPairListArg self._client_keys = list(load_keypairs(result)) # OpenSSH versions before 7.8 didn't support RSA SHA-2 # signature names in certificate key types, requiring the # use of ssh-rsa-cert-v01@openssh.com as the key type even # when using SHA-2 signatures. However, OpenSSL 8.8 and # later reject ssh-rsa-cert-v01@openssh.com as a key type # by default, requiring that the RSA SHA-2 version of the key # type be used. This makes it difficult to use RSA keys with # certificates without knowing the version of the remote # server and which key types it will accept. # # The code below works around this by trying multiple key # types during public key and host-based authentication when # using SHA-2 signatures with RSA keys signed by certificates. if self._saved_rsa_key: key = self._saved_rsa_key key.algorithm = key.sig_algorithm + b'-cert-v01@openssh.com' self._saved_rsa_key = None else: key = self._client_keys.pop(0) if self._choose_signature_alg(key): if key.algorithm == b'ssh-rsa-cert-v01@openssh.com' and \ key.sig_algorithm != b'ssh-rsa': self._saved_rsa_key = key return key async def password_auth_requested(self) -> Optional[str]: """Return a password to authenticate with""" if not self._password_auth and not self._kbdint_password_auth: return None if self._password is not None: password: Optional[str] = self._password if callable(password): password = cast(Callable[[], Optional[str]], password)() if inspect.isawaitable(password): password = await cast(Awaitable[Optional[str]], password) else: password = cast(Optional[str], password) self._password = None else: result = self._owner.password_auth_requested() if inspect.isawaitable(result): password = await cast(Awaitable[Optional[str]], result) else: password = cast(Optional[str], result) return password async def password_change_requested(self, prompt: str, lang: str) -> Tuple[str, str]: """Return a password to authenticate with and what to change it to""" result = self._owner.password_change_requested(prompt, lang) if inspect.isawaitable(result): result = await cast(Awaitable[PasswordChangeResponse], result) return cast(PasswordChangeResponse, result) def password_changed(self) -> None: """Report a successful password change""" self._owner.password_changed() def password_change_failed(self) -> None: """Report a failed password change""" self._owner.password_change_failed() async def kbdint_auth_requested(self) -> Optional[str]: """Return the list of supported keyboard-interactive auth methods If keyboard-interactive auth is not supported in the client but a password was provided when the connection was opened, this will allow sending the password via keyboard-interactive auth. """ if not self._kbdint_auth: return None result = self._owner.kbdint_auth_requested() if inspect.isawaitable(result): result = await cast(Awaitable[Optional[str]], result) if result is NotImplemented: if self._password is not None and not self._kbdint_password_auth: self._kbdint_password_auth = True result = '' else: result = None return cast(Optional[str], result) async def kbdint_challenge_received( self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts) -> Optional[KbdIntResponse]: """Return responses to a keyboard-interactive auth challenge""" if self._kbdint_password_auth: if not prompts: # Silently drop any empty challenges used to print messages response: Optional[KbdIntResponse] = [] elif len(prompts) == 1: prompt = prompts[0][0].lower() if 'password' in prompt or 'passcode' in prompt: password = await self.password_auth_requested() response = [password] if password is not None else None else: response = None else: response = None else: result = self._owner.kbdint_challenge_received(name, instructions, lang, prompts) if inspect.isawaitable(result): response = await cast(Awaitable[KbdIntResponse], result) else: response = cast(KbdIntResponse, result) return response def _process_session_open(self, _packet: SSHPacket) -> \ Tuple[SSHServerChannel, SSHServerSession]: """Process an inbound session open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Session open forbidden on client') def _process_direct_tcpip_open(self, _packet: SSHPacket) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Process an inbound direct TCP/IP channel open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Direct TCP/IP open forbidden on client') def _process_forwarded_tcpip_open(self, packet: SSHPacket) -> \ Tuple[SSHTCPChannel, MaybeAwait[SSHTCPSession]]: """Process an inbound forwarded TCP/IP channel open request""" dest_host_bytes = packet.get_string() dest_port = packet.get_uint32() orig_host_bytes = packet.get_string() orig_port = packet.get_uint32() packet.check_end() try: dest_host = dest_host_bytes.decode('utf-8') orig_host = orig_host_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid forwarded TCP/IP channel ' 'open request') from None # Some buggy servers send back a port of `0` instead of the actual # listening port when reporting connections which arrive on a listener # set up on a dynamic port. This lookup attempts to work around that. listener = cast(SSHTCPClientListener[bytes], self._remote_listeners.get((dest_host, dest_port)) or self._dynamic_remote_listeners.get(dest_host)) if listener: chan, session = listener.process_connection(orig_host, orig_port) self.logger.info('Accepted forwarded TCP connection on %s', (dest_host, dest_port)) self.logger.info(' Client address: %s', (orig_host, orig_port)) return chan, session else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'No such listener') async def close_client_tcp_listener(self, listen_host: str, listen_port: int) -> None: """Close a remote TCP/IP listener""" await self._make_global_request( b'cancel-tcpip-forward', String(listen_host), UInt32(listen_port)) self.logger.info('Closed remote TCP listener on %s', (listen_host, listen_port)) listener = self._remote_listeners.get((listen_host, listen_port)) if listener: if self._dynamic_remote_listeners.get(listen_host) == listener: del self._dynamic_remote_listeners[listen_host] del self._remote_listeners[listen_host, listen_port] def _process_direct_streamlocal_at_openssh_dot_com_open( self, _packet: SSHPacket) -> \ Tuple[SSHUNIXChannel, SSHUNIXSession]: """Process an inbound direct UNIX domain channel open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Direct UNIX domain socket open ' 'forbidden on client') def _process_tun_at_openssh_dot_com_open( self, _packet: SSHPacket) -> \ Tuple[SSHTunTapChannel, SSHTunTapSession]: """Process an inbound TUN/TAP open request These requests are disallowed on an SSH client. """ # pylint: disable=no-self-use raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'TUN/TAP request forbidden on client') def _process_forwarded_streamlocal_at_openssh_dot_com_open( self, packet: SSHPacket) -> \ Tuple[SSHUNIXChannel, MaybeAwait[SSHUNIXSession]]: """Process an inbound forwarded UNIX domain channel open request""" dest_path_bytes = packet.get_string() _ = packet.get_string() # reserved packet.check_end() try: dest_path = dest_path_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid forwarded UNIX domain channel ' 'open request') from None listener = cast(SSHUNIXClientListener[bytes], self._remote_listeners.get(dest_path)) if listener: chan, session = listener.process_connection() self.logger.info('Accepted remote UNIX connection on %s', dest_path) return chan, session else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'No such listener') async def close_client_unix_listener(self, listen_path: str) -> None: """Close a remote UNIX domain socket listener""" await self._make_global_request( b'cancel-streamlocal-forward@openssh.com', String(listen_path)) self.logger.info('Closed UNIX listener on %s', listen_path) if listen_path in self._remote_listeners: del self._remote_listeners[listen_path] def _process_x11_open(self, packet: SSHPacket) -> \ Tuple[SSHX11Channel, Awaitable[SSHX11ClientForwarder]]: """Process an inbound X11 channel open request""" orig_host_bytes = packet.get_string() orig_port = packet.get_uint32() packet.check_end() try: orig_host = orig_host_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid forwarded X11 channel ' 'open request') from None if self._x11_listener: self.logger.info('Accepted X11 connection') self.logger.info(' Client address: %s', (orig_host, orig_port)) chan = self.create_x11_channel() chan.set_inbound_peer_names(orig_host, orig_port) return chan, self._x11_listener.forward_connection() else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'X11 forwarding disabled') def _process_auth_agent_at_openssh_dot_com_open( self, packet: SSHPacket) -> \ Tuple[SSHUNIXChannel, Awaitable[SSHForwarder]]: """Process an inbound auth agent channel open request""" packet.check_end() if self._agent_forward_path: self.logger.info('Accepted SSH agent connection') return (self.create_unix_channel(), self.forward_unix_connection(self._agent_forward_path)) else: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Auth agent forwarding disabled') def _process_hostkeys_00_at_openssh_dot_com_global_request( self, packet: SSHPacket) -> None: """Process a list of accepted server host keys""" self.create_task(self._finish_hostkeys(packet)) async def _finish_hostkeys(self, packet: SSHPacket) -> None: """Finish processing hostkeys global request""" if not self._server_host_keys_handler: self.logger.debug1('Ignoring server host key message: no handler') self._report_global_response(False) return if self._trusted_host_keys is None: self.logger.info('Server host key not verified: handler disabled') self._report_global_response(False) return added = [] removed = list(self._trusted_host_keys) retained = [] revoked = [] prove = [] while packet: try: key_data = packet.get_string() key = decode_ssh_public_key(key_data) if key in self._revoked_host_keys: revoked.append(key) elif key in self._trusted_host_keys: retained.append(key) removed.remove(key) else: prove.append((key, String(key_data))) except KeyImportError: pass if prove: pkttype, packet = await self._make_global_request( b'hostkeys-prove-00@openssh.com', b''.join(key_str for _, key_str in prove)) if pkttype == MSG_REQUEST_SUCCESS: prefix = String('hostkeys-prove-00@openssh.com') + \ String(self._session_id) for key, key_str in prove: sig = packet.get_string() if key.verify(prefix + key_str, sig): added.append(key) else: self.logger.debug1('Server host key validation failed') else: self.logger.debug1('Server host key prove request failed') packet.check_end() self.logger.info(f'Server host key report: {len(added)} added, ' f'{len(removed)} removed, {len(retained)} retained, ' f'{len(revoked)} revoked') result = self._server_host_keys_handler(added, removed, retained, revoked) if inspect.isawaitable(result): assert result is not None await result self._report_global_response(True) async def attach_x11_listener(self, chan: SSHClientChannel[AnyStr], display: Optional[str], auth_path: Optional[str], single_connection: bool) -> \ Tuple[bytes, bytes, int]: """Attach a channel to a local X11 display""" if not display: display = os.environ.get('DISPLAY') if not display: raise ValueError('X11 display not set') if not self._x11_listener: self._x11_listener = await create_x11_client_listener( self._loop, display, auth_path) return self._x11_listener.attach(display, chan, single_connection) def detach_x11_listener(self, chan: SSHChannel[AnyStr]) -> None: """Detach a session from a local X11 listener""" if self._x11_listener: if self._x11_listener.detach(chan): self._x11_listener = None async def create_session(self, session_factory: SSHClientSessionFactory, command: DefTuple[Optional[str]] = (), *, subsystem: DefTuple[Optional[str]]= (), env: DefTuple[Env] = (), send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[Union[bool, str]] = (), term_type: DefTuple[Optional[str]] = (), term_size: DefTuple[TermSizeArg] = (), term_modes: DefTuple[TermModesArg] = (), x11_forwarding: DefTuple[Union[int, str]] = (), x11_display: DefTuple[Optional[str]] = (), x11_auth_path: DefTuple[Optional[str]] = (), x11_single_connection: DefTuple[bool] = (), encoding: DefTuple[Optional[str]] = (), errors: DefTuple[str] = (), window: DefTuple[int] = (), max_pktsize: DefTuple[int] = ()) -> \ Tuple[SSHClientChannel, SSHClientSession]: """Create an SSH client session This method is a coroutine which can be called to create an SSH client session used to execute a command, start a subsystem such as sftp, or if no command or subsystem is specified run an interactive shell. Optional arguments allow terminal and environment information to be provided. By default, this class expects string data in its send and receive functions, which it encodes on the SSH connection in UTF-8 (ISO 10646) format. An optional encoding argument can be passed in to select a different encoding, or `None` can be passed in if the application wishes to send and receive raw bytes. When an encoding is set, an optional errors argument can be passed in to select what Unicode error handling strategy to use. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHClientSession` object that will be created to handle activity on this session :param command: (optional) The remote command to execute. By default, an interactive shell is started if no command or subsystem is provided. :param subsystem: (optional) The name of a remote subsystem to start up. :param env: (optional) The environment variables to set for this session. Keys and values passed in here will be converted to Unicode strings encoded as UTF-8 (ISO 10646) for transmission. .. note:: Many SSH servers restrict which environment variables a client is allowed to set. The server's configuration may need to be edited before environment variables can be successfully set in the remote environment. :param send_env: (optional) A list of environment variable names to pull from `os.environ` and set for this session. Wildcards patterns using `'*'` and `'?'` are allowed, and all variables with matching names will be sent with whatever value is set in the local environment. If a variable is present in both env and send_env, the value from env will be used. :param request_pty: (optional) Whether or not to request a pseudo-terminal (PTY) for this session. This defaults to `True`, which means to request a PTY whenever the `term_type` is set. Other possible values include `False` to never request a PTY, `'force'` to always request a PTY even without `term_type` being set, or `'auto'` to request a TTY when `term_type` is set but only when starting an interactive shell. :param term_type: (optional) The terminal type to set for this session. :param term_size: (optional) The terminal width and height in characters and optionally the width and height in pixels. :param term_modes: (optional) POSIX terminal modes to set for this session, where keys are taken from :ref:`POSIX terminal modes ` with values defined in section 8 of :rfc:`RFC 4254 <4254#section-8>`. :param x11_forwarding: (optional) Whether or not to request X11 forwarding for this session, defaulting to `False`. If set to `True`, X11 forwarding will be requested and a failure will raise :exc:`ChannelOpenError`. It can also be set to `'ignore_failure'` to attempt X11 forwarding but ignore failures. :param x11_display: (optional) The display that X11 connections should be forwarded to, defaulting to the value in the environment variable `DISPLAY`. :param x11_auth_path: (optional) The path to the Xauthority file to read X11 authentication data from, defaulting to the value in the environment variable `XAUTHORITY` or the file :file:`.Xauthority` in the user's home directory if that's not set. :param x11_single_connection: (optional) Whether or not to limit X11 forwarding to a single connection, defaulting to `False`. :param encoding: (optional) The Unicode encoding to use for data exchanged on this session. :param errors: (optional) The error handling strategy to apply on Unicode encode/decode errors. :param window: (optional) The receive window size for this session. :param max_pktsize: (optional) The maximum packet size for this session. :type session_factory: `callable` :type command: `str` :type subsystem: `str` :type env: `dict` with `bytes` or `str` keys and values :type send_env: `list` of `bytes` or `str` :type request_pty: `bool`, `'force'`, or `'auto'` :type term_type: `str` :type term_size: `tuple` of 2 or 4 `int` values :type term_modes: `dict` with `int` keys and values :type x11_forwarding: `bool` or `'ignore_failure'` :type x11_display: `str` :type x11_auth_path: `str` :type x11_single_connection: `bool` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHClientChannel` and :class:`SSHClientSession` :raises: :exc:`ChannelOpenError` if the session can't be opened """ if command == (): command = self._options.command if subsystem == (): subsystem = self._options.subsystem if env == (): env = self._options.env if send_env == (): send_env = self._options.send_env if request_pty == (): request_pty = self._options.request_pty if term_type == (): term_type = self._options.term_type if term_size == (): term_size = self._options.term_size if term_modes == (): term_modes = self._options.term_modes if x11_forwarding == (): x11_forwarding = self._options.x11_forwarding if x11_display == (): x11_display = self._options.x11_display if x11_auth_path == (): x11_auth_path = self._options.x11_auth_path if x11_single_connection == (): x11_single_connection = self._options.x11_single_connection if encoding == (): encoding = self._options.encoding if errors == (): errors = self._options.errors if window == (): window = self._options.window if max_pktsize == (): max_pktsize = self._options.max_pktsize new_env: Dict[bytes, bytes] = {} if send_env: new_env.update(lookup_env(send_env)) if env: new_env.update(encode_env(env)) if request_pty == 'force': request_pty = True elif request_pty == 'auto': request_pty = bool(term_type and not (command or subsystem)) elif request_pty: request_pty = bool(term_type) command: Optional[str] subsystem: Optional[str] request_pty: bool term_type: Optional[str] term_size: TermSizeArg term_modes: TermModesArg x11_forwarding: Union[bool, str] x11_display: Optional[str] x11_auth_path: Optional[str] x11_single_connection: bool encoding: Optional[str] errors: str window: int max_pktsize: int chan = SSHClientChannel(self, self._loop, encoding, errors, window, max_pktsize) session = await chan.create(session_factory, command, subsystem, new_env, request_pty, term_type, term_size, term_modes or {}, x11_forwarding, x11_display, x11_auth_path, x11_single_connection, bool(self._agent_forward_path)) return chan, session async def open_session(self, *args: object, **kwargs: object) -> \ Tuple[SSHWriter, SSHReader, SSHReader]: """Open an SSH client session This method is a coroutine wrapper around :meth:`create_session` designed to provide a "high-level" stream interface for creating an SSH client session. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns an :class:`SSHWriter` and two :class:`SSHReader` objects representing stdin, stdout, and stderr which can be used to perform I/O on the session. With the exception of `session_factory`, all of the arguments to :meth:`create_session` are supported and have the same meaning. """ chan, session = await self.create_session( SSHClientStreamSession, *args, **kwargs) # type: ignore session: SSHClientStreamSession return (SSHWriter(session, chan), SSHReader(session, chan), SSHReader(session, chan, EXTENDED_DATA_STDERR)) # pylint: disable=redefined-builtin @async_context_manager # type: ignore async def create_process(self, *args: object, input: Optional[AnyStr] = None, stdin: ProcessSource = PIPE, stdout: ProcessTarget = PIPE, stderr: ProcessTarget = PIPE, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True, recv_eof: bool = True, **kwargs: object) -> SSHClientProcess[AnyStr]: """Create a process on the remote system This method is a coroutine wrapper around :meth:`create_session` which can be used to execute a command, start a subsystem, or start an interactive shell, optionally redirecting stdin, stdout, and stderr to and from files or pipes attached to other local and remote processes. By default, the stdin, stdout, and stderr arguments default to the special value `PIPE` which means that they can be read and written interactively via stream objects which are members of the :class:`SSHClientProcess` object this method returns. If other file-like objects are provided as arguments, input or output will automatically be redirected to them. The special value `DEVNULL` can be used to provide no input or discard all output, and the special value `STDOUT` can be provided as `stderr` to send its output to the same stream as `stdout`. In addition to the arguments below, all arguments to :meth:`create_session` except for `session_factory` are supported and have the same meaning. :param input: (optional) Input data to feed to standard input of the remote process. If specified, this argument takes precedence over stdin. Data should be a `str` if encoding is set, or `bytes` if not. :param stdin: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHReader` to feed to standard input of the remote process, or `DEVNULL` to provide no input. :param stdout: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHWriter` to feed standard output of the remote process to, or `DEVNULL` to discard this output. :param stderr: (optional) A filename, file-like object, file descriptor, socket, or :class:`SSHWriter` to feed standard error of the remote process to, `DEVNULL` to discard this output, or `STDOUT` to feed standard error to the same place as stdout. :param bufsize: (optional) Buffer size to use when feeding data from a file to stdin :param send_eof: Whether or not to send EOF to the channel when EOF is received from stdin, defaulting to `True`. If set to `False`, the channel will remain open after EOF is received on stdin, and multiple sources can be redirected to the channel. :param recv_eof: Whether or not to send EOF to stdout and stderr when EOF is received from the channel, defaulting to `True`. If set to `False`, the redirect targets of stdout and stderr will remain open after EOF is received on the channel and can be used for multiple redirects. :type input: `str` or `bytes` :type bufsize: `int` :type send_eof: `bool` :type recv_eof: `bool` :returns: :class:`SSHClientProcess` :raises: :exc:`ChannelOpenError` if the channel can't be opened """ chan, process = await self.create_session( SSHClientProcess, *args, **kwargs) # type: ignore new_stdin: Optional[ProcessSource] = stdin process: SSHClientProcess if input: chan.write(input) chan.write_eof() new_stdin = None await process.redirect(new_stdin, stdout, stderr, bufsize, send_eof, recv_eof) return process async def create_subprocess(self, protocol_factory: SubprocessFactory, command: DefTuple[Optional[str]] = (), bufsize: int = io.DEFAULT_BUFFER_SIZE, input: Optional[AnyStr] = None, stdin: ProcessSource = PIPE, stdout: ProcessTarget = PIPE, stderr: ProcessTarget = PIPE, encoding: Optional[str] = None, **kwargs: object) -> \ Tuple[SSHSubprocessTransport, SSHSubprocessProtocol]: """Create a subprocess on the remote system This method is a coroutine wrapper around :meth:`create_session` which can be used to execute a command, start a subsystem, or start an interactive shell, optionally redirecting stdin, stdout, and stderr to and from files or pipes attached to other local and remote processes similar to :meth:`create_process`. However, instead of performing interactive I/O using :class:`SSHReader` and :class:`SSHWriter` objects, the caller provides a function which returns an object which conforms to the :class:`asyncio.SubprocessProtocol` and this call returns that and an :class:`SSHSubprocessTransport` object which conforms to :class:`asyncio.SubprocessTransport`. With the exception of the addition of `protocol_factory`, all of the arguments are the same as :meth:`create_process`. :param protocol_factory: A `callable` which returns an :class:`SSHSubprocessProtocol` object that will be created to handle activity on this session. :type protocol_factory: `callable` :returns: an :class:`SSHSubprocessTransport` and :class:`SSHSubprocessProtocol` :raises: :exc:`ChannelOpenError` if the channel can't be opened """ def transport_factory() -> SSHSubprocessTransport: """Return a subprocess transport""" return SSHSubprocessTransport(protocol_factory) _, transport = await self.create_session(transport_factory, command, encoding=encoding, **kwargs) # type: ignore new_stdin: Optional[ProcessSource] = stdin transport: SSHSubprocessTransport if input: stdin_pipe = cast(SSHSubprocessWritePipe, transport.get_pipe_transport(0)) stdin_pipe.write(input) stdin_pipe.write_eof() new_stdin = None await transport.redirect(new_stdin, stdout, stderr, bufsize) return transport, transport.get_protocol() # pylint: enable=redefined-builtin async def run(self, *args: object, check: bool = False, timeout: Optional[float] = None, **kwargs: object) -> SSHCompletedProcess: """Run a command on the remote system and collect its output This method is a coroutine wrapper around :meth:`create_process` which can be used to run a process to completion when no interactivity is needed. All of the arguments to :meth:`create_process` can be passed in to provide input or redirect stdin, stdout, and stderr, but this method waits until the process exits and returns an :class:`SSHCompletedProcess` object with the exit status or signal information and the output to stdout and stderr (if not redirected). If the check argument is set to `True`, a non-zero exit status from the remote process will trigger the :exc:`ProcessError` exception to be raised. In addition to the argument below, all arguments to :meth:`create_process` are supported and have the same meaning. If a timeout is specified and it expires before the process exits, the :exc:`TimeoutError` exception will be raised. By default, no timeout is set and this call will wait indefinitely. :param check: (optional) Whether or not to raise :exc:`ProcessError` when a non-zero exit status is returned :param timeout: Amount of time in seconds to wait for process to exit or `None` to wait indefinitely :type check: `bool` :type timeout: `int`, `float`, or `None` :returns: :class:`SSHCompletedProcess` :raises: | :exc:`ChannelOpenError` if the session can't be opened | :exc:`ProcessError` if checking non-zero exit status | :exc:`TimeoutError` if the timeout expires before exit """ process = await self.create_process(*args, **kwargs) # type: ignore return await process.wait(check, timeout) async def create_connection( self, session_factory: SSHTCPSessionFactory[AnyStr], remote_host: str, remote_port: int, orig_host: str = '', orig_port: int = 0, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHTCPChannel[AnyStr], SSHTCPSession[AnyStr]]: """Create an SSH TCP direct connection This method is a coroutine which can be called to request that the server open a new outbound TCP connection to the specified destination host and port. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTCPSession` object created by `session_factory`. Optional arguments include the host and port of the original client opening the connection when performing TCP port forwarding. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application send and receive string data. When encoding is set, an optional errors argument can be passed in to select what Unicode error handling strategy to use. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHTCPSession` object that will be created to handle activity on this session :param remote_host: The remote hostname or address to connect to :param remote_port: The remote port number to connect to :param orig_host: (optional) The hostname or address of the client requesting the connection :param orig_port: (optional) The port number of the client requesting the connection :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_host: `str` :type remote_port: `int` :type orig_host: `str` :type orig_port: `int` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHTCPChannel` and :class:`SSHTCPSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ self.logger.info('Opening direct TCP connection to %s', (remote_host, remote_port)) self.logger.info(' Client address: %s', (orig_host, orig_port)) chan = self.create_tcp_channel(encoding, errors, window, max_pktsize) session = await chan.connect(session_factory, remote_host, remote_port, orig_host, orig_port) return chan, session async def open_connection(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH TCP direct connection This method is a coroutine wrapper around :meth:`create_connection` designed to provide a "high-level" stream interface for creating an SSH TCP direct connection. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of `session_factory`, all of the arguments to :meth:`create_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = await self.create_connection( SSHTCPStreamSession, *args, **kwargs) # type: ignore session: SSHTCPStreamSession return SSHReader(session, chan), SSHWriter(session, chan) @async_context_manager async def create_server( self, session_factory: TCPListenerFactory[AnyStr], listen_host: str, listen_port: int, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHListener: """Create a remote SSH TCP listener This method is a coroutine which can be called to request that the server listen on the specified remote address and port for incoming TCP connections. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the listener. If the request fails, `None` is returned. :param session_factory: A `callable` or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning an :class:`SSHTCPSession` object used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :param listen_host: The hostname or address on the remote host to listen on :param listen_port: The port number on the remote host to listen on :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` or coroutine :type listen_host: `str` :type listen_port: `int` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ listen_host = listen_host.lower() self.logger.info('Creating remote TCP listener on %s', (listen_host, listen_port)) pkttype, packet = await self._make_global_request( b'tcpip-forward', String(listen_host), UInt32(listen_port)) if pkttype == MSG_REQUEST_SUCCESS: if listen_port == 0: listen_port = packet.get_uint32() dynamic = True else: # OpenSSH 6.8 introduced a bug which causes the reply # to contain an extra uint32 value of 0 when non-dynamic # ports are requested, causing the check_end() call below # to fail. This check works around this problem. if len(packet.get_remaining_payload()) == 4: # pragma: no cover packet.get_uint32() dynamic = False packet.check_end() listener = SSHTCPClientListener[AnyStr](self, session_factory, listen_host, listen_port, encoding, errors, window, max_pktsize) if dynamic: self.logger.debug1('Assigning dynamic port %d', listen_port) self._dynamic_remote_listeners[listen_host] = listener self._remote_listeners[listen_host, listen_port] = listener return listener else: packet.check_end() self.logger.debug1('Failed to create remote TCP listener') raise ChannelListenError('Failed to create remote TCP listener') @async_context_manager async def start_server(self, handler_factory: _TCPServerHandlerFactory, *args: object, **kwargs: object) -> SSHListener: """Start a remote SSH TCP listener This method is a coroutine wrapper around :meth:`create_server` designed to provide a "high-level" stream interface for creating remote SSH TCP listeners. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it takes a `handler_factory` which returns a `callable` or coroutine that will be passed :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on each new connection which arrives. Like :meth:`create_server`, `handler_factory` can also raise :exc:`ChannelOpenError` if the connection should not be accepted. With the exception of `handler_factory` replacing `session_factory`, all of the arguments to :meth:`create_server` are supported and have the same meaning here. :param handler_factory: A `callable` or coroutine which takes arguments of the original host and port of the client and decides whether to accept the connection or not, either returning a callback or coroutine used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :type handler_factory: `callable` or coroutine :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory(orig_host: str, orig_port: int) -> SSHTCPSession: """Return a TCP stream session handler""" return SSHTCPStreamSession(handler_factory(orig_host, orig_port)) return await self.create_server(session_factory, *args, **kwargs) # type: ignore async def create_unix_connection( self, session_factory: SSHUNIXSessionFactory[AnyStr], remote_path: str, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]: """Create an SSH UNIX domain socket direct connection This method is a coroutine which can be called to request that the server open a new outbound UNIX domain socket connection to the specified destination path. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHUNIXSession` object created by `session_factory`. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application to send and receive string data. When encoding is set, an optional errors argument can be passed in to select what Unicode error handling strategy to use. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHUNIXSession` object that will be created to handle activity on this session :param remote_path: The remote path to connect to :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_path: `str` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHUNIXChannel` and :class:`SSHUNIXSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ self.logger.info('Opening direct UNIX connection to %s', remote_path) chan = self.create_unix_channel(encoding, errors, window, max_pktsize) session = await chan.connect(session_factory, remote_path) return chan, session async def open_unix_connection(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH UNIX domain socket direct connection This method is a coroutine wrapper around :meth:`create_unix_connection` designed to provide a "high-level" stream interface for creating an SSH UNIX domain socket direct connection. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of `session_factory`, all of the arguments to :meth:`create_unix_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = \ await self.create_unix_connection(SSHUNIXStreamSession, *args, **kwargs) # type: ignore session: SSHUNIXStreamSession return SSHReader(session, chan), SSHWriter(session, chan) @async_context_manager async def create_unix_server( self, session_factory: UNIXListenerFactory[AnyStr], listen_path: str, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHListener: """Create a remote SSH UNIX domain socket listener This method is a coroutine which can be called to request that the server listen on the specified remote path for incoming UNIX domain socket connections. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the listener. If the request fails, `None` is returned. :param session_factory: A `callable` or coroutine which decides whether to accept the connection or not, either returning an :class:`SSHUNIXSession` object used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :param listen_path: The path on the remote host to listen on :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type listen_path: `str` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ self.logger.info('Creating remote UNIX listener on %s', listen_path) pkttype, packet = await self._make_global_request( b'streamlocal-forward@openssh.com', String(listen_path)) packet.check_end() if pkttype == MSG_REQUEST_SUCCESS: listener = SSHUNIXClientListener[AnyStr](self, session_factory, listen_path, encoding, errors, window, max_pktsize) self._remote_listeners[listen_path] = listener return listener else: self.logger.debug1('Failed to create remote UNIX listener') raise ChannelListenError('Failed to create remote UNIX listener') @async_context_manager async def start_unix_server( self, handler_factory: _UNIXServerHandlerFactory, *args: object, **kwargs: object) -> SSHListener: """Start a remote SSH UNIX domain socket listener This method is a coroutine wrapper around :meth:`create_unix_server` designed to provide a "high-level" stream interface for creating remote SSH UNIX domain socket listeners. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it takes a `handler_factory` which returns a `callable` or coroutine that will be passed :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on each new connection which arrives. Like :meth:`create_unix_server`, `handler_factory` can also raise :exc:`ChannelOpenError` if the connection should not be accepted. With the exception of `handler_factory` replacing `session_factory`, all of the arguments to :meth:`create_unix_server` are supported and have the same meaning here. :param handler_factory: A `callable` or coroutine which decides whether to accept the UNIX domain socket connection or not, either returning a callback or coroutine used to handle activity on that connection or raising :exc:`ChannelOpenError` to indicate that the connection should not be accepted :type handler_factory: `callable` or coroutine :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory() -> SSHUNIXStreamSession: """Return a UNIX domain socket stream session handler""" return SSHUNIXStreamSession(handler_factory()) return await self.create_unix_server(session_factory, *args, **kwargs) # type: ignore async def create_ssh_connection(self, client_factory: _ClientFactory, host: str, port: DefTuple[int] = (), **kwargs: object) -> \ Tuple['SSHClientConnection', SSHClient]: """Create a tunneled SSH client connection This method is a coroutine which can be called to open an SSH client connection to the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`create_connection` but requests that the upstream SSH server open the connection rather than connecting directly. """ return (await create_connection(client_factory, host, port, tunnel=self, **kwargs)) # type: ignore @async_context_manager async def connect_ssh(self, host: str, port: DefTuple[int] = (), **kwargs: object) -> 'SSHClientConnection': """Make a tunneled SSH client connection This method is a coroutine which can be called to open an SSH client connection to the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`connect` but requests that the upstream SSH server open the connection rather than connecting directly. """ return await connect(host, port, tunnel=self, **kwargs) # type: ignore @async_context_manager async def connect_reverse_ssh(self, host: str, port: DefTuple[int] = (), **kwargs: object) -> 'SSHServerConnection': """Make a tunneled reverse direction SSH connection This method is a coroutine which can be called to open an SSH client connection to the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`connect` but requests that the upstream SSH server open the connection rather than connecting directly. """ return await connect_reverse(host, port, tunnel=self, **kwargs) # type: ignore @async_context_manager async def listen_ssh(self, host: str = '', port: DefTuple[int] = (), **kwargs: object) -> SSHAcceptor: """Create a tunneled SSH listener This method is a coroutine which can be called to open a remote SSH listener on the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`listen` but requests that the upstream SSH server open the listener rather than listening directly via TCP/IP. """ return await listen(host, port, tunnel=self, **kwargs) # type: ignore @async_context_manager async def listen_reverse_ssh(self, host: str = '', port: DefTuple[int] = (), **kwargs: object) -> SSHAcceptor: """Create a tunneled reverse direction SSH listener This method is a coroutine which can be called to open a remote SSH listener on the requested host and port tunneled inside this already established connection. It takes all the same arguments as :func:`listen_reverse` but requests that the upstream SSH server open the listener rather than listening directly via TCP/IP. """ return await listen_reverse(host, port, tunnel=self, **kwargs) # type: ignore async def create_tun( self, session_factory: SSHTunTapSessionFactory, remote_unit: Optional[int] = None, *, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHTunTapChannel, SSHTunTapSession]: """Create an SSH layer 3 tunnel This method is a coroutine which can be called to request that the server open a new outbound layer 3 tunnel to the specified remote TUN device. If the tunnel is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTunTapSession` object created by `session_factory`. Optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHUNIXSession` object that will be created to handle activity on this session :param remote_unit: The remote TUN device to connect to :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_unit: `int` or `None` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHTunTapChannel` and :class:`SSHTunTapSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ self.logger.info('Opening layer 3 tunnel to remote unit %s', 'any' if remote_unit is None else str(remote_unit)) chan = self.create_tuntap_channel(window, max_pktsize) session = await chan.open(session_factory, SSH_TUN_MODE_POINTTOPOINT, remote_unit) return chan, session async def create_tap( self, session_factory: SSHTunTapSessionFactory, remote_unit: Optional[int] = None, *, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHTunTapChannel, SSHTunTapSession]: """Create an SSH layer 2 tunnel This method is a coroutine which can be called to request that the server open a new outbound layer 2 tunnel to the specified remote TAP device. If the tunnel is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTunTapSession` object created by `session_factory`. Optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHUNIXSession` object that will be created to handle activity on this session :param remote_unit: The remote TAP device to connect to :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_unit: `int` or `None` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHTunTapChannel` and :class:`SSHTunTapSession` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ self.logger.info('Opening layer 2 tunnel to remote unit %s', 'any' if remote_unit is None else str(remote_unit)) chan = self.create_tuntap_channel(window, max_pktsize) session = await chan.open(session_factory, SSH_TUN_MODE_ETHERNET, remote_unit) return chan, session async def open_tun(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH layer 3 tunnel This method is a coroutine wrapper around :meth:`create_tun` designed to provide a "high-level" stream interface for creating an SSH layer 3 tunnel. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the tunnel. With the exception of `session_factory`, all of the arguments to :meth:`create_tun` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = await self.create_tun(SSHTunTapStreamSession, *args, **kwargs) # type: ignore session: SSHTunTapStreamSession return SSHReader(session, chan), SSHWriter(session, chan) async def open_tap(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH layer 2 tunnel This method is a coroutine wrapper around :meth:`create_tap` designed to provide a "high-level" stream interface for creating an SSH layer 2 tunnel. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the tunnel. With the exception of `session_factory`, all of the arguments to :meth:`create_tap` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` :raises: :exc:`ChannelOpenError` if the connection can't be opened """ chan, session = await self.create_tap(SSHTunTapStreamSession, *args, **kwargs) # type: ignore session: SSHTunTapStreamSession return SSHReader(session, chan), SSHWriter(session, chan) @async_context_manager async def forward_local_port_to_path( self, listen_host: str, listen_port: int, dest_path: str, accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: """Set up local TCP port forwarding to a remote UNIX domain socket This method is a coroutine which attempts to set up port forwarding from a local TCP listening port to a remote UNIX domain path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. :param listen_host: The hostname or address on the local host to listen on :param listen_port: The port number on the local host to listen on :param dest_path: The path on the remote host to forward the connections to :param accept_handler: A `callable` or coroutine which takes arguments of the original host and port of the client and decides whether or not to allow connection forwarding, returning `True` to accept the connection and begin forwarding or `False` to reject and close it. :type listen_host: `str` :type listen_port: `int` :type dest_path: `str` :type accept_handler: `callable` or coroutine :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ async def tunnel_connection( session_factory: SSHUNIXSessionFactory[bytes], orig_host: str, orig_port: int) -> \ Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]: """Forward a local connection over SSH""" if accept_handler: result = accept_handler(orig_host, orig_port) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) if not result: self.logger.info('Request for TCP forwarding from ' '%s to %s denied by application', (orig_host, orig_port), dest_path) raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Connection forwarding denied') return (await self.create_unix_connection(session_factory, dest_path)) self.logger.info('Creating local TCP forwarder from %s to %s', (listen_host, listen_port), dest_path) try: listener = await create_tcp_forward_listener(self, self._loop, tunnel_connection, listen_host, listen_port) except OSError as exc: self.logger.debug1('Failed to create local TCP listener: %s', exc) raise if listen_port == 0: listen_port = listener.get_port() self._local_listeners[listen_host, listen_port] = listener return listener @async_context_manager async def forward_local_path_to_port(self, listen_path: str, dest_host: str, dest_port: int) -> SSHListener: """Set up local UNIX domain socket forwarding to a remote TCP port This method is a coroutine which attempts to set up UNIX domain socket forwarding from a local listening path to a remote host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the UNIX domain socket forwarding. :param listen_path: The path on the local host to listen on :param dest_host: The hostname or address to forward the connections to :param dest_port: The port number to forward the connections to :type listen_path: `str` :type dest_host: `str` :type dest_port: `int` :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ async def tunnel_connection( session_factory: SSHTCPSessionFactory[bytes]) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Forward a local connection over SSH""" return await self.create_connection(session_factory, dest_host, dest_port, '', 0) self.logger.info('Creating local UNIX forwarder from %s to %s', listen_path, (dest_host, dest_port)) try: listener = await create_unix_forward_listener(self, self._loop, tunnel_connection, listen_path) except OSError as exc: self.logger.debug1('Failed to create local UNIX listener: %s', exc) raise self._local_listeners[listen_path] = listener return listener @async_context_manager async def forward_remote_port(self, listen_host: str, listen_port: int, dest_host: str, dest_port: int) -> SSHListener: """Set up remote port forwarding This method is a coroutine which attempts to set up port forwarding from a remote listening port to a local host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, `None` is returned. :param listen_host: The hostname or address on the remote host to listen on :param listen_port: The port number on the remote host to listen on :param dest_host: The hostname or address to forward connections to :param dest_port: The port number to forward connections to :type listen_host: `str` :type listen_port: `int` :type dest_host: `str` :type dest_port: `int` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory(_orig_host: str, _orig_port: int) -> Awaitable[SSHTCPSession]: """Return an SSHTCPSession used to do remote port forwarding""" return cast(Awaitable[SSHTCPSession], self.forward_connection(dest_host, dest_port)) self.logger.info('Creating remote TCP forwarder from %s to %s', (listen_host, listen_port), (dest_host, dest_port)) return await self.create_server(session_factory, listen_host, listen_port) @async_context_manager async def forward_remote_path(self, listen_path: str, dest_path: str) -> SSHListener: """Set up remote UNIX domain socket forwarding This method is a coroutine which attempts to set up UNIX domain socket forwarding from a remote listening path to a local path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, `None` is returned. :param listen_path: The path on the remote host to listen on :param dest_path: The path on the local host to forward connections to :type listen_path: `str` :type dest_path: `str` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory() -> Awaitable[SSHUNIXSession[bytes]]: """Return an SSHUNIXSession used to do remote path forwarding""" return cast(Awaitable[SSHUNIXSession[bytes]], self.forward_unix_connection(dest_path)) self.logger.info('Creating remote UNIX forwarder from %s to %s', listen_path, dest_path) return await self.create_unix_server(session_factory, listen_path) @async_context_manager async def forward_remote_port_to_path(self, listen_host: str, listen_port: int, dest_path: str) -> SSHListener: """Set up remote TCP port forwarding to a local UNIX domain socket This method is a coroutine which attempts to set up port forwarding from a remote TCP listening port to a local UNIX domain socket path via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, `None` is returned. :param listen_host: The hostname or address on the remote host to listen on :param listen_port: The port number on the remote host to listen on :param dest_path: The path on the local host to forward connections to :type listen_host: `str` :type listen_port: `int` :type dest_path: `str` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory(_orig_host: str, _orig_port: int) -> Awaitable[SSHUNIXSession]: """Return an SSHTCPSession used to do remote port forwarding""" return cast(Awaitable[SSHUNIXSession], self.forward_unix_connection(dest_path)) self.logger.info('Creating remote TCP forwarder from %s to %s', (listen_host, listen_port), dest_path) return await self.create_server(session_factory, listen_host, listen_port) @async_context_manager async def forward_remote_path_to_port(self, listen_path: str, dest_host: str, dest_port: int) -> SSHListener: """Set up remote UNIX domain socket forwarding to a local TCP port This method is a coroutine which attempts to set up UNIX domain socket forwarding from a remote listening path to a local TCP host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. If the request fails, `None` is returned. :param listen_path: The path on the remote host to listen on :param dest_host: The hostname or address to forward connections to :param dest_port: The port number to forward connections to :type listen_path: `str` :type dest_host: `str` :type dest_port: `int` :returns: :class:`SSHListener` :raises: :class:`ChannelListenError` if the listener can't be opened """ def session_factory() -> Awaitable[SSHTCPSession[bytes]]: """Return an SSHUNIXSession used to do remote path forwarding""" return cast(Awaitable[SSHTCPSession[bytes]], self.forward_connection(dest_host, dest_port)) self.logger.info('Creating remote UNIX forwarder from %s to %s', listen_path, (dest_host, dest_port)) return await self.create_unix_server(session_factory, listen_path) @async_context_manager async def forward_socks(self, listen_host: str, listen_port: int) -> SSHListener: """Set up local port forwarding via SOCKS This method is a coroutine which attempts to set up dynamic port forwarding via SOCKS on the specified local host and port. Each SOCKS request contains the destination host and port to connect to and triggers a request to tunnel traffic to the requested host and port via the SSH connection. If the request is successful, the return value is an :class:`SSHListener` object which can be used later to shut down the port forwarding. :param listen_host: The hostname or address on the local host to listen on :param listen_port: The port number on the local host to listen on :type listen_host: `str` :type listen_port: `int` :returns: :class:`SSHListener` :raises: :exc:`OSError` if the listener can't be opened """ async def tunnel_socks(session_factory: SSHTCPSessionFactory[bytes], dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Forward a local SOCKS connection over SSH""" return await self.create_connection(session_factory, dest_host, dest_port, orig_host, orig_port) self.logger.info('Creating local SOCKS forwarder on %s', (listen_host, listen_port)) try: listener = await create_socks_listener(self, self._loop, tunnel_socks, listen_host, listen_port) except OSError as exc: self.logger.debug1('Failed to create local SOCKS listener: %s', exc) raise if listen_port == 0: listen_port = listener.get_port() self._local_listeners[listen_host, listen_port] = listener return listener @async_context_manager async def forward_tun(self, local_unit: Optional[int] = None, remote_unit: Optional[int] = None) -> SSHForwarder: """Set up layer 3 forwarding This method is a coroutine which attempts to set up layer 3 packet forwarding between local and remote TUN devices. If the request is successful, the return value is an :class:`SSHForwarder` object which can be used later to shut down the forwarding. :param local_unit: The unit number of the local TUN device to use :param remote_unit: The unit number of the remote TUN device to use :type local_unit: `int` or `None` :type remote_unit: `int` or `None` :returns: :class:`SSHForwarder` :raises: | :exc:`OSError` if the local TUN device can't be opened | :exc:`ChannelOpenError` if the SSH channel can't be opened """ def session_factory() -> SSHTunTapSession: """Return an SSHTunTapSession used to do layer 3 forwarding""" return cast(SSHTunTapSession, self.forward_tuntap(SSH_TUN_MODE_POINTTOPOINT, local_unit)) _, peer = await self.create_tun(session_factory, remote_unit) return cast(SSHForwarder, peer) @async_context_manager async def forward_tap(self, local_unit: Optional[int] = None, remote_unit: Optional[int] = None) -> SSHForwarder: """Set up layer 2 forwarding This method is a coroutine which attempts to set up layer 2 packet forwarding between local and remote TAP devices. If the request is successful, the return value is an :class:`SSHForwarder` object which can be used later to shut down the forwarding. :param local_unit: The unit number of the local TAP device to use :param remote_unit: The unit number of the remote TAP device to use :type local_unit: `int` or `None` :type remote_unit: `int` or `None` :returns: :class:`SSHForwarder` :raises: | :exc:`OSError` if the local TUN device can't be opened | :exc:`ChannelOpenError` if the SSH channel can't be opened """ def session_factory() -> SSHTunTapSession: """Return an SSHTunTapSession used to do layer 2 forwarding""" return cast(SSHTunTapSession, self.forward_tuntap(SSH_TUN_MODE_ETHERNET, local_unit)) _, peer = await self.create_tap(session_factory, remote_unit) return cast(SSHForwarder, peer) @async_context_manager async def start_sftp_client(self, env: DefTuple[Env] = (), send_env: DefTuple[Optional[EnvSeq]] = (), path_encoding: Optional[str] = 'utf-8', path_errors = 'strict', sftp_version = MIN_SFTP_VERSION) -> SFTPClient: """Start an SFTP client This method is a coroutine which attempts to start a secure file transfer session. If it succeeds, it returns an :class:`SFTPClient` object which can be used to copy and access files on the remote host. An optional Unicode encoding can be specified for sending and receiving pathnames, defaulting to UTF-8 with strict error checking. If an encoding of `None` is specified, pathnames will be left as bytes rather than being converted to & from strings. :param env: (optional) The environment variables to set for this SFTP session. Keys and values passed in here will be converted to Unicode strings encoded as UTF-8 (ISO 10646) for transmission. .. note:: Many SSH servers restrict which environment variables a client is allowed to set. The server's configuration may need to be edited before environment variables can be successfully set in the remote environment. :param send_env: (optional) A list of environment variable names to pull from `os.environ` and set for this SFTP session. Wildcards patterns using `'*'` and `'?'` are allowed, and all variables with matching names will be sent with whatever value is set in the local environment. If a variable is present in both env and send_env, the value from env will be used. :param path_encoding: The Unicode encoding to apply when sending and receiving remote pathnames :param path_errors: The error handling strategy to apply on encode/decode errors :param sftp_version: (optional) The maximum version of the SFTP protocol to support, currently either 3 or 4, defaulting to 3. :type env: `dict` with `str` keys and values :type send_env: `list` of `str` :type path_encoding: `str` or `None` :type path_errors: `str` :type sftp_version: `int` :returns: :class:`SFTPClient` :raises: :exc:`SFTPError` if the session can't be opened """ writer, reader, _ = await self.open_session(subsystem='sftp', env=env, send_env=send_env, encoding=None) return await start_sftp_client(self, self._loop, reader, writer, path_encoding, path_errors, sftp_version) class SSHServerConnection(SSHConnection): """SSH server connection This class represents an SSH server connection. During authentication, :meth:`send_auth_banner` can be called to send an authentication banner to the client. Once authenticated, :class:`SSHServer` objects wishing to create session objects with non-default channel properties can call :meth:`create_server_channel` from their :meth:`session_requested() ` method and return a tuple of the :class:`SSHServerChannel` object returned from that and either an :class:`SSHServerSession` object or a coroutine which returns an :class:`SSHServerSession`. Similarly, :class:`SSHServer` objects wishing to create TCP connection objects with non-default channel properties can call :meth:`create_tcp_channel` from their :meth:`connection_requested() ` method and return a tuple of the :class:`SSHTCPChannel` object returned from that and either an :class:`SSHTCPSession` object or a coroutine which returns an :class:`SSHTCPSession`. :class:`SSHServer` objects wishing to create UNIX domain socket connection objects with non-default channel properties can call :meth:`create_unix_channel` from the :meth:`unix_connection_requested() ` method and return a tuple of the :class:`SSHUNIXChannel` object returned from that and either an :class:`SSHUNIXSession` object or a coroutine which returns an :class:`SSHUNIXSession`. """ _options: 'SSHServerConnectionOptions' _owner: SSHServer _x11_listener: Optional[SSHX11ServerListener] def __init__(self, loop: asyncio.AbstractEventLoop, options: 'SSHServerConnectionOptions', acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, wait: Optional[str] = None): super().__init__(loop, options, acceptor, error_handler, wait, server=True) self._options = options self._server_host_keys = options.server_host_keys self._all_server_host_keys = options.all_server_host_keys self._server_host_key_algs = list(options.server_host_keys.keys()) self._known_client_hosts = options.known_client_hosts self._trust_client_host = options.trust_client_host self._authorized_client_keys = options.authorized_client_keys self._allow_pty = options.allow_pty self._line_editor = options.line_editor self._line_echo = options.line_echo self._line_history = options.line_history self._max_line_length = options.max_line_length self._rdns_lookup = options.rdns_lookup self._x11_forwarding = options.x11_forwarding self._x11_auth_path = options.x11_auth_path self._agent_forwarding = options.agent_forwarding self._process_factory = options.process_factory self._session_factory = options.session_factory self._encoding = options.encoding self._errors = options.errors self._sftp_factory = options.sftp_factory self._sftp_version = options.sftp_version self._allow_scp = options.allow_scp self._window = options.window self._max_pktsize = options.max_pktsize if options.gss_host: try: self._gss = GSSServer(options.gss_host, options.gss_store) self._gss_kex = options.gss_kex self._gss_auth = options.gss_auth self._gss_mic_auth = self._gss_auth except GSSError: pass self._server_host_key: Optional[SSHKeyPair] = None self._key_options: _KeyOrCertOptions = {} self._cert_options: Optional[_KeyOrCertOptions] = None self._kbdint_password_auth = False self._agent_listener: Optional[SSHAgentListener] = None def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this server connection""" if self._agent_listener: self._agent_listener.close() self._agent_listener = None super()._cleanup(exc) def _connection_made(self) -> None: """Handle the opening of a new connection""" self.logger.info('Accepted SSH client connection') if self._options.proxy_command: proxy_command = ' '.join(shlex.quote(arg) for arg in self._options.proxy_command) self.logger.info(' Proxy command: %s', proxy_command) else: self.logger.info(' Local address: %s', (self._local_addr, self._local_port)) self.logger.info(' Peer address: %s', (self._peer_addr, self._peer_port)) async def reload_config(self) -> None: """Re-evaluate config with updated match options""" if self._rdns_lookup: self._peer_host, _ = await self._loop.getnameinfo( (self._peer_addr, self._peer_port), socket.NI_NUMERICSERV) options = await SSHServerConnectionOptions.construct( options=self._options, reload=True, accept_addr=self._local_addr, accept_port=self._local_port, username=self._username, client_host=self._peer_host, client_addr=self._peer_addr) self._options = options self._host_based_auth = options.host_based_auth self._public_key_auth = options.public_key_auth self._kbdint_auth = options.kbdint_auth self._password_auth = options.password_auth self._authorized_client_keys = options.authorized_client_keys self._allow_pty = options.allow_pty self._x11_forwarding = options.x11_forwarding self._agent_forwarding = options.agent_forwarding self._rekey_bytes = options.rekey_bytes self._rekey_seconds = options.rekey_seconds self._keepalive_count_max = options.keepalive_count_max self._keepalive_interval = options.keepalive_interval def choose_server_host_key(self, peer_host_key_algs: Sequence[bytes]) -> bool: """Choose the server host key to use Given a list of host key algorithms supported by the client, select the first compatible server host key we have and return whether or not we were able to find a match. """ for alg in peer_host_key_algs: keypair = self._server_host_keys.get(alg) if keypair: if alg != keypair.algorithm: keypair.set_sig_algorithm(alg) self._server_host_key = keypair return True return False def get_server_host_key(self) -> Optional[SSHKeyPair]: """Return the chosen server host key This method returns a keypair object containing the chosen server host key and a corresponding public key or certificate. """ return self._server_host_key def send_server_host_keys(self) -> None: """Send list of available server host keys""" if self._all_server_host_keys: self.logger.info('Sending server host keys') keys = [String(key) for key in self._all_server_host_keys.keys()] self._send_global_request(b'hostkeys-00@openssh.com', *keys) else: self.logger.info('Sending server host keys disabled') def gss_kex_auth_supported(self) -> bool: """Return whether GSS key exchange authentication is supported""" if self._gss_kex_auth: assert self._gss is not None return self._gss.complete else: return False def gss_mic_auth_supported(self) -> bool: """Return whether GSS MIC authentication is supported""" return self._gss_mic_auth async def validate_gss_principal(self, username: str, user_principal: str, host_principal: str) -> bool: """Validate the GSS principal name for the specified user Return whether the user principal acquired during GSS authentication is valid for the specified user. """ result = self._owner.validate_gss_principal(username, user_principal, host_principal) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) return cast(bool, result) def host_based_auth_supported(self) -> bool: """Return whether or not host based authentication is supported""" return (self._host_based_auth and (bool(self._known_client_hosts) or self._owner.host_based_auth_supported())) async def validate_host_based_auth(self, username: str, key_data: bytes, client_host: str, client_username: str, msg: bytes, signature: bytes) -> bool: """Validate host based authentication for the specified host and user""" # Remove a trailing '.' from the client host if present if client_host[-1:] == '.': client_host = client_host[:-1] if self._trust_client_host: resolved_host = client_host else: peername = cast(SockAddr, self.get_extra_info('peername')) try: resolved_host, _ = await self._loop.getnameinfo( peername, socket.NI_NUMERICSERV) except socket.gaierror: resolved_host = peername[0] if resolved_host != client_host: self.logger.info('Client host mismatch: received %s, ' 'resolved %s', client_host, resolved_host) if self._known_client_hosts: self._match_known_hosts(self._known_client_hosts, resolved_host, self._peer_addr, None) try: key = self._validate_host_key(resolved_host, self._peer_addr, self._peer_port, key_data) except ValueError as exc: self.logger.debug1('Invalid host key: %s', exc) return False if not key.verify(String(self._session_id) + msg, signature): self.logger.debug1('Invalid host-based auth signature') return False result = self._owner.validate_host_based_user(username, client_host, client_username) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) return cast(bool, result) async def _validate_openssh_certificate( self, username: str, cert: SSHOpenSSHCertificate) -> \ Optional[SSHKey]: """Validate an OpenSSH client certificate for the specified user""" options: Optional[_KeyOrCertOptions] = None if self._authorized_client_keys: options = self._authorized_client_keys.validate( cert.signing_key, self._peer_host, self._peer_addr, cert.principals, ca=True) if options is None: result = self._owner.validate_ca_key(username, cert.signing_key) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) if not result: return None options = {} self._key_options = options cert_user = None if self.get_key_option('principals') else username try: cert.validate(CERT_TYPE_USER, cert_user) except ValueError: return None allowed_addresses = cast(Sequence[IPNetwork], cert.options.get('source-address')) if allowed_addresses: ip = ip_address(self._peer_addr) if not any(ip in network for network in allowed_addresses): return None self._cert_options = cert.options cert.key.set_touch_required( not (self.get_key_option('no-touch-required', False) and self.get_certificate_option('no-touch-required', False))) return cert.key async def _validate_x509_certificate_chain( self, username: str, cert: SSHX509CertificateChain) -> \ Optional[SSHKey]: """Validate an X.509 client certificate for the specified user""" if not self._authorized_client_keys: return None options, trusted_cert = \ self._authorized_client_keys.validate_x509( cert, self._peer_host, self._peer_addr) if options is None: return None self._key_options = options if self.get_key_option('principals'): username = '' assert self._x509_trusted_certs is not None trusted_certs = list(self._x509_trusted_certs) if trusted_cert: trusted_certs += [trusted_cert] try: cert.validate_chain(trusted_certs, self._x509_trusted_cert_paths, set(), self._x509_purposes, user_principal=username) except ValueError: return None return cert.key async def _validate_client_certificate( self, username: str, key_data: bytes) -> Optional[SSHKey]: """Validate a client certificate for the specified user""" try: cert = decode_ssh_certificate(key_data) except KeyImportError: return None if cert.is_x509_chain: return await self._validate_x509_certificate_chain( username, cast(SSHX509CertificateChain, cert)) else: return await self._validate_openssh_certificate( username, cast(SSHOpenSSHCertificate, cert)) async def _validate_client_public_key(self, username: str, key_data: bytes) -> Optional[SSHKey]: """Validate a client public key for the specified user""" try: key = decode_ssh_public_key(key_data) except KeyImportError: return None options: Optional[_KeyOrCertOptions] = None if self._authorized_client_keys: options = self._authorized_client_keys.validate( key, self._peer_host, self._peer_addr) if options is None: result = self._owner.validate_public_key(username, key) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) if not result: return None options = {} self._key_options = options key.set_touch_required( not self.get_key_option('no-touch-required', False)) return key def public_key_auth_supported(self) -> bool: """Return whether or not public key authentication is supported""" return (self._public_key_auth and (bool(self._authorized_client_keys) or self._owner.public_key_auth_supported())) async def validate_public_key(self, username: str, key_data: bytes, msg: bytes, signature: bytes) -> bool: """Validate the public key or certificate for the specified user This method validates that the public key or certificate provided is allowed for the specified user. If msg and signature are provided, the key is used to also validate the message signature. It returns `True` when the key is allowed and the signature (if present) is valid. Otherwise, it returns `False`. """ key = ((await self._validate_client_certificate(username, key_data)) or (await self._validate_client_public_key(username, key_data))) if key is None: return False elif msg: return key.verify(String(self._session_id) + msg, signature) else: return True def password_auth_supported(self) -> bool: """Return whether or not password authentication is supported""" return self._password_auth and self._owner.password_auth_supported() async def validate_password(self, username: str, password: str) -> bool: """Return whether password is valid for this user""" result = self._owner.validate_password(username, password) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) return cast(bool, result) async def change_password(self, username: str, old_password: str, new_password: str) -> bool: """Handle a password change request for a user""" result = self._owner.change_password(username, old_password, new_password) if inspect.isawaitable(result): result = await cast(Awaitable[bool], result) return cast(bool, result) def kbdint_auth_supported(self) -> bool: """Return whether or not keyboard-interactive authentication is supported""" result = self._kbdint_auth and self._owner.kbdint_auth_supported() if result is True: return True elif (result is NotImplemented and self._owner.password_auth_supported()): self._kbdint_password_auth = True return True else: return False async def get_kbdint_challenge(self, username: str, lang: str, submethods: str) -> KbdIntChallenge: """Return a keyboard-interactive auth challenge""" if self._kbdint_password_auth: challenge: KbdIntChallenge = ('', '', DEFAULT_LANG, (('Password:', False),)) else: result = self._owner.get_kbdint_challenge(username, lang, submethods) if inspect.isawaitable(result): challenge = await cast(Awaitable[KbdIntChallenge], result) else: challenge = cast(KbdIntChallenge, result) return challenge async def validate_kbdint_response(self, username: str, responses: KbdIntResponse) -> \ KbdIntChallenge: """Return whether the keyboard-interactive response is valid for this user""" next_challenge: KbdIntChallenge if self._kbdint_password_auth: if len(responses) != 1: return False try: pw_result = self._owner.validate_password( username, responses[0]) if inspect.isawaitable(pw_result): next_challenge = await cast(Awaitable[bool], pw_result) else: next_challenge = cast(bool, pw_result) except PasswordChangeRequired: # Don't support password change requests for now in # keyboard-interactive auth next_challenge = False else: result = self._owner.validate_kbdint_response(username, responses) if inspect.isawaitable(result): next_challenge = await cast(Awaitable[KbdIntChallenge], result) else: next_challenge = cast(KbdIntChallenge, result) return next_challenge def _process_session_open(self, packet: SSHPacket) -> \ Tuple[SSHServerChannel, SSHServerSession]: """Process an incoming session open request""" packet.check_end() chan: SSHServerChannel session: SSHServerSession if self._process_factory or self._session_factory or self._sftp_factory: chan = self.create_server_channel(self._encoding, self._errors, self._window, self._max_pktsize) if self._process_factory: session = SSHServerProcess(self._process_factory, self._sftp_factory, self._sftp_version, self._allow_scp) else: session = SSHServerStreamSession(self._session_factory, self._sftp_factory, self._sftp_version, self._allow_scp) else: result = self._owner.session_requested() if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Session refused') if isinstance(result, tuple): chan, result = result else: chan = self.create_server_channel(self._encoding, self._errors, self._window, self._max_pktsize) if callable(result): session = SSHServerStreamSession(result) else: session = cast(SSHServerSession, result) return chan, session def _process_direct_tcpip_open(self, packet: SSHPacket) -> \ Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Process an incoming direct TCP/IP open request""" dest_host_bytes = packet.get_string() dest_port = packet.get_uint32() orig_host_bytes = packet.get_string() orig_port = packet.get_uint32() packet.check_end() try: dest_host = dest_host_bytes.decode('utf-8') orig_host = orig_host_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid direct TCP/IP channel ' 'open request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted') permitted_opens = cast(Set[Tuple[str, int]], self.get_key_option('permitopen')) if permitted_opens and \ (dest_host, dest_port) not in permitted_opens and \ (dest_host, None) not in permitted_opens: raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted to ' f'{dest_host} port {dest_port}') result = self._owner.connection_requested(dest_host, dest_port, orig_host, orig_port) if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused') if result is True: result = cast(SSHTCPSession[bytes], self.forward_connection(dest_host, dest_port)) if isinstance(result, tuple): chan, result = result else: chan = self.create_tcp_channel() session: SSHTCPSession[bytes] if callable(result): session = SSHTCPStreamSession[bytes](result) else: session = cast(SSHTCPSession[bytes], result) self.logger.info('Accepted direct TCP connection request to %s', (dest_host, dest_port)) self.logger.info(' Client address: %s', (orig_host, orig_port)) chan.set_inbound_peer_names(dest_host, dest_port, orig_host, orig_port) return chan, session def _process_tcpip_forward_global_request(self, packet: SSHPacket) -> None: """Process an incoming TCP/IP port forwarding request""" listen_host_bytes = packet.get_string() listen_port = packet.get_uint32() packet.check_end() try: listen_host = listen_host_bytes.decode('utf-8').lower() except UnicodeDecodeError: raise ProtocolError('Invalid TCP/IP forward request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): self.logger.info('Request for TCP listener on %s denied: port ' 'forwarding not permitted', (listen_host, listen_port)) self._report_global_response(False) return self.create_task(self._finish_port_forward(listen_host, listen_port)) async def _finish_port_forward(self, listen_host: str, listen_port: int) -> None: """Finish processing a TCP/IP port forwarding request""" listener = self._owner.server_requested(listen_host, listen_port) try: if inspect.isawaitable(listener): listener = await cast(Awaitable[_ListenerArg], listener) if listener is True: listener = await self.forward_local_port( listen_host, listen_port, listen_host, listen_port) elif callable(listener): listener = await self.forward_local_port( listen_host, listen_port, listen_host, listen_port, listener) except OSError: self.logger.debug1('Failed to create TCP listener') self._report_global_response(False) return if not listener: self.logger.info('Request for TCP listener on %s denied by ' 'application', (listen_host, listen_port)) self._report_global_response(False) return listener: SSHListener result: Union[bool, bytes] if listen_port == 0: listen_port = listener.get_port() result = UInt32(listen_port) else: result = True self.logger.info('Created TCP listener on %s', (listen_host, listen_port)) self._local_listeners[listen_host, listen_port] = listener self._report_global_response(result) def _process_cancel_tcpip_forward_global_request( self, packet: SSHPacket) -> None: """Process a request to cancel TCP/IP port forwarding""" listen_host_bytes = packet.get_string() listen_port = packet.get_uint32() packet.check_end() try: listen_host = listen_host_bytes.decode('utf-8').lower() except UnicodeDecodeError: raise ProtocolError('Invalid TCP/IP cancel ' 'forward request') from None try: listener = self._local_listeners.pop((listen_host, listen_port)) except KeyError: raise ProtocolError('TCP/IP listener not found') from None self.logger.info('Closed TCP listener on %s', (listen_host, listen_port)) listener.close() self._report_global_response(True) def _process_direct_streamlocal_at_openssh_dot_com_open( self, packet: SSHPacket) -> \ Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]: """Process an incoming direct UNIX domain socket open request""" dest_path_bytes = packet.get_string() # OpenSSH appears to have a bug which sends this extra data _ = packet.get_string() # originator _ = packet.get_uint32() # originator_port packet.check_end() try: dest_path = dest_path_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid direct UNIX domain channel ' 'open request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Port forwarding not permitted') result = self._owner.unix_connection_requested(dest_path) if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused') if result is True: result = cast(SSHUNIXSession[bytes], self.forward_unix_connection(dest_path)) if isinstance(result, tuple): chan, result = result else: chan = self.create_unix_channel() session: SSHUNIXSession[bytes] if callable(result): session = SSHUNIXStreamSession[bytes](result) else: session = cast(SSHUNIXSession[bytes], result) self.logger.info('Accepted direct UNIX connection on %s', dest_path) chan.set_inbound_peer_names(dest_path) return chan, session def _process_streamlocal_forward_at_openssh_dot_com_global_request( self, packet: SSHPacket) -> None: """Process an incoming UNIX domain socket forwarding request""" listen_path_bytes = packet.get_string() packet.check_end() try: listen_path = listen_path_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid UNIX domain socket ' 'forward request') from None if not self.check_key_permission('port-forwarding') or \ not self.check_certificate_permission('port-forwarding'): self.logger.info('Request for UNIX listener on %s denied: port ' 'forwarding not permitted', listen_path) self._report_global_response(False) return self.create_task(self._finish_path_forward(listen_path)) async def _finish_path_forward(self, listen_path: str) -> None: """Finish processing a UNIX domain socket forwarding request""" listener = self._owner.unix_server_requested(listen_path) try: if inspect.isawaitable(listener): listener = await cast(Awaitable[_ListenerArg], listener) if listener is True: listener = await self.forward_local_path(listen_path, listen_path) except OSError: self.logger.debug1('Failed to create UNIX listener') self._report_global_response(False) return if not listener: self.logger.info('Request for UNIX listener on %s denied by ' 'application', listen_path) self._report_global_response(False) return self.logger.info('Created UNIX listener on %s', listen_path) self._local_listeners[listen_path] = cast(SSHListener, listener) self._report_global_response(True) def _process_cancel_streamlocal_forward_at_openssh_dot_com_global_request( self, packet: SSHPacket) -> None: """Process a request to cancel UNIX domain socket forwarding""" listen_path_bytes = packet.get_string() packet.check_end() try: listen_path = listen_path_bytes.decode('utf-8') except UnicodeDecodeError: raise ProtocolError('Invalid UNIX domain cancel ' 'forward request') from None try: listener = self._local_listeners.pop(listen_path) except KeyError: raise ProtocolError('UNIX domain listener not found') from None self.logger.info('Closed UNIX listener on %s', listen_path) listener.close() self._report_global_response(True) def _process_tun_at_openssh_dot_com_open( self, packet: SSHPacket) -> \ Tuple[SSHTunTapChannel, SSHTunTapSession]: """Process an incoming TUN/TAP open request""" mode = packet.get_uint32() unit: Optional[int] = packet.get_uint32() packet.check_end() if unit == SSH_TUN_UNIT_ANY: unit = None if mode == SSH_TUN_MODE_POINTTOPOINT: result = self._owner.tun_requested(unit) elif mode == SSH_TUN_MODE_ETHERNET: result = self._owner.tap_requested(unit) else: result = False if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused') if result is True: result = cast(SSHTunTapSession, self.forward_tuntap(mode, unit)) if isinstance(result, tuple): chan, result = result else: chan = self.create_tuntap_channel() session: SSHTunTapSession if callable(result): session = SSHTunTapStreamSession(result) else: session = cast(SSHTunTapSession, result) self.logger.info('Accepted layer %d tunnel request to unit %s', 3 if mode == SSH_TUN_MODE_POINTTOPOINT else 2, 'any' if unit == SSH_TUN_UNIT_ANY else str(unit)) chan.set_mode(mode) return chan, session def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( self, packet: SSHPacket) -> None: """Prove the server has private keys for all requested host keys""" prefix = String('hostkeys-prove-00@openssh.com') + \ String(self._session_id) signatures = [] while packet: try: key_data = packet.get_string() key = self._all_server_host_keys[key_data] signatures.append(String(key.sign(prefix + String(key_data)))) except (KeyError, KeyImportError): self._report_global_response(False) return self._report_global_response(b''.join(signatures)) async def attach_x11_listener(self, chan: SSHServerChannel[AnyStr], auth_proto: bytes, auth_data: bytes, screen: int) -> Optional[str]: """Attach a channel to a remote X11 display""" if (not self._x11_forwarding or not self.check_key_permission('X11-forwarding') or not self.check_certificate_permission('X11-forwarding')): self.logger.info('X11 forwarding request denied: X11 ' 'forwarding not permitted') return None if not self._x11_listener: self._x11_listener = await create_x11_server_listener( self, self._loop, self._x11_auth_path, auth_proto, auth_data) if self._x11_listener: return self._x11_listener.attach(chan, screen) else: return None def detach_x11_listener(self, chan: SSHChannel[AnyStr]) -> None: """Detach a session from a remote X11 listener""" if self._x11_listener: if self._x11_listener.detach(chan): self._x11_listener = None async def create_agent_listener(self) -> bool: """Create a listener for forwarding ssh-agent connections""" if (not self._agent_forwarding or not self.check_key_permission('agent-forwarding') or not self.check_certificate_permission('agent-forwarding')): self.logger.info('Agent forwarding request denied: Agent ' 'forwarding not permitted') return False if self._agent_listener: return True try: tempdir = tempfile.TemporaryDirectory(prefix='asyncssh-') path = str(Path(tempdir.name, 'agent')) unix_listener = await create_unix_forward_listener( self, self._loop, self.create_agent_connection, path) self._agent_listener = SSHAgentListener(tempdir, path, unix_listener) return True except OSError: return False def get_agent_path(self) -> Optional[str]: """Return the path of the ssh-agent listener, if one exists""" if self._agent_listener: return self._agent_listener.get_path() else: return None def send_auth_banner(self, msg: str, lang: str = DEFAULT_LANG) -> None: """Send an authentication banner to the client This method can be called to send an authentication banner to the client, displaying information while authentication is in progress. It is an error to call this method after the authentication is complete. :param msg: The message to display :param lang: The language the message is in :type msg: `str` :type lang: `str` :raises: :exc:`OSError` if authentication is already completed """ if self._auth_complete: raise OSError('Authentication already completed') self.logger.debug1('Sending authentication banner') self.send_packet(MSG_USERAUTH_BANNER, String(msg), String(lang)) def set_authorized_keys(self, authorized_keys: _AuthKeysArg) -> None: """Set the keys trusted for client public key authentication This method can be called to set the trusted user and CA keys for client public key authentication. It should generally be called from the :meth:`begin_auth ` method of :class:`SSHServer` to set the appropriate keys for the user attempting to authenticate. :param authorized_keys: The keys to trust for client public key authentication :type authorized_keys: *see* :ref:`SpecifyingAuthorizedKeys` """ if isinstance(authorized_keys, (str, list)): authorized_keys = read_authorized_keys(authorized_keys) self._authorized_client_keys = authorized_keys def get_key_option(self, option: str, default: object = None) -> object: """Return option from authorized_keys If a client key or certificate was presented during authentication, this method returns the value of the requested option in the corresponding authorized_keys entry if it was set. Otherwise, it returns the default value provided. The following standard options are supported: | command (string) | environment (dictionary of name/value pairs) | from (list of host patterns) | no-touch-required (boolean) | permitopen (list of host/port tuples) | principals (list of usernames) Non-standard options are also supported and will return the value `True` if the option is present without a value or return a list of strings containing the values associated with each occurrence of that option name. If the option is not present, the specified default value is returned. :param option: The name of the option to look up. :param default: The default value to return if the option is not present. :type option: `str` :returns: The value of the option in authorized_keys, if set """ return self._key_options.get(option, default) def check_key_permission(self, permission: str) -> bool: """Check permissions in authorized_keys If a client key or certificate was presented during authentication, this method returns whether the specified permission is allowed by the corresponding authorized_keys entry. By default, all permissions are granted, but they can be revoked by specifying an option starting with 'no-' without a value. The following standard options are supported: | X11-forwarding | agent-forwarding | port-forwarding | pty | user-rc AsyncSSH internally enforces X11-forwarding, agent-forwarding, port-forwarding and pty permissions but ignores user-rc since it does not implement that feature. Non-standard permissions can also be checked, as long as the option follows the convention of starting with 'no-'. :param permission: The name of the permission to check (without the 'no-'). :type permission: `str` :returns: A `bool` indicating if the permission is granted. """ return not self._key_options.get('no-' + permission, False) def get_certificate_option(self, option: str, default: object = None) -> object: """Return option from user certificate If a user certificate was presented during authentication, this method returns the value of the requested option in the certificate if it was set. Otherwise, it returns the default value provided. The following options are supported: | force-command (string) | no-touch-required (boolean) | source-address (list of CIDR-style IP network addresses) :param option: The name of the option to look up. :param default: The default value to return if the option is not present. :type option: `str` :returns: The value of the option in the user certificate, if set """ if self._cert_options is not None: return self._cert_options.get(option, default) else: return default def check_certificate_permission(self, permission: str) -> bool: """Check permissions in user certificate If a user certificate was presented during authentication, this method returns whether the specified permission was granted in the certificate. Otherwise, it acts as if all permissions are granted and returns `True`. The following permissions are supported: | X11-forwarding | agent-forwarding | port-forwarding | pty | user-rc AsyncSSH internally enforces agent-forwarding, port-forwarding and pty permissions but ignores the other values since it does not implement those features. :param permission: The name of the permission to check (without the 'permit-'). :type permission: `str` :returns: A `bool` indicating if the permission is granted. """ if self._cert_options is not None: return cast(bool, self._cert_options.get('permit-' + permission, False)) else: return True def create_server_channel(self, encoding: Optional[str] = '', errors: str = '', window: int = 0, max_pktsize: int = 0) -> SSHServerChannel: """Create an SSH server channel for a new SSH session This method can be called by :meth:`session_requested() ` to create an :class:`SSHServerChannel` with the desired encoding, Unicode error handling strategy, window, and max packet size for a newly created SSH server session. :param encoding: (optional) The Unicode encoding to use for data exchanged on the session, defaulting to UTF-8 (ISO 10646) format. If `None` is passed in, the application can send and receive raw bytes. :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: :class:`SSHServerChannel` """ return SSHServerChannel(self, self._loop, self._allow_pty, self._line_editor, self._line_echo, self._line_history, self._max_line_length, self._encoding if encoding == '' else encoding, self._errors if errors == '' else errors, window or self._window, max_pktsize or self._max_pktsize) async def create_connection( self, session_factory: SSHTCPSessionFactory[AnyStr], remote_host: str, remote_port: int, orig_host: str = '', orig_port: int = 0, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHTCPChannel[AnyStr], SSHTCPSession[AnyStr]]: """Create an SSH TCP forwarded connection This method is a coroutine which can be called to notify the client about a new inbound TCP connection arriving on the specified remote host and port. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHTCPSession` object created by `session_factory`. Optional arguments include the host and port of the original client opening the connection when performing TCP port forwarding. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application to send and receive string data. When encoding is set, an optional errors argument can be passed in to select what Unicode error handling strategy to use. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHTCPSession` object that will be created to handle activity on this session :param remote_host: The hostname or address the connection was received on :param remote_port: The port number the connection was received on :param orig_host: (optional) The hostname or address of the client requesting the connection :param orig_port: (optional) The port number of the client requesting the connection :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_host: `str` :type remote_port: `int` :type orig_host: `str` :type orig_port: `int` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHTCPChannel` and :class:`SSHTCPSession` """ self.logger.info('Opening forwarded TCP connection to %s', (remote_host, remote_port)) self.logger.info(' Client address: %s', (orig_host, orig_port)) chan = self.create_tcp_channel(encoding, errors, window, max_pktsize) session = await chan.accept(session_factory, remote_host, remote_port, orig_host, orig_port) return chan, session async def open_connection(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH TCP forwarded connection This method is a coroutine wrapper around :meth:`create_connection` designed to provide a "high-level" stream interface for creating an SSH TCP forwarded connection. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of `session_factory`, all of the arguments to :meth:`create_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` """ chan, session = await self.create_connection( SSHTCPStreamSession, *args, **kwargs) # type: ignore session: SSHTCPStreamSession return SSHReader(session, chan), SSHWriter(session, chan) async def create_unix_connection( self, session_factory: SSHUNIXSessionFactory[AnyStr], remote_path: str, *, encoding: Optional[str] = None, errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]: """Create an SSH UNIX domain socket forwarded connection This method is a coroutine which can be called to notify the client about a new inbound UNIX domain socket connection arriving on the specified remote path. If the connection is successfully opened, a new SSH channel will be opened with data being handled by a :class:`SSHUNIXSession` object created by `session_factory`. By default, this class expects data to be sent and received as raw bytes. However, an optional encoding argument can be passed in to select the encoding to use, allowing the application to send and receive string data. When encoding is set, an optional errors argument can be passed in to select what Unicode error handling strategy to use. Other optional arguments include the SSH receive window size and max packet size which default to 2 MB and 32 KB, respectively. :param session_factory: A `callable` which returns an :class:`SSHUNIXSession` object that will be created to handle activity on this session :param remote_path: The path the connection was received on :param encoding: (optional) The Unicode encoding to use for data exchanged on the connection :param errors: (optional) The error handling strategy to apply on encode/decode errors :param window: (optional) The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session :type session_factory: `callable` :type remote_path: `str` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :returns: an :class:`SSHTCPChannel` and :class:`SSHUNIXSession` """ self.logger.info('Opening forwarded UNIX connection to %s', remote_path) chan = self.create_unix_channel(encoding, errors, window, max_pktsize) session = await chan.accept(session_factory, remote_path) return chan, session async def open_unix_connection(self, *args: object, **kwargs: object) -> \ Tuple[SSHReader, SSHWriter]: """Open an SSH UNIX domain socket forwarded connection This method is a coroutine wrapper around :meth:`create_unix_connection` designed to provide a "high-level" stream interface for creating an SSH UNIX domain socket forwarded connection. Instead of taking a `session_factory` argument for constructing an object which will handle activity on the session via callbacks, it returns :class:`SSHReader` and :class:`SSHWriter` objects which can be used to perform I/O on the connection. With the exception of `session_factory`, all of the arguments to :meth:`create_unix_connection` are supported and have the same meaning here. :returns: an :class:`SSHReader` and :class:`SSHWriter` """ chan, session = \ await self.create_unix_connection( SSHUNIXStreamSession, *args, **kwargs) # type: ignore session: SSHUNIXStreamSession return SSHReader(session, chan), SSHWriter(session, chan) async def create_x11_connection( self, session_factory: SSHTCPSessionFactory[bytes], orig_host: str = '', orig_port: int = 0, *, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHX11Channel, SSHTCPSession[bytes]]: """Create an SSH X11 forwarded connection""" self.logger.info('Opening forwarded X11 connection') chan = self.create_x11_channel(window, max_pktsize) session = await chan.open(session_factory, orig_host, orig_port) return chan, session async def create_agent_connection( self, session_factory: SSHUNIXSessionFactory[bytes], *, window:int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ Tuple[SSHAgentChannel, SSHUNIXSession[bytes]]: """Create a forwarded ssh-agent connection back to the client""" if not self._agent_listener: raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, 'Agent forwarding not permitted') self.logger.info('Opening forwarded agent connection') chan = self.create_agent_channel(window, max_pktsize) session = await chan.open(session_factory) return chan, session async def open_agent_connection(self) -> \ Tuple[SSHReader[bytes], SSHWriter[bytes]]: """Open a forwarded ssh-agent connection back to the client""" chan, session = \ await self.create_agent_connection(SSHUNIXStreamSession) session: SSHUNIXStreamSession[bytes] return SSHReader[bytes](session, chan), SSHWriter[bytes](session, chan) class SSHConnectionOptions(Options, Generic[_Options]): """SSH connection options""" config: SSHConfig waiter: Optional[asyncio.Future] protocol_factory: _ProtocolFactory version: bytes host: str port: int tunnel: object proxy_command: Optional[Sequence[str]] family: int local_addr: HostPort tcp_keepalive: bool canonicalize_hostname: Union[bool, str] canonical_domains: Sequence[str] canonicalize_fallback_local: bool canonicalize_max_dots: int canonicalize_permitted_cnames: Sequence[Tuple[str, str]] kex_algs: Sequence[bytes] encryption_algs: Sequence[bytes] mac_algs: Sequence[bytes] compression_algs: Sequence[bytes] signature_algs: Sequence[bytes] host_based_auth: bool public_key_auth: bool kbdint_auth: bool password_auth: bool x509_trusted_certs: Optional[Sequence[SSHX509Certificate]] x509_trusted_cert_paths: Sequence[FilePath] x509_purposes: Union[None, str, Sequence[str]] rekey_bytes: int rekey_seconds: float connect_timeout: Optional[float] login_timeout: float keepalive_internal: float keepalive_count_max: int def __init__(self, options: Optional[_Options] = None, **kwargs: object): last_config = options.config if options else None super().__init__(options=options, last_config=last_config, **kwargs) @classmethod async def construct(cls, options: Optional[_Options] = None, **kwargs: object) -> _Options: """Construct a new options object from within an async task""" loop = asyncio.get_event_loop() return cast(_Options, await loop.run_in_executor( None, functools.partial(cls, options, loop=loop, **kwargs))) # pylint: disable=arguments-differ def prepare(self, config: SSHConfig, # type: ignore protocol_factory: _ProtocolFactory, version: _VersionArg, host: str, port: DefTuple[int], tunnel: object, passphrase: Optional[BytesOrStr], proxy_command: DefTuple[_ProxyCommand], family: DefTuple[int], local_addr: DefTuple[HostPort], tcp_keepalive: DefTuple[bool], canonicalize_hostname: DefTuple[Union[bool, str]], canonical_domains: DefTuple[Sequence[str]], canonicalize_fallback_local: DefTuple[bool], canonicalize_max_dots: DefTuple[int], canonicalize_permitted_cnames: _CNAMEArg, kex_algs: _AlgsArg, encryption_algs: _AlgsArg, mac_algs: _AlgsArg, compression_algs: _AlgsArg, signature_algs: _AlgsArg, host_based_auth: _AuthArg, public_key_auth: _AuthArg, kbdint_auth: _AuthArg, password_auth: _AuthArg, x509_trusted_certs: CertListArg, x509_trusted_cert_paths: Sequence[FilePath], x509_purposes: X509CertPurposes, rekey_bytes: DefTuple[Union[int, str]], rekey_seconds: DefTuple[Union[float, str]], connect_timeout: Optional[Union[float, str]], login_timeout: Union[float, str], keepalive_interval: Union[float, str], keepalive_count_max: int) -> None: """Prepare common connection configuration options""" def _split_cname_patterns( patterns: Union[str, Tuple[str, str]]) -> Tuple[str, str]: """Split CNAME patterns""" if isinstance(patterns, str): domains = patterns.split(':') if len(domains) == 2: patterns = cast(Tuple[str, str], tuple(domains)) else: raise ValueError('CNAME rules must contain two patterns') return patterns self.config = config self.protocol_factory = protocol_factory self.version = _validate_version(version) self.host = cast(str, config.get('Hostname', host)) self.port = cast(int, port if port != () else config.get('Port', DEFAULT_PORT)) self.tunnel = tunnel if tunnel != () else config.get('ProxyJump') self.passphrase = passphrase if proxy_command == (): proxy_command = cast(Optional[str], config.get('ProxyCommand')) if isinstance(proxy_command, str): proxy_command = split_args(proxy_command) self.proxy_command = proxy_command self.family = cast(int, family if family != () else config.get('AddressFamily', socket.AF_UNSPEC)) bind_addr = config.get('BindAddress') self.local_addr = cast(HostPort, local_addr if local_addr != () else (bind_addr, 0) if bind_addr else None) self.tcp_keepalive = cast(bool, tcp_keepalive if tcp_keepalive != () else config.get('TCPKeepAlive', True)) self.canonicalize_hostname = \ cast(Union[bool, str], canonicalize_hostname if canonicalize_hostname != () else config.get('CanonicalizeHostname', False)) self.canonical_domains = \ cast(Sequence[str], canonical_domains if canonical_domains != () else config.get('CanonicalDomains', ())) self.canonicalize_fallback_local = \ cast(bool, canonicalize_fallback_local \ if canonicalize_fallback_local != () else config.get('CanonicalizeFallbackLocal', True)) self.canonicalize_max_dots = \ cast(int, canonicalize_max_dots if canonicalize_max_dots != () else config.get('CanonicalizeMaxDots', 1)) permitted_cnames = \ cast(Sequence[str], canonicalize_permitted_cnames if canonicalize_permitted_cnames != () else config.get('CanonicalizePermittedCNAMEs', ())) self.canonicalize_permitted_cnames = \ [_split_cname_patterns(patterns) for patterns in permitted_cnames] self.kex_algs, self.encryption_algs, self.mac_algs, \ self.compression_algs, self.signature_algs = \ _validate_algs(config, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, x509_trusted_certs is not None) self.host_based_auth = \ cast(bool, host_based_auth if host_based_auth != () else config.get('HostbasedAuthentication', True)) self.public_key_auth = \ cast(bool, public_key_auth if public_key_auth != () else config.get('PubkeyAuthentication', True)) self.kbdint_auth = \ cast(bool, kbdint_auth if kbdint_auth != () else config.get('KbdInteractiveAuthentication', config.get('ChallengeResponseAuthentication', True))) self.password_auth = \ cast(bool, password_auth if password_auth != () else config.get('PasswordAuthentication', True)) if x509_trusted_certs is not None: certs = load_certificates(x509_trusted_certs) for cert in certs: if not cert.is_x509: raise ValueError('OpenSSH certificates not allowed ' 'in X.509 trusted certs') x509_trusted_certs = cast(Sequence[SSHX509Certificate], certs) if x509_trusted_cert_paths: for path in x509_trusted_cert_paths: if not Path(path).is_dir(): raise ValueError('X.509 trusted certificate path not ' f'a directory: {path}') self.x509_trusted_certs = x509_trusted_certs self.x509_trusted_cert_paths = x509_trusted_cert_paths self.x509_purposes = x509_purposes config_rekey_bytes, config_rekey_seconds = \ cast(Tuple[DefTuple[int], DefTuple[int]], config.get('RekeyLimit', ((), ()))) if rekey_bytes == (): rekey_bytes = config_rekey_bytes if rekey_bytes == (): rekey_bytes = _DEFAULT_REKEY_BYTES elif isinstance(rekey_bytes, str): rekey_bytes = parse_byte_count(rekey_bytes) if cast(int, rekey_bytes) <= 0: raise ValueError('Rekey bytes cannot be negative or zero') if rekey_seconds == (): rekey_seconds = config_rekey_seconds if rekey_seconds == (): rekey_seconds = _DEFAULT_REKEY_SECONDS elif isinstance(rekey_seconds, str): rekey_seconds = parse_time_interval(rekey_seconds) if rekey_seconds and cast(float, rekey_seconds) <= 0: raise ValueError('Rekey seconds cannot be negative or zero') if isinstance(connect_timeout, str): connect_timeout = parse_time_interval(connect_timeout) if connect_timeout and connect_timeout < 0: raise ValueError('Connect timeout cannot be negative') if isinstance(login_timeout, str): login_timeout = parse_time_interval(login_timeout) if login_timeout and login_timeout < 0: raise ValueError('Login timeout cannot be negative') if isinstance(keepalive_interval, str): keepalive_interval = parse_time_interval(keepalive_interval) if keepalive_interval and keepalive_interval < 0: raise ValueError('Keepalive interval cannot be negative') if keepalive_count_max <= 0: raise ValueError('Keepalive count max cannot be negative or zero') self.rekey_bytes = cast(int, rekey_bytes) self.rekey_seconds = cast(float, rekey_seconds) self.connect_timeout = connect_timeout or None self.login_timeout = login_timeout self.keepalive_interval = keepalive_interval self.keepalive_count_max = keepalive_count_max class SSHClientConnectionOptions(SSHConnectionOptions): """SSH client connection options The following options are available to control the establishment of SSH client connections: :param client_factory: (optional) A `callable` which returns an :class:`SSHClient` object that will be created for each new connection. :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run to make a connection to the SSH server. Data will be forwarded to this process over stdin/stdout instead of opening a TCP connection. If specified as a string, standard shell quoting will be applied when splitting the command and its arguments. :param known_hosts: (optional) The list of keys which will be used to validate the server host key presented during the SSH handshake. If this is not specified, the keys will be looked up in the file :file:`.ssh/known_hosts`. If this is explicitly set to `None`, server host key validation will be disabled. :param host_key_alias: (optional) An alias to use instead of the real host name when looking up a host key in known_hosts and when validating host certificates. :param server_host_key_algs: (optional) A list of server host key algorithms to use instead of the default of those present in known_hosts when performing the SSH handshake, taken from :ref:`server host key algorithms `. This is useful when using the validate_host_public_key callback to validate server host keys, since AsyncSSH can not determine which server host key algorithms are preferred. This argument can also be set to 'default' to specify that the client should always send its default list of supported algorithms to avoid leaking information about what algorithms are present for the server in known_hosts. .. note:: The 'default' keyword should be used with caution, as it can result in a host key mismatch if the client trusts only a subset of the host keys the server might return. :param server_host_keys_handler: (optional) A `callable` or coroutine handler function which if set will be called when a global request from the server is received which provides an updated list of server host keys. The handler takes four arguments (added, removed, retained, and revoked), each of which is a list of SSHKey public keys, reflecting differences between what the server reported and what is currently matching in known_hosts. .. note:: This handler will only be called when known host checking is enabled and the check succeeded. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 server certificate authentication. If no trusted certificates are specified, an attempt will be made to load them from the file :file:`.ssh/ca-bundle.crt`. If this argument is explicitly set to `None`, X.509 server certificate authentication will not be performed. .. note:: X.509 certificates to trust can also be provided through a :ref:`known_hosts ` file if they are converted into OpenSSH format. This allows their trust to be limited to only specific host names. :param x509_trusted_cert_paths: (optional) A list of path names to "hash directories" containing certificates which should be trusted for X.509 server certificate authentication. Each certificate should be in a separate file with a name of the form *hash.number*, where *hash* is the OpenSSL hash value of the certificate subject name and *number* is an integer counting up from zero if multiple certificates have the same hash. If no paths are specified, an attempt with be made to use the directory :file:`.ssh/crt` as a certificate hash directory. :param x509_purposes: (optional) A list of purposes allowed in the ExtendedKeyUsage of a certificate used for X.509 server certificate authentication, defulting to 'secureShellServer'. If this argument is explicitly set to `None`, the server certificate's ExtendedKeyUsage will not be checked. :param username: (optional) Username to authenticate as on the server. If not specified, the currently logged in user on the local machine will be used. :param password: (optional) The password to use for client password authentication or keyboard-interactive authentication which prompts for a password, or a `callable` or coroutine which returns the password to use. If this is not specified or set to `None`, client password authentication will not be performed. :param client_host_keysign: (optional) Whether or not to use `ssh-keysign` to sign host-based authentication requests. If set to `True`, an attempt will be made to find `ssh-keysign` in its typical locations. If set to a string, that will be used as the `ssh-keysign` path. When set, client_host_keys should be a list of public keys. Otherwise, client_host_keys should be a list of private keys with optional paired certificates. :param client_host_keys: (optional) A list of keys to use to authenticate this client via host-based authentication. If `client_host_keysign` is set and no host keys or certificates are specified, an attempt will be made to find them in their typical locations. If `client_host_keysign` is not set, host private keys must be specified explicitly or host-based authentication will not be performed. :param client_host_certs: (optional) A list of optional certificates which can be paired with the provided client host keys. :param client_host: (optional) The local hostname to use when performing host-based authentication. If not specified, the hostname associated with the local IP address of the SSH connection will be used. :param client_username: (optional) The local username to use when performing host-based authentication. If not specified, the username of the currently logged in user will be used. :param client_keys: (optional) A list of keys which will be used to authenticate this client via public key authentication. These keys will be used after trying keys from a PKCS11 provider or an ssh-agent, if either of those are configured. If no client keys are specified, an attempt will be made to load them from the files :file:`.ssh/id_ed25519_sk`, :file:`.ssh/id_ecdsa_sk`, :file:`.ssh/id_ed448`, :file:`.ssh/id_ed25519`, :file:`.ssh/id_ecdsa`, :file:`.ssh/id_rsa`, and :file:`.ssh/id_dsa` in the user's home directory, with optional certificates loaded from the files :file:`.ssh/id_ed25519_sk-cert.pub`, :file:`.ssh/id_ecdsa_sk-cert.pub`, :file:`.ssh/id_ed448-cert.pub`, :file:`.ssh/id_ed25519-cert.pub`, :file:`.ssh/id_ecdsa-cert.pub`, :file:`.ssh/id_rsa-cert.pub`, and :file:`.ssh/id_dsa-cert.pub`. If this argument is explicitly set to `None`, client public key authentication will not be performed. :param client_certs: (optional) A list of optional certificates which can be paired with the provided client keys. :param passphrase: (optional) The passphrase to use to decrypt client keys if they are encrypted, or a `callable` or coroutine which takes a filename as a parameter and returns the passphrase to use to decrypt that file. If not specified, only unencrypted client keys can be loaded. If the keys passed into client_keys are already loaded, this argument is ignored. .. note:: A callable or coroutine passed in as a passphrase will be called on all filenames configured as client keys or client host keys each time an SSHClientConnectionOptions object is instantiated, even if the keys aren't encrypted or aren't ever used for authentication. :param ignore_encrypted: (optional) Whether or not to ignore encrypted keys when no passphrase is specified. This defaults to `True` when keys are specified via the IdentityFile config option, causing encrypted keys in the config to be ignored when no passphrase is specified. Note that encrypted keys loaded into an SSH agent can still be used when this option is set. :param host_based_auth: (optional) Whether or not to allow host-based authentication. By default, host-based authentication is enabled if client host keys are made available. :param public_key_auth: (optional) Whether or not to allow public key authentication. By default, public key authentication is enabled if client keys are made available. :param kbdint_auth: (optional) Whether or not to allow keyboard-interactive authentication. By default, keyboard-interactive authentication is enabled if a password is specified or if callbacks to respond to challenges are made available. :param password_auth: (optional) Whether or not to allow password authentication. By default, password authentication is enabled if a password is specified or if callbacks to provide a password are made available. :param gss_host: (optional) The principal name to use for the host in GSS key exchange and authentication. If not specified, this value will be the same as the `host` argument. If this argument is explicitly set to `None`, GSS key exchange and authentication will not be performed. :param gss_store: (optional) The GSS credential store from which to acquire credentials. :param gss_kex: (optional) Whether or not to allow GSS key exchange. By default, GSS key exchange is enabled. :param gss_auth: (optional) Whether or not to allow GSS authentication. By default, GSS authentication is enabled. :param gss_delegate_creds: (optional) Whether or not to forward GSS credentials to the server being accessed. By default, GSS credential delegation is disabled. :param preferred_auth: A list of authentication methods the client should attempt to use in order of preference. By default, the preferred list is gssapi-keyex, gssapi-with-mic, hostbased, publickey, keyboard-interactive, and then password. This list may be limited by which auth methods are implemented by the client and which methods the server accepts. :param disable_trivial_auth: (optional) Whether or not to allow "trivial" forms of auth where the client is not actually challenged for credentials. Setting this will cause the connection to fail if a server does not perform some non-trivial form of auth during the initial SSH handshake. If not specified, all forms of auth supported by the server are allowed, including none. :param agent_path: (optional) The path of a UNIX domain socket to use to contact an ssh-agent process which will perform the operations needed for client public key authentication, or the :class:`SSHServerConnection` to use to forward ssh-agent requests over. If this is not specified and the environment variable `SSH_AUTH_SOCK` is set, its value will be used as the path. If this argument is explicitly set to `None`, an ssh-agent will not be used. :param agent_identities: (optional) A list of identities used to restrict which SSH agent keys may be used. These may be specified as byte strings in binary SSH format or as public keys or certificates (*see* :ref:`SpecifyingPublicKeys` and :ref:`SpecifyingCertificates`). If set to `None`, all keys loaded into the SSH agent will be made available for use. This is the default. :param agent_forwarding: (optional) Whether or not to allow forwarding of ssh-agent requests from processes running on the server. This argument can also be set to the path of a UNIX domain socket in cases where forwarded agent requests should be sent to a different path than client agent requests. By default, forwarding ssh-agent requests from the server is not allowed. :param pkcs11_provider: (optional) The path of a shared library which should be used as a PKCS#11 provider for accessing keys on PIV security tokens. By default, no local security tokens will be accessed. :param pkcs11_pin: (optional) The PIN to use when accessing security tokens via PKCS#11. .. note:: If your application opens multiple SSH connections using PKCS#11 keys, you should consider calling :func:`load_pkcs11_keys` explicitly instead of using these arguments. This allows you to pay the cost of loading the key information from the security tokens only once. You can then pass the returned keys via the `client_keys` argument to any calls that need them. Calling :func:`load_pkcs11_keys` explicitly also gives you the ability to load keys from multiple tokens with different PINs and to select which tokens to load keys from and which keys on those tokens to load. :param client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to `'AsyncSSH'` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms `. :param encryption_algs: (optional) A list of encryption algorithms to use during the SSH handshake, taken from :ref:`encryption algorithms `. :param mac_algs: (optional) A list of MAC algorithms to use during the SSH handshake, taken from :ref:`MAC algorithms `. :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or `None` to disable compression. The client prefers to disable compression, but will enable it if the server requires it. :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH handshake, taken from :ref:`signature algorithms `. :param rekey_bytes: (optional) The number of bytes which can be sent before the SSH session key is renegotiated, defaulting to 1 GB. :param rekey_seconds: (optional) The maximum time in seconds before the SSH session key is renegotiated, defaulting to 1 hour. :param connect_timeout: (optional) The maximum time in seconds allowed to complete an outbound SSH connection. This includes the time to establish the TCP connection and the time to perform the initial SSH protocol handshake, key exchange, and authentication. This is disabled by default, relying on the system's default TCP connect timeout and AsyncSSH's login timeout. :param login_timeout: (optional) The maximum time in seconds allowed for authentication to complete, defaulting to 2 minutes. Setting this to 0 will disable the login timeout. .. note:: This timeout only applies after the SSH TCP connection is established. To set a timeout which includes establishing the TCP connection, use the `connect_timeout` argument above. :param keepalive_interval: (optional) The time in seconds to wait before sending a keepalive message if no data has been received from the server. This defaults to 0, which disables sending these messages. :param keepalive_count_max: (optional) The maximum number of keepalive messages which will be sent without getting a response before disconnecting from the server. This defaults to 3, but only applies when keepalive_interval is non-zero. :param tcp_keepalive: (optional) Whether or not to enable keepalive probes at the TCP level to detect broken connections, defaulting to `True`. :param canonicalize_hostname: (optional) Whether or not to enable hostname canonicalization, defaulting to `False`, in which case hostnames are passed as-is to the system resolver. If set to `True`, requests that don't involve a proxy tunnel or command will attempt to canonicalize the hostname using canonical_domains and rules in canonicalize_permitted_cnames. If set to `'always'`, hostname canonicalization is also applied to proxied requests. :param canonical_domains: (optional) When canonicalize_hostname is set, this specifies list of domain suffixes in which to search for the hostname. :param canonicalize_fallback_local: (optional) Whether or not to fall back to looking up the hostname against the system resolver's search domains when no matches are found in canonical_domains, defaulting to `True`. :param canonicalize_max_dots: (optional) Tha maximum number of dots which can appear in a hostname before hostname canonicalization is disabled, defaulting to 1. Hostnames with more than this number of dots are treated as already being fully qualified and passed as-is to the system resolver. :param canonicalize_permitted_cnames: (optional) Patterns to match against to decide whether hostname canonicalization should return a CNAME. This argument contains a list of pairs of wildcard pattern lists. The first pattern is matched against the hostname found after adding one of the search domains from canonical_domains and the second pattern is matched against the associated CNAME. If a match can be found in the list for both patterns, the CNAME is returned as the canonical hostname. The default is an empty list, preventing CNAMEs from being returned. :param command: (optional) The default remote command to execute on client sessions. An interactive shell is started if no command or subsystem is specified. :param subsystem: (optional) The default remote subsystem to start on client sessions. :param env: (optional) The default environment variables to set for client sessions. Keys and values passed in here will be converted to Unicode strings encoded as UTF-8 (ISO 10646) for transmission. .. note:: Many SSH servers restrict which environment variables a client is allowed to set. The server's configuration may need to be edited before environment variables can be successfully set in the remote environment. :param send_env: (optional) A list of environment variable names to pull from `os.environ` and set by default for client sessions. Wildcards patterns using `'*'` and `'?'` are allowed, and all variables with matching names will be sent with whatever value is set in the local environment. If a variable is present in both env and send_env, the value from env will be used. :param request_pty: (optional) Whether or not to request a pseudo-terminal (PTY) by default for client sessions. This defaults to `True`, which means to request a PTY whenever the `term_type` is set. Other possible values include `False` to never request a PTY, `'force'` to always request a PTY even without `term_type` being set, or `'auto'` to request a TTY when `term_type` is set but only when starting an interactive shell. :param term_type: (optional) The default terminal type to set for client sessions. :param term_size: (optional) The terminal width and height in characters and optionally the width and height in pixels to set for client sessions. :param term_modes: (optional) POSIX terminal modes to set for client sessions, where keys are taken from :ref:`POSIX terminal modes ` with values defined in section 8 of :rfc:`RFC 4254 <4254#section-8>`. :param x11_forwarding: (optional) Whether or not to request X11 forwarding for client sessions, defaulting to `False`. If set to `True`, X11 forwarding will be requested and a failure will raise :exc:`ChannelOpenError`. It can also be set to `'ignore_failure'` to attempt X11 forwarding but ignore failures. :param x11_display: (optional) The display that X11 connections should be forwarded to, defaulting to the value in the environment variable `DISPLAY`. :param x11_auth_path: (optional) The path to the Xauthority file to read X11 authentication data from, defaulting to the value in the environment variable `XAUTHORITY` or the file :file:`.Xauthority` in the user's home directory if that's not set. :param x11_single_connection: (optional) Whether or not to limit X11 forwarding to a single connection, defaulting to `False`. :param encoding: (optional) The default Unicode encoding to use for data exchanged on client sessions. :param errors: (optional) The default error handling strategy to apply on Unicode encode/decode errors. :param window: (optional) The default receive window size to set for client sessions. :param max_pktsize: (optional) The default maximum packet size to set for client sessions. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. .. note:: Specifying configuration files when creating an :class:`SSHClientConnectionOptions` object will cause the config file to be read and parsed at the time of creation of the object, including evaluation of any conditional blocks. If you want the config to be parsed for every new connection, this argument should be added to the connect or listen calls instead. However, if you want to save the parsing overhead and your configuration doesn't depend on conditions that would change between calls, this argument may be an option. :param options: (optional) A previous set of options to use as the base to incrementally build up a configuration. When an option is not explicitly specified, its value will be pulled from this options object (if present) before falling back to the default value. :type client_factory: `callable` returning :class:`SSHClient` :type proxy_command: `str` or `list` of `str` :type known_hosts: *see* :ref:`SpecifyingKnownHosts` :type host_key_alias: `str` :type server_host_key_algs: `str` or `list` of `str` :type server_host_keys_handler: `callable` or coroutine :type x509_trusted_certs: *see* :ref:`SpecifyingCertificates` :type x509_trusted_cert_paths: `list` of `str` :type x509_purposes: *see* :ref:`SpecifyingX509Purposes` :type username: `str` :type password: `str` :type client_host_keysign: `bool` or `str` :type client_host_keys: *see* :ref:`SpecifyingPrivateKeys` or :ref:`SpecifyingPublicKeys` :type client_host_certs: *see* :ref:`SpecifyingCertificates` :type client_host: `str` :type client_username: `str` :type client_keys: *see* :ref:`SpecifyingPrivateKeys` :type client_certs: *see* :ref:`SpecifyingCertificates` :type passphrase: `str` or `bytes` :type ignore_encrypted: `bool` :type host_based_auth: `bool` :type public_key_auth: `bool` :type kbdint_auth: `bool` :type password_auth: `bool` :type gss_host: `str` :type gss_store: `str`, `bytes`, or a `dict` with `str` or `bytes` keys and values :type gss_kex: `bool` :type gss_auth: `bool` :type gss_delegate_creds: `bool` :type preferred_auth: `str` or `list` of `str` :type disable_trivial_auth: `bool` :type agent_path: `str` :type agent_identities: *see* :ref:`SpecifyingPublicKeys` and :ref:`SpecifyingCertificates` :type agent_forwarding: `bool` or `str` :type pkcs11_provider: `str` or `None` :type pkcs11_pin: `str` :type client_version: `str` :type kex_algs: `str` or `list` of `str` :type encryption_algs: `str` or `list` of `str` :type mac_algs: `str` or `list` of `str` :type compression_algs: `str` or `list` of `str` :type signature_algs: `str` or `list` of `str` :type rekey_bytes: *see* :ref:`SpecifyingByteCounts` :type rekey_seconds: *see* :ref:`SpecifyingTimeIntervals` :type connect_timeout: *see* :ref:`SpecifyingTimeIntervals` :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` :type tcp_keepalive: `bool` :type canonicalize_hostname: `bool` or `'always'` :type canonical_domains: `list` of `str` :type canonicalize_fallback_local: `bool` :type canonicalize_max_dots: `int` :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type command: `str` :type subsystem: `str` :type env: `dict` with `str` keys and values :type send_env: `list` of `str` :type request_pty: `bool`, `'force'`, or `'auto'` :type term_type: `str` :type term_size: `tuple` of 2 or 4 `int` values :type term_modes: `dict` with `int` keys and values :type x11_forwarding: `bool` or `'ignore_failure'` :type x11_display: `str` :type x11_auth_path: `str` :type x11_single_connection: `bool` :type encoding: `str` or `None` :type errors: `str` :type window: `int` :type max_pktsize: `int` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` """ config: SSHClientConfig client_factory: _ClientFactory client_version: bytes known_hosts: KnownHostsArg host_key_alias: Optional[str] server_host_key_algs: Union[str, Sequence[str]] server_host_keys_handler: _ServerHostKeysHandler username: str password: Optional[str] client_host_keysign: Optional[str] client_host_keypairs: Sequence[SSHKeyPair] client_host_pubkeys: Sequence[Union[SSHKey, SSHCertificate]] client_host: Optional[str] client_username: str client_keys: Optional[Sequence[SSHKeyPair]] client_certs: Sequence[FilePath] ignore_encrypted: bool gss_host: DefTuple[Optional[str]] gss_store: Optional[Dict[BytesOrStr, BytesOrStr]] gss_kex: bool gss_auth: bool gss_delegate_creds: bool preferred_auth: Sequence[str] disable_trivial_auth: bool agent_path: Optional[str] agent_identities: Optional[Sequence[bytes]] agent_forward_path: Optional[str] pkcs11_provider: Optional[str] pkcs11_pin: Optional[str] command: Optional[str] subsystem: Optional[str] env: Env send_env: Optional[EnvSeq] request_pty: _RequestPTY term_type: Optional[str] term_size: TermSizeArg term_modes: TermModesArg x11_forwarding: Union[bool, str] x11_display: Optional[str] x11_auth_path: Optional[str] x11_single_connection: bool encoding: Optional[str] errors: str window: int max_pktsize: int # pylint: disable=arguments-differ def prepare(self, # type: ignore loop: Optional[asyncio.AbstractEventLoop] = None, last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, canonical: bool = False, final: bool = False, client_factory: Optional[_ClientFactory] = None, client_version: _VersionArg = (), host: str = '', port: DefTuple[int] = (), tunnel: object = (), proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), canonicalize_hostname: DefTuple[Union[bool, str]] = (), canonical_domains: DefTuple[Sequence[str]] = (), canonicalize_fallback_local: DefTuple[bool] = (), canonicalize_max_dots: DefTuple[int] = (), canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), public_key_auth: _AuthArg = (), kbdint_auth: _AuthArg = (), password_auth: _AuthArg = (), x509_trusted_certs: CertListArg = (), x509_trusted_cert_paths: Sequence[FilePath] = (), x509_purposes: X509CertPurposes = 'secureShellServer', rekey_bytes: DefTuple[Union[int, str]] = (), rekey_seconds: DefTuple[Union[float, str]] = (), connect_timeout: DefTuple[Optional[Union[float, str]]] = (), login_timeout: Union[float, str] = _DEFAULT_LOGIN_TIMEOUT, keepalive_interval: DefTuple[Union[float, str]] = (), keepalive_count_max: DefTuple[int] = (), known_hosts: KnownHostsArg = (), host_key_alias: DefTuple[Optional[str]] = (), server_host_key_algs: _AlgsArg = (), server_host_keys_handler: _ServerHostKeysHandler = None, username: DefTuple[str] = (), password: Optional[str] = None, client_host_keysign: DefTuple[KeySignPath] = (), client_host_keys: Optional[_ClientKeysArg] = None, client_host_certs: Sequence[FilePath] = (), client_host: Optional[str] = None, client_username: DefTuple[str] = (), client_keys: _ClientKeysArg = (), client_certs: Sequence[FilePath] = (), passphrase: Optional[BytesOrStr] = None, ignore_encrypted: DefTuple[bool] = (), gss_host: DefTuple[Optional[str]] = (), gss_store: Optional[Union[BytesOrStr, BytesOrStrDict]] = None, gss_kex: DefTuple[bool] = (), gss_auth: DefTuple[bool] = (), gss_delegate_creds: DefTuple[bool] = (), preferred_auth: DefTuple[Union[str, Sequence[str]]] = (), disable_trivial_auth: bool = False, agent_path: DefTuple[Optional[str]] = (), agent_identities: DefTuple[Optional[IdentityListArg]] = (), agent_forwarding: DefTuple[Union[bool, str]] = (), pkcs11_provider: DefTuple[Optional[str]] = (), pkcs11_pin: Optional[str] = None, command: DefTuple[Optional[str]] = (), subsystem: Optional[str] = None, env: DefTuple[Env] = (), send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[_RequestPTY] = (), term_type: Optional[str] = None, term_size: TermSizeArg = None, term_modes: TermModesArg = None, x11_forwarding: DefTuple[Union[bool, str]] = (), x11_display: Optional[str] = None, x11_auth_path: Optional[str] = None, x11_single_connection: bool = False, encoding: Optional[str] = 'utf-8', errors: str = 'strict', window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> None: """Prepare client connection configuration options""" try: local_username = getpass.getuser() except KeyError: raise ValueError('Unknown local username: set one of ' 'LOGNAME, USER, LNAME, or USERNAME in ' 'the environment') from None if config == () and (not last_config or not last_config.loaded): default_config = Path('~', '.ssh', 'config').expanduser() config = [default_config] if os.access(default_config, os.R_OK) else [] config = SSHClientConfig.load(last_config, config, reload, canonical, final, local_username, username, host, port) if x509_trusted_certs == (): default_x509_certs = Path('~', '.ssh', 'ca-bundle.crt').expanduser() if os.access(default_x509_certs, os.R_OK): x509_trusted_certs = str(default_x509_certs) if x509_trusted_cert_paths == (): default_x509_cert_path = Path('~', '.ssh', 'crt').expanduser() if default_x509_cert_path.is_dir(): x509_trusted_cert_paths = [str(default_x509_cert_path)] if connect_timeout == (): connect_timeout = cast(Optional[Union[float, str]], config.get('ConnectTimeout', None)) connect_timeout: Optional[Union[float, str]] if keepalive_interval == (): keepalive_interval = \ cast(Union[float, str], config.get('ServerAliveInterval', _DEFAULT_KEEPALIVE_INTERVAL)) keepalive_interval: Union[float, str] if keepalive_count_max == (): keepalive_count_max = \ cast(int, config.get('ServerAliveCountMax', _DEFAULT_KEEPALIVE_COUNT_MAX)) keepalive_count_max: int super().prepare(config, client_factory or SSHClient, client_version, host, port, tunnel, passphrase, proxy_command, family, local_addr, tcp_keepalive, canonicalize_hostname, canonical_domains, canonicalize_fallback_local, canonicalize_max_dots, canonicalize_permitted_cnames, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, host_based_auth, public_key_auth, kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) if known_hosts != (): self.known_hosts = known_hosts else: user_known_hosts = \ cast(List[str], config.get('UserKnownHostsFile', ())) if user_known_hosts == []: self.known_hosts = None else: self.known_hosts = list(user_known_hosts) + \ cast(List[str], config.get('GlobalKnownHostsFile', [])) self.host_key_alias = \ cast(Optional[str], host_key_alias if host_key_alias != () else config.get('HostKeyAlias')) self.server_host_key_algs = server_host_key_algs # Just validate the input here -- the actual server host key # selection is done later, after the known_hosts lookup is done. _select_host_key_algs(server_host_key_algs, cast(DefTuple[str], config.get('HostKeyAlgorithms', ())), []) self.server_host_keys_handler = server_host_keys_handler self.username = saslprep(cast(str, username if username != () else config.get('User', local_username))) self.password = password if client_host_keysign == (): client_host_keysign = \ cast(bool, config.get('EnableSSHKeySign', False)) if client_host_keysign: client_host_keysign = find_keysign(client_host_keysign) if client_host_keys: self.client_host_pubkeys = \ load_public_keys(cast(KeyListArg, client_host_keys)) else: self.client_host_pubkeys = load_default_host_public_keys() else: client_host_keysign = None self.client_host_keypairs = \ load_keypairs(cast(KeyPairListArg, client_host_keys), passphrase, client_host_certs, loop=loop) self.client_host_keysign = client_host_keysign self.client_host = client_host self.client_username = saslprep(cast(str, client_username if client_username != () else local_username)) self.gss_host = gss_host if isinstance(gss_store, (bytes, str)): self.gss_store = {'ccache': gss_store} else: self.gss_store = gss_store self.gss_kex = cast(bool, gss_kex if gss_kex != () else config.get('GSSAPIKeyExchange', True)) self.gss_auth = cast(bool, gss_auth if gss_auth != () else config.get('GSSAPIAuthentication', True)) self.gss_delegate_creds = cast(bool, gss_delegate_creds if gss_delegate_creds != () else config.get('GSSAPIDelegateCredentials', False)) if preferred_auth == (): preferred_auth = \ cast(str, config.get('PreferredAuthentications', ())) if isinstance(preferred_auth, str): preferred_auth = preferred_auth.split(',') preferred_auth: Sequence[str] self.preferred_auth = preferred_auth self.disable_trivial_auth = disable_trivial_auth if agent_path == (): agent_path = cast(DefTuple[str], config.get('IdentityAgent', ())) if agent_path == (): agent_path = os.environ.get('SSH_AUTH_SOCK', '') agent_path = str(Path(agent_path).expanduser()) if agent_path else '' if pkcs11_provider == (): pkcs11_provider = \ cast(Optional[str], config.get('PKCS11Provider')) pkcs11_provider: Optional[str] if ignore_encrypted == (): ignore_encrypted = client_keys == () ignore_encrypted: bool if client_keys == (): client_keys = cast(_ClientKeysArg, config.get('IdentityFile', ())) if client_certs == (): client_certs = \ cast(Sequence[FilePath], config.get('CertificateFile', ())) identities_only = cast(bool, config.get('IdentitiesOnly')) if agent_identities == (): if identities_only: agent_identities = cast(KeyListArg, client_keys) else: agent_identities = None if agent_identities: self.agent_identities = load_identities(agent_identities, identities_only) elif agent_identities == (): self.agent_identities = load_default_identities() else: self.agent_identities = None if client_keys: self.client_keys = \ load_keypairs(cast(KeyPairListArg, client_keys), passphrase, client_certs, identities_only, ignore_encrypted, loop=loop) elif client_keys is not None: self.client_keys = load_default_keypairs(passphrase, client_certs) else: self.client_keys = None if self.client_keys is not None: self.agent_path = agent_path self.pkcs11_provider = pkcs11_provider self.pkcs11_pin = pkcs11_pin else: self.agent_path = None self.pkcs11_provider = None self.pkcs11_pin = None if agent_forwarding == (): agent_forwarding = cast(Union[bool, str], config.get('ForwardAgent', False)) agent_forwarding: Union[bool, str] if not agent_forwarding: self.agent_forward_path = None elif agent_forwarding is True: self.agent_forward_path = agent_path else: self.agent_forward_path = agent_forwarding self.command = cast(Optional[str], command if command != () else config.get('RemoteCommand')) self.subsystem = subsystem self.env = cast(Env, env if env != () else config.get('SetEnv')) self.send_env = cast(Optional[EnvSeq], send_env if send_env != () else config.get('SendEnv')) self.request_pty = cast(_RequestPTY, request_pty if request_pty != () else config.get('RequestTTY', True)) self.term_type = term_type self.term_size = term_size self.term_modes = term_modes self.x11_forwarding = cast(Union[bool, str], x11_forwarding if x11_forwarding != () else config.get('ForwardX11Trusted') and 'ignore_failure') self.x11_display = x11_display self.x11_auth_path = x11_auth_path self.x11_single_connection = x11_single_connection self.encoding = encoding self.errors = errors self.window = window self.max_pktsize = max_pktsize class SSHServerConnectionOptions(SSHConnectionOptions): """SSH server connection options The following options are available to control the acceptance of SSH server connections: :param server_factory: A `callable` which returns an :class:`SSHServer` object that will be created for each new connection. :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run when using :func:`connect_reverse` to make a reverse direction connection to an SSH client. Data will be forwarded to this process over stdin/stdout instead of opening a TCP connection. If specified as a string, standard shell quoting will be applied when splitting the command and its arguments. :param server_host_keys: (optional) A list of private keys and optional certificates which can be used by the server as a host key. Either this argument or `gss_host` must be specified. If this is not specified, only GSS-based key exchange will be supported. :param server_host_certs: (optional) A list of optional certificates which can be paired with the provided server host keys. :param send_server_host_keys: (optional) Whether or not to send a list of the allowed server host keys for clients to use to update their known hosts like for the server. .. note:: Enabling this option will allow multiple server host keys of the same type to be configured. Only the first key of each type will be actively used during key exchange, but the others will be reported as reserved keys that clients should begin to trust, to allow for future key rotation. If this option is disabled, specifying multiple server host keys of the same type is treated as a configuration error. :param passphrase: (optional) The passphrase to use to decrypt server host keys if they are encrypted, or a `callable` or coroutine which takes a filename as a parameter and returns the passphrase to use to decrypt that file. If not specified, only unencrypted server host keys can be loaded. If the keys passed into server_host_keys are already loaded, this argument is ignored. .. note:: A callable or coroutine passed in as a passphrase will be called on all filenames configured as server host keys each time an SSHServerConnectionOptions object is instantiated, even if the keys aren't encrypted or aren't ever used for server validation. :param known_client_hosts: (optional) A list of client hosts which should be trusted to perform host-based client authentication. If this is not specified, host-based client authentication will be not be performed. :param trust_client_host: (optional) Whether or not to use the hostname provided by the client when performing host-based authentication. By default, the client-provided hostname is not trusted and is instead determined by doing a reverse lookup of the IP address the client connected from. :param authorized_client_keys: (optional) A list of authorized user and CA public keys which should be trusted for certificate-based client public key authentication. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 client certificate authentication. If this argument is explicitly set to `None`, X.509 client certificate authentication will not be performed. .. note:: X.509 certificates to trust can also be provided through an :ref:`authorized_keys ` file if they are converted into OpenSSH format. This allows their trust to be limited to only specific client IPs or user names and allows SSH functions to be restricted when these certificates are used. :param x509_trusted_cert_paths: (optional) A list of path names to "hash directories" containing certificates which should be trusted for X.509 client certificate authentication. Each certificate should be in a separate file with a name of the form *hash.number*, where *hash* is the OpenSSL hash value of the certificate subject name and *number* is an integer counting up from zero if multiple certificates have the same hash. :param x509_purposes: (optional) A list of purposes allowed in the ExtendedKeyUsage of a certificate used for X.509 client certificate authentication, defulting to 'secureShellClient'. If this argument is explicitly set to `None`, the client certificate's ExtendedKeyUsage will not be checked. :param host_based_auth: (optional) Whether or not to allow host-based authentication. By default, host-based authentication is enabled if known client host keys are specified or if callbacks to validate client host keys are made available. :param public_key_auth: (optional) Whether or not to allow public key authentication. By default, public key authentication is enabled if authorized client keys are specified or if callbacks to validate client keys are made available. :param kbdint_auth: (optional) Whether or not to allow keyboard-interactive authentication. By default, keyboard-interactive authentication is enabled if the callbacks to generate challenges are made available. :param password_auth: (optional) Whether or not to allow password authentication. By default, password authentication is enabled if callbacks to validate a password are made available. :param gss_host: (optional) The principal name to use for the host in GSS key exchange and authentication. If not specified, the value returned by :func:`socket.gethostname` will be used if it is a fully qualified name. Otherwise, the value used by :func:`socket.getfqdn` will be used. If this argument is explicitly set to `None`, GSS key exchange and authentication will not be performed. :param gss_store: (optional) The GSS credential store from which to acquire credentials. :param gss_kex: (optional) Whether or not to allow GSS key exchange. By default, GSS key exchange is enabled. :param gss_auth: (optional) Whether or not to allow GSS authentication. By default, GSS authentication is enabled. :param allow_pty: (optional) Whether or not to allow allocation of a pseudo-tty in sessions, defaulting to `True` :param line_editor: (optional) Whether or not to enable input line editing on sessions which have a pseudo-tty allocated, defaulting to `True` :param line_echo: (bool) Whether or not to echo completed input lines when they are entered, rather than waiting for the application to read and echo them, defaulting to `True`. Setting this to `False` and performing the echo in the application can better synchronize input and output, especially when there are input prompts. :param line_history: (int) The number of lines of input line history to store in the line editor when it is enabled, defaulting to 1000 :param max_line_length: (int) The maximum number of characters allowed in an input line when the line editor is enabled, defaulting to 1024 :param rdns_lookup: (optional) Whether or not to perform reverse DNS lookups on the client's IP address to enable hostname-based matches in authorized key file "from" options and "Match Host" config options, defaulting to `False`. :param x11_forwarding: (optional) Whether or not to allow forwarding of X11 connections back to the client when the client supports it, defaulting to `False` :param x11_auth_path: (optional) The path to the Xauthority file to write X11 authentication data to, defaulting to the value in the environment variable `XAUTHORITY` or the file :file:`.Xauthority` in the user's home directory if that's not set :param agent_forwarding: (optional) Whether or not to allow forwarding of ssh-agent requests back to the client when the client supports it, defaulting to `True` :param process_factory: (optional) A `callable` or coroutine handler function which takes an AsyncSSH :class:`SSHServerProcess` argument that will be called each time a new shell, exec, or subsystem other than SFTP is requested by the client. If set, this takes precedence over the `session_factory` argument. :param session_factory: (optional) A `callable` or coroutine handler function which takes AsyncSSH stream objects for stdin, stdout, and stderr that will be called each time a new shell, exec, or subsystem other than SFTP is requested by the client. If not specified, sessions are rejected by default unless the :meth:`session_requested() ` method is overridden on the :class:`SSHServer` object returned by `server_factory` to make this decision. :param encoding: (optional) The Unicode encoding to use for data exchanged on sessions on this server, defaulting to UTF-8 (ISO 10646) format. If `None` is passed in, the application can send and receive raw bytes. :param errors: (optional) The error handling strategy to apply on Unicode encode/decode errors of data exchanged on sessions on this server, defaulting to 'strict'. :param sftp_factory: (optional) A `callable` which returns an :class:`SFTPServer` object that will be created each time an SFTP session is requested by the client, or `True` to use the base :class:`SFTPServer` class to handle SFTP requests. If not specified, SFTP sessions are rejected by default. :param sftp_version: (optional) The maximum version of the SFTP protocol to support, currently either 3 or 4, defaulting to 3. :param allow_scp: (optional) Whether or not to allow incoming scp requests to be accepted. This option can only be used in conjunction with `sftp_factory`. If not specified, scp requests will be passed as regular commands to the `process_factory` or `session_factory`. to the client when the client supports it, defaulting to `True` :param window: (optional) The receive window size for sessions on this server :param max_pktsize: (optional) The maximum packet size for sessions on this server :param server_version: (optional) An ASCII string to advertise to SSH clients as the version of this server, defaulting to `'AsyncSSH'` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms `, :param encryption_algs: (optional) A list of encryption algorithms to use during the SSH handshake, taken from :ref:`encryption algorithms `. :param mac_algs: (optional) A list of MAC algorithms to use during the SSH handshake, taken from :ref:`MAC algorithms `. :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or `None` to disable compression. The server defaults to allowing either no compression or compression after auth, depending on what the client requests. :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH handshake, taken from :ref:`signature algorithms `. :param rekey_bytes: (optional) The number of bytes which can be sent before the SSH session key is renegotiated, defaulting to 1 GB. :param rekey_seconds: (optional) The maximum time in seconds before the SSH session key is renegotiated, defaulting to 1 hour. :param connect_timeout: (optional) The maximum time in seconds allowed to complete an outbound SSH connection. This includes the time to establish the TCP connection and the time to perform the initial SSH protocol handshake, key exchange, and authentication. This is disabled by default, relying on the system's default TCP connect timeout and AsyncSSH's login timeout. :param login_timeout: (optional) The maximum time in seconds allowed for authentication to complete, defaulting to 2 minutes. Setting this to 0 will disable the login timeout. .. note:: This timeout only applies after the SSH TCP connection is established. To set a timeout which includes establishing the TCP connection, use the `connect_timeout` argument above. :param keepalive_interval: (optional) The time in seconds to wait before sending a keepalive message if no data has been received from the client. This defaults to 0, which disables sending these messages. :param keepalive_count_max: (optional) The maximum number of keepalive messages which will be sent without getting a response before disconnecting a client. This defaults to 3, but only applies when keepalive_interval is non-zero. :param tcp_keepalive: (optional) Whether or not to enable keepalive probes at the TCP level to detect broken connections, defaulting to `True`. :param canonicalize_hostname: (optional) Whether or not to enable hostname canonicalization, defaulting to `False`, in which case hostnames are passed as-is to the system resolver. If set to `True`, requests that don't involve a proxy tunnel or command will attempt to canonicalize the hostname using canonical_domains and rules in canonicalize_permitted_cnames. If set to `'always'`, hostname canonicalization is also applied to proxied requests. :param canonical_domains: (optional) When canonicalize_hostname is set, this specifies list of domain suffixes in which to search for the hostname. :param canonicalize_fallback_local: (optional) Whether or not to fall back to looking up the hostname against the system resolver's search domains when no matches are found in canonical_domains, defaulting to `True`. :param canonicalize_max_dots: (optional) Tha maximum number of dots which can appear in a hostname before hostname canonicalization is disabled, defaulting to 1. Hostnames with more than this number of dots are treated as already being fully qualified and passed as-is to the system resolver. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. .. note:: Specifying configuration files when creating an :class:`SSHServerConnectionOptions` object will cause the config file to be read and parsed at the time of creation of the object, including evaluation of any conditional blocks. If you want the config to be parsed for every new connection, this argument should be added to the connect or listen calls instead. However, if you want to save the parsing overhead and your configuration doesn't depend on conditions that would change between calls, this argument may be an option. :param options: (optional) A previous set of options to use as the base to incrementally build up a configuration. When an option is not explicitly specified, its value will be pulled from this options object (if present) before falling back to the default value. :type server_factory: `callable` returning :class:`SSHServer` :type proxy_command: `str` or `list` of `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type server_host_keys: *see* :ref:`SpecifyingPrivateKeys` :type server_host_certs: *see* :ref:`SpecifyingCertificates` :type send_server_host_keys: `bool` :type passphrase: `str` or `bytes` :type known_client_hosts: *see* :ref:`SpecifyingKnownHosts` :type trust_client_host: `bool` :type authorized_client_keys: *see* :ref:`SpecifyingAuthorizedKeys` :type x509_trusted_certs: *see* :ref:`SpecifyingCertificates` :type x509_trusted_cert_paths: `list` of `str` :type x509_purposes: *see* :ref:`SpecifyingX509Purposes` :type host_based_auth: `bool` :type public_key_auth: `bool` :type kbdint_auth: `bool` :type password_auth: `bool` :type gss_host: `str` :type gss_store: `str`, `bytes`, or a `dict` with `str` or `bytes` keys and values :type gss_kex: `bool` :type gss_auth: `bool` :type allow_pty: `bool` :type line_editor: `bool` :type line_echo: `bool` :type line_history: `int` :type max_line_length: `int` :type rdns_lookup: `bool` :type x11_forwarding: `bool` :type x11_auth_path: `str` :type agent_forwarding: `bool` :type process_factory: `callable` or coroutine :type session_factory: `callable` or coroutine :type encoding: `str` or `None` :type errors: `str` :type sftp_factory: `callable` :type sftp_version: `int` :type allow_scp: `bool` :type window: `int` :type max_pktsize: `int` :type server_version: `str` :type kex_algs: `str` or `list` of `str` :type encryption_algs: `str` or `list` of `str` :type mac_algs: `str` or `list` of `str` :type compression_algs: `str` or `list` of `str` :type signature_algs: `str` or `list` of `str` :type rekey_bytes: *see* :ref:`SpecifyingByteCounts` :type rekey_seconds: *see* :ref:`SpecifyingTimeIntervals` :type connect_timeout: *see* :ref:`SpecifyingTimeIntervals` :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` :type tcp_keepalive: `bool` :type canonicalize_hostname: `bool` or `'always'` :type canonical_domains: `list` of `str` :type canonicalize_fallback_local: `bool` :type canonicalize_max_dots: `int` :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` """ config: SSHServerConfig server_factory: _ServerFactory server_version: bytes server_host_keys: 'OrderedDict[bytes, SSHKeyPair]' all_server_host_keys: 'OrderedDict[bytes, SSHKeyPair]' send_server_host_keys: bool known_client_hosts: KnownHostsArg trust_client_host: bool authorized_client_keys: DefTuple[Optional[SSHAuthorizedKeys]] gss_host: Optional[str] gss_store: Optional[Dict[BytesOrStr, BytesOrStr]] gss_kex: bool gss_auth: bool allow_pty: bool line_editor: bool line_echo: bool line_history: int max_line_length: int rdns_lookup: bool x11_forwarding: bool x11_auth_path: Optional[str] agent_forwarding: bool process_factory: Optional[SSHServerProcessFactory] session_factory: Optional[SSHServerSessionFactory] encoding: Optional[str] errors: str sftp_factory: Optional[SFTPServerFactory] sftp_version: int allow_scp: bool window: int max_pktsize: int # pylint: disable=arguments-differ def prepare(self, # type: ignore loop: Optional[asyncio.AbstractEventLoop] = None, last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, canonical: bool = False, final: bool = False, accept_addr: str = '', accept_port: int = 0, username: str = '', client_host: str = '', client_addr: str = '', server_factory: Optional[_ServerFactory] = None, server_version: _VersionArg = (), host: str = '', port: DefTuple[int] = (), tunnel: object = (), proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), canonicalize_hostname: DefTuple[Union[bool, str]] = (), canonical_domains: DefTuple[Sequence[str]] = (), canonicalize_fallback_local: DefTuple[bool] = (), canonicalize_max_dots: DefTuple[int] = (), canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), public_key_auth: _AuthArg = (), kbdint_auth: _AuthArg = (), password_auth: _AuthArg = (), x509_trusted_certs: CertListArg = (), x509_trusted_cert_paths: Sequence[FilePath] = (), x509_purposes: X509CertPurposes = 'secureShellClient', rekey_bytes: DefTuple[Union[int, str]] = (), rekey_seconds: DefTuple[Union[float, str]] = (), connect_timeout: Optional[Union[float, str]] = None, login_timeout: DefTuple[Union[float, str]] = (), keepalive_interval: DefTuple[Union[float, str]] = (), keepalive_count_max: DefTuple[int] = (), server_host_keys: KeyPairListArg = (), server_host_certs: CertListArg = (), send_server_host_keys: bool = False, passphrase: Optional[BytesOrStr] = None, known_client_hosts: KnownHostsArg = None, trust_client_host: bool = False, authorized_client_keys: _AuthKeysArg = (), gss_host: DefTuple[Optional[str]] = (), gss_store: Optional[Union[BytesOrStr, BytesOrStrDict]] = None, gss_kex: DefTuple[bool] = (), gss_auth: DefTuple[bool] = (), allow_pty: DefTuple[bool] = (), line_editor: bool = True, line_echo: bool = True, line_history: int = _DEFAULT_LINE_HISTORY, max_line_length: int = _DEFAULT_MAX_LINE_LENGTH, rdns_lookup: DefTuple[bool] = (), x11_forwarding: bool = False, x11_auth_path: Optional[str] = None, agent_forwarding: DefTuple[bool] = (), process_factory: Optional[SSHServerProcessFactory] = None, session_factory: Optional[SSHServerSessionFactory] = None, encoding: Optional[str] = 'utf-8', errors: str = 'strict', sftp_factory: Optional[SFTPServerFactory] = None, sftp_version: int = MIN_SFTP_VERSION, allow_scp: bool = False, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> None: """Prepare server connection configuration options""" config = SSHServerConfig.load(last_config, config, reload, canonical, final, accept_addr, accept_port, username, client_host, client_addr) if login_timeout == (): login_timeout = \ cast(Union[float, str], config.get('LoginGraceTime', _DEFAULT_LOGIN_TIMEOUT)) login_timeout: Union[float, str] if keepalive_interval == (): keepalive_interval = \ cast(Union[float, str], config.get('ClientAliveInterval', _DEFAULT_KEEPALIVE_INTERVAL)) keepalive_interval: Union[float, str] if keepalive_count_max == (): keepalive_count_max = \ cast(int, config.get('ClientAliveCountMax', _DEFAULT_KEEPALIVE_COUNT_MAX)) keepalive_count_max: int super().prepare(config, server_factory or SSHServer, server_version, host, port, tunnel, passphrase, proxy_command, family, local_addr, tcp_keepalive, canonicalize_hostname, canonical_domains, canonicalize_fallback_local, canonicalize_max_dots, canonicalize_permitted_cnames, kex_algs, encryption_algs, mac_algs, compression_algs, signature_algs, host_based_auth, public_key_auth, kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) if server_host_keys == (): server_host_keys = cast(Sequence[str], config.get('HostKey')) if server_host_certs == (): server_host_certs = cast(Sequence[str], config.get('HostCertificate', ())) server_keys = load_keypairs(server_host_keys, passphrase, server_host_certs, loop=loop) self.server_host_keys = OrderedDict() self.all_server_host_keys = OrderedDict() for keypair in server_keys: for alg in keypair.host_key_algorithms: if alg in self.server_host_keys and not send_server_host_keys: raise ValueError('Multiple keys of type ' f'{alg.decode("ascii")} found: ' 'Enable send_server_host_keys to ' 'allow reserved keys to be configured') if alg not in self.server_host_keys: self.server_host_keys[alg] = keypair if send_server_host_keys: self.all_server_host_keys[keypair.public_data] = keypair self.known_client_hosts = known_client_hosts self.trust_client_host = trust_client_host if authorized_client_keys == () and reload: authorized_client_keys = \ cast(List[str], config.get('AuthorizedKeysFile')) if isinstance(authorized_client_keys, (str, list)): self.authorized_client_keys = \ read_authorized_keys(authorized_client_keys) else: self.authorized_client_keys = authorized_client_keys if gss_host == (): gss_host = socket.gethostname() if '.' not in gss_host: gss_host = socket.getfqdn() gss_host: Optional[str] self.gss_host = gss_host if isinstance(gss_store, (bytes, str)): self.gss_store = {'ccache': gss_store} else: self.gss_store = gss_store self.gss_kex = cast(bool, gss_kex if gss_kex != () else config.get('GSSAPIKeyExchange', True)) self.gss_auth = cast(bool, gss_auth if gss_auth != () else config.get('GSSAPIAuthentication', True)) if not server_keys and not gss_host: raise ValueError('No server host keys provided') self.allow_pty = cast(bool, allow_pty if allow_pty != () else config.get('PermitTTY', True)) self.line_editor = line_editor self.line_echo = line_echo self.line_history = line_history self.max_line_length = max_line_length self.rdns_lookup = cast(bool, rdns_lookup if rdns_lookup != () else config.get('UseDNS', False)) self.x11_forwarding = x11_forwarding self.x11_auth_path = x11_auth_path self.agent_forwarding = cast(bool, agent_forwarding if agent_forwarding != () else config.get('AllowAgentForwarding', True)) self.process_factory = process_factory self.session_factory = session_factory self.encoding = encoding self.errors = errors self.sftp_factory = SFTPServer if sftp_factory is True else sftp_factory self.sftp_version = sftp_version self.allow_scp = allow_scp self.window = window self.max_pktsize = max_pktsize @async_context_manager async def run_client(sock: socket.socket, config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None, **kwargs: object) -> SSHClientConnection: """Start an SSH client connection on an already-connected socket This function is a coroutine which starts an SSH client on an existing already-connected socket. It can be used instead of :func:`connect` when a socket is connected outside of asyncio. :param sock: An existing already-connected socket to run an SSH client on, instead of opening up a new connection. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the configuration from the file :file:`.ssh/config`. If this argument is explicitly set to `None`, no new configuration files will be loaded, but any configuration loaded when constructing the `options` argument will still apply. See :ref:`SupportedClientConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when establishing the SSH client connection. These options can be specified either through this parameter or as direct keyword arguments to this function. :type sock: :class:`socket.socket` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` :returns: :class:`SSHClientConnection` """ def conn_factory() -> SSHClientConnection: """Return an SSH client connection factory""" return SSHClientConnection(loop, new_options, wait='auth') loop = asyncio.get_event_loop() new_options = await SSHClientConnectionOptions.construct( options, config=config, **kwargs) return await asyncio.wait_for( _connect(new_options, config, loop, 0, sock, conn_factory, 'Starting SSH client on'), timeout=new_options.connect_timeout) @async_context_manager async def run_server(sock: socket.socket, config: DefTuple[ConfigPaths] = (), options: Optional[SSHServerConnectionOptions] = None, **kwargs: object) -> SSHServerConnection: """Start an SSH server connection on an already-connected socket This function is a coroutine which starts an SSH server on an existing already-connected TCP socket. It can be used instead of :func:`listen` when connections are accepted outside of asyncio. :param sock: An existing already-connected socket to run SSH over, instead of opening up a new connection. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. By default, no OpenSSH configuration files will be loaded. See :ref:`SupportedServerConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when starting the reverse-direction SSH server. These options can be specified either through this parameter or as direct keyword arguments to this function. :type sock: :class:`socket.socket` :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` :returns: :class:`SSHServerConnection` """ def conn_factory() -> SSHServerConnection: """Return an SSH server connection factory""" return SSHServerConnection(loop, new_options, wait='auth') loop = asyncio.get_event_loop() new_options = await SSHServerConnectionOptions.construct( options, config=config, **kwargs) return await asyncio.wait_for( _connect(new_options, config, loop, 0, sock, conn_factory, 'Starting SSH server on'), timeout=new_options.connect_timeout) @async_context_manager async def connect(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None, **kwargs: object) -> SSHClientConnection: """Make an SSH client connection This function is a coroutine which can be run to create an outbound SSH client connection to the specified host and port. When successful, the following steps occur: 1. The connection is established and an instance of :class:`SSHClientConnection` is created to represent it. 2. The `client_factory` is called without arguments and should return an instance of :class:`SSHClient` or a subclass. 3. The client object is tied to the connection and its :meth:`connection_made() ` method is called. 4. The SSH handshake and authentication process is initiated, calling methods on the client object if needed. 5. When authentication completes successfully, the client's :meth:`auth_completed() ` method is called. 6. The coroutine returns the :class:`SSHClientConnection`. At this point, the connection is ready for sessions to be opened or port forwarding to be set up. If an error occurs, it will be raised as an exception and the partially open connection and client objects will be cleaned up. :param host: (optional) The hostname or address to connect to. :param port: (optional) The port number to connect to. If not specified, the default SSH port is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting :param sock: (optional) An existing already-connected socket to run SSH over, instead of opening up a new connection. When this is specified, none of host, port family, flags, or local_addr should be specified. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the configuration from the file :file:`.ssh/config`. If this argument is explicitly set to `None`, no new configuration files will be loaded, but any configuration loaded when constructing the `options` argument will still apply. See :ref:`SupportedClientConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when establishing the SSH client connection. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` :type sock: :class:`socket.socket` or `None` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` :returns: :class:`SSHClientConnection` """ def conn_factory() -> SSHClientConnection: """Return an SSH client connection factory""" return SSHClientConnection(loop, new_options, wait='auth') loop = asyncio.get_event_loop() new_options = await SSHClientConnectionOptions.construct( options, config=config, host=host, port=port, tunnel=tunnel, family=family, local_addr=local_addr, **kwargs) return await asyncio.wait_for( _connect(new_options, config, loop, flags, sock, conn_factory, 'Opening SSH connection to'), timeout=new_options.connect_timeout) @async_context_manager async def connect_reverse( host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHServerConnectionOptions] = None, **kwargs: object) -> SSHServerConnection: """Create a reverse direction SSH connection This function is a coroutine which behaves similar to :func:`connect`, making an outbound TCP connection to a remote server. However, instead of starting up an SSH client which runs on that outbound connection, this function starts up an SSH server, expecting the remote system to start up a reverse-direction SSH client. Arguments to this function are the same as :func:`connect`, except that the `options` are of type :class:`SSHServerConnectionOptions` instead of :class:`SSHClientConnectionOptions`. :param host: (optional) The hostname or address to connect to. :param port: (optional) The port number to connect to. If not specified, the default SSH port is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting :param sock: (optional) An existing already-connected socket to run SSH over, instead of opening up a new connection. When this is specified, none of host, port family, flags, or local_addr should be specified. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. By default, no OpenSSH configuration files will be loaded. See :ref:`SupportedServerConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when starting the reverse-direction SSH server. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` :type sock: :class:`socket.socket` or `None` :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` :returns: :class:`SSHServerConnection` """ def conn_factory() -> SSHServerConnection: """Return an SSH server connection factory""" return SSHServerConnection(loop, new_options, wait='auth') loop = asyncio.get_event_loop() new_options = await SSHServerConnectionOptions.construct( options, config=config, host=host, port=port, tunnel=tunnel, family=family, local_addr=local_addr, **kwargs) return await asyncio.wait_for( _connect(new_options, config, loop, flags, sock, conn_factory, 'Opening reverse SSH connection to'), timeout=new_options.connect_timeout) @async_context_manager async def listen(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelListener] = (), family: DefTuple[int] = (), flags:int = socket.AI_PASSIVE, backlog: int = 100, sock: Optional[socket.socket] = None, reuse_address: bool = False, reuse_port: bool = False, acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHServerConnectionOptions] = None, **kwargs: object) -> SSHAcceptor: """Start an SSH server This function is a coroutine which can be run to create an SSH server listening on the specified host and port. The return value is an :class:`SSHAcceptor` which can be used to shut down the listener. :param host: (optional) The hostname or address to listen on. If not specified, listeners are created for all addresses. :param port: (optional) The port number to listen on. If not specified, the default SSH port is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param family: (optional) The address family to use when creating the server. By default, the address families are automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host :param backlog: (optional) The maximum number of queued connections allowed on listeners :param sock: (optional) A pre-existing socket to use instead of creating and binding a new socket. When this is specified, host and port should not be specified. :param reuse_address: (optional) Whether or not to reuse a local socket in the TIME_WAIT state without waiting for its natural timeout to expire. If not specified, this will be automatically set to `True` on UNIX. :param reuse_port: (optional) Whether or not to allow this socket to be bound to the same port other existing sockets are bound to, so long as they all set this flag when being created. If not specified, the default is to not allow this. This option is not supported on Windows or Python versions prior to 3.4.4. :param acceptor: (optional) A `callable` or coroutine which will be called when the SSH handshake completes on an accepted connection, taking the :class:`SSHServerConnection` as an argument. :param error_handler: (optional) A `callable` which will be called whenever the SSH handshake fails on an accepted connection. It is called with the failed :class:`SSHServerConnection` and an exception object describing the failure. If not specified, failed handshakes result in the connection object being silently cleaned up. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. By default, no OpenSSH configuration files will be loaded. See :ref:`SupportedServerConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when accepting SSH server connections. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type backlog: `int` :type sock: :class:`socket.socket` or `None` :type reuse_address: `bool` :type reuse_port: `bool` :type acceptor: `callable` or coroutine :type error_handler: `callable` :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` :returns: :class:`SSHAcceptor` """ def conn_factory() -> SSHServerConnection: """Return an SSH server connection factory""" return SSHServerConnection(loop, new_options, acceptor, error_handler) loop = asyncio.get_event_loop() new_options = await SSHServerConnectionOptions.construct( options, config=config, host=host, port=port, tunnel=tunnel, family=family, **kwargs) # pylint: disable=attribute-defined-outside-init new_options.proxy_command = None return await asyncio.wait_for( _listen(new_options, config, loop, flags, backlog, sock, reuse_address, reuse_port, conn_factory, 'Creating SSH listener on'), timeout=new_options.connect_timeout) @async_context_manager async def listen_reverse(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelListener] = (), family: DefTuple[int] = (), flags: int = socket.AI_PASSIVE, backlog: int = 100, sock: Optional[socket.socket] = None, reuse_address: bool = False, reuse_port: bool = False, acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None, **kwargs: object) -> SSHAcceptor: """Create a reverse-direction SSH listener This function is a coroutine which behaves similar to :func:`listen`, creating a listener which accepts inbound connections on the specified host and port. However, instead of starting up an SSH server on each inbound connection, it starts up a reverse-direction SSH client, expecting the remote system making the connection to start up a reverse-direction SSH server. Arguments to this function are the same as :func:`listen`, except that the `options` are of type :class:`SSHClientConnectionOptions` instead of :class:`SSHServerConnectionOptions`. The return value is an :class:`SSHAcceptor` which can be used to shut down the reverse listener. :param host: (optional) The hostname or address to listen on. If not specified, listeners are created for all addresses. :param port: (optional) The port number to listen on. If not specified, the default SSH port is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param family: (optional) The address family to use when creating the server. By default, the address families are automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host :param backlog: (optional) The maximum number of queued connections allowed on listeners :param sock: (optional) A pre-existing socket to use instead of creating and binding a new socket. When this is specified, host and port should not :param reuse_address: (optional) Whether or not to reuse a local socket in the TIME_WAIT state without waiting for its natural timeout to expire. If not specified, this will be automatically set to `True` on UNIX. :param reuse_port: (optional) Whether or not to allow this socket to be bound to the same port other existing sockets are bound to, so long as they all set this flag when being created. If not specified, the default is to not allow this. This option is not supported on Windows or Python versions prior to 3.4.4. :param acceptor: (optional) A `callable` or coroutine which will be called when the SSH handshake completes on an accepted connection, taking the :class:`SSHClientConnection` as an argument. :param error_handler: (optional) A `callable` which will be called whenever the SSH handshake fails on an accepted connection. It is called with the failed :class:`SSHClientConnection` and an exception object describing the failure. If not specified, failed handshakes result in the connection object being silently cleaned up. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the configuration from the file :file:`.ssh/config`. If this argument is explicitly set to `None`, no new configuration files will be loaded, but any configuration loaded when constructing the `options` argument will still apply. See :ref:`SupportedClientConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when starting reverse-direction SSH clients. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type backlog: `int` :type sock: :class:`socket.socket` or `None` :type reuse_address: `bool` :type reuse_port: `bool` :type acceptor: `callable` or coroutine :type error_handler: `callable` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` :returns: :class:`SSHAcceptor` """ def conn_factory() -> SSHClientConnection: """Return an SSH client connection factory""" return SSHClientConnection(loop, new_options, acceptor, error_handler) loop = asyncio.get_event_loop() new_options = await SSHClientConnectionOptions.construct( options, config=config, host=host, port=port, tunnel=tunnel, family=family, **kwargs) # pylint: disable=attribute-defined-outside-init new_options.proxy_command = None return await asyncio.wait_for( _listen(new_options, config, loop, flags, backlog, sock, reuse_address, reuse_port, conn_factory, 'Creating reverse direction SSH listener on'), timeout=new_options.connect_timeout) async def create_connection(client_factory: _ClientFactory, host = '', port: DefTuple[int] = (), **kwargs: object) -> \ Tuple[SSHClientConnection, SSHClient]: """Create an SSH client connection This is a coroutine which wraps around :func:`connect`, providing backward compatibility with older AsyncSSH releases. The only differences are that the `client_factory` argument is the first positional argument in this call rather than being a keyword argument or specified via an :class:`SSHClientConnectionOptions` object and the return value is a tuple of an :class:`SSHClientConnection` and :class:`SSHClient` rather than just the connection, mirroring :meth:`asyncio.AbstractEventLoop.create_connection`. :returns: An :class:`SSHClientConnection` and :class:`SSHClient` """ conn = await connect(host, port, client_factory=client_factory, **kwargs) # type: ignore return conn, cast(SSHClient, conn.get_owner()) @async_context_manager async def create_server(server_factory: _ServerFactory, host = '', port: DefTuple[int] = (), **kwargs: object) -> SSHAcceptor: """Create an SSH server This is a coroutine which wraps around :func:`listen`, providing backward compatibility with older AsyncSSH releases. The only difference is that the `server_factory` argument is the first positional argument in this call rather than being a keyword argument or specified via an :class:`SSHServerConnectionOptions` object, mirroring :meth:`asyncio.AbstractEventLoop.create_server`. """ return await listen(host, port, server_factory=server_factory, **kwargs) # type: ignore async def get_server_host_key( host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), kex_algs: _AlgsArg = (), server_host_key_algs: _AlgsArg = (), config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None) \ -> Optional[SSHKey]: """Retrieve an SSH server's host key This is a coroutine which can be run to connect to an SSH server and return the server host key presented during the SSH handshake. A list of server host key algorithms can be provided to specify which host key types the server is allowed to choose from. If the key exchange is successful, the server host key sent during the handshake is returned. .. note:: Not all key exchange methods involve the server presenting a host key. If something like GSS key exchange is used without a server host key, this method may return `None` even when the handshake completes. :param host: (optional) The hostname or address to connect to :param port: (optional) The port number to connect to. If not specified, the default SSH port is used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run to make a connection to the SSH server. Data will be forwarded to this process over stdin/stdout instead of opening a TCP connection. If specified as a string, standard shell quoting will be applied when splitting the command and its arguments. :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting :param sock: (optional) An existing already-connected socket to run SSH over, instead of opening up a new connection. When this is specified, none of host, port family, flags, or local_addr should be specified. :param client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to `'AsyncSSH'` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms ` :param server_host_key_algs: (optional) A list of server host key algorithms to allow during the SSH handshake, taken from :ref:`server host key algorithms `. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the configuration from the file :file:`.ssh/config`. If this argument is explicitly set to `None`, no new configuration files will be loaded, but any configuration loaded when constructing the `options` argument will still apply. See :ref:`SupportedClientConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when establishing the SSH client connection used to retrieve the server host key. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type proxy_command: `str` or `list` of `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` :type sock: :class:`socket.socket` or `None` :type client_version: `str` :type kex_algs: `str` or `list` of `str` :type server_host_key_algs: `str` or `list` of `str` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` :returns: An :class:`SSHKey` public key or `None` """ def conn_factory() -> SSHClientConnection: """Return an SSH client connection factory""" return SSHClientConnection(loop, new_options, wait='kex') loop = asyncio.get_event_loop() new_options = await SSHClientConnectionOptions.construct( options, config=config, host=host, port=port, tunnel=tunnel, proxy_command=proxy_command, family=family, local_addr=local_addr, known_hosts=None, server_host_key_algs=server_host_key_algs, x509_trusted_certs=None, x509_trusted_cert_paths=None, x509_purposes='any', gss_host=None, kex_algs=kex_algs, client_version=client_version) conn = await asyncio.wait_for( _connect(new_options, config, loop, flags, sock, conn_factory, 'Fetching server host key from'), timeout=new_options.connect_timeout) server_host_key = conn.get_server_host_key() conn.abort() await conn.wait_closed() return server_host_key async def get_server_auth_methods( host = '', port: DefTuple[int] = (), username: DefTuple[str] = (), *, tunnel: DefTuple[_TunnelConnector] = (), proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), kex_algs: _AlgsArg = (), server_host_key_algs: _AlgsArg = (), config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None) -> Sequence[str]: """Retrieve an SSH server's allowed auth methods This is a coroutine which can be run to connect to an SSH server and return the auth methods available to authenticate to it. .. note:: The key exchange with the server must complete successfully before the list of available auth methods can be returned, so be sure to specify any arguments needed to complete the key exchange. Also, auth methods may vary by user, so you may want to specify the specific user you would like to get auth methods for. :param host: (optional) The hostname or address to connect to :param port: (optional) The port number to connect to. If not specified, the default SSH port is used. :param username: (optional) Username to authenticate as on the server. If not specified, the currently logged in user on the local machine will be used. :param tunnel: (optional) An existing SSH client connection that this new connection should be tunneled over. If set, a direct TCP/IP tunnel will be opened over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a connection will be made to that host and then used as a tunnel. A comma-separated list may also be specified to establish a tunnel through multiple hosts. .. note:: When specifying tunnel as a string, any config options in the call will apply only when opening a connection to the final destination host and port. However, settings to use when opening tunnels may be specified via a configuration file. To get more control of config options used to open the tunnel, :func:`connect` can be called explicitly, and the resulting client connection can be passed as the tunnel argument. :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run to make a connection to the SSH server. Data will be forwarded to this process over stdin/stdout instead of opening a TCP connection. If specified as a string, standard shell quoting will be applied when splitting the command and its arguments. :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. :param flags: (optional) The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting :param sock: (optional) An existing already-connected socket to run SSH over, instead of opening up a new connection. When this is specified, none of host, port family, flags, or local_addr should be specified. :param client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to `'AsyncSSH'` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, taken from :ref:`key exchange algorithms ` :param server_host_key_algs: (optional) A list of server host key algorithms to allow during the SSH handshake, taken from :ref:`server host key algorithms `. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the configuration from the file :file:`.ssh/config`. If this argument is explicitly set to `None`, no new configuration files will be loaded, but any configuration loaded when constructing the `options` argument will still apply. See :ref:`SupportedClientConfigOptions` for details on what configuration options are currently supported. :param options: (optional) Options to use when establishing the SSH client connection used to retrieve the server host key. These options can be specified either through this parameter or as direct keyword arguments to this function. :type host: `str` :type port: `int` :type tunnel: :class:`SSHClientConnection` or `str` :type proxy_command: `str` or `list` of `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` :type sock: :class:`socket.socket` or `None` :type client_version: `str` :type kex_algs: `str` or `list` of `str` :type server_host_key_algs: `str` or `list` of `str` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` :returns: a `list` of `str` """ def conn_factory() -> SSHClientConnection: """Return an SSH client connection factory""" return SSHClientConnection(loop, new_options, wait='auth_methods') loop = asyncio.get_event_loop() new_options = await SSHClientConnectionOptions.construct( options, config=config, host=host, port=port, username=username, tunnel=tunnel, proxy_command=proxy_command, family=family, local_addr=local_addr, known_hosts=None, server_host_key_algs=server_host_key_algs, x509_trusted_certs=None, x509_trusted_cert_paths=None, x509_purposes='any', gss_host=None, kex_algs=kex_algs, client_version=client_version) conn = await asyncio.wait_for( _connect(new_options, config, loop, flags, sock, conn_factory, 'Fetching server auth methods from'), timeout=new_options.connect_timeout) server_auth_methods = conn.get_server_auth_methods() conn.abort() await conn.wait_closed() return server_auth_methods asyncssh-2.20.0/asyncssh/constants.py000066400000000000000000000321111475467777400177010ustar00rootroot00000000000000# Copyright (c) 2013-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH constants""" # Default language for error messages DEFAULT_LANG = 'en-US' # Default SSH listening port DEFAULT_PORT = 22 # SSH message codes MSG_DISCONNECT = 1 MSG_IGNORE = 2 MSG_UNIMPLEMENTED = 3 MSG_DEBUG = 4 MSG_SERVICE_REQUEST = 5 MSG_SERVICE_ACCEPT = 6 MSG_EXT_INFO = 7 MSG_KEXINIT = 20 MSG_NEWKEYS = 21 MSG_KEX_FIRST = 30 MSG_KEX_LAST = 49 MSG_USERAUTH_REQUEST = 50 MSG_USERAUTH_FAILURE = 51 MSG_USERAUTH_SUCCESS = 52 MSG_USERAUTH_BANNER = 53 MSG_USERAUTH_FIRST = 60 MSG_USERAUTH_LAST = 79 MSG_GLOBAL_REQUEST = 80 MSG_REQUEST_SUCCESS = 81 MSG_REQUEST_FAILURE = 82 MSG_CHANNEL_OPEN = 90 MSG_CHANNEL_OPEN_CONFIRMATION = 91 MSG_CHANNEL_OPEN_FAILURE = 92 MSG_CHANNEL_WINDOW_ADJUST = 93 MSG_CHANNEL_DATA = 94 MSG_CHANNEL_EXTENDED_DATA = 95 MSG_CHANNEL_EOF = 96 MSG_CHANNEL_CLOSE = 97 MSG_CHANNEL_REQUEST = 98 MSG_CHANNEL_SUCCESS = 99 MSG_CHANNEL_FAILURE = 100 # Messages 90-92 are excluded here as they relate to opening a new channel MSG_CHANNEL_FIRST = 93 MSG_CHANNEL_LAST = 127 # SSH disconnect reason codes DISC_HOST_NOT_ALLOWED_TO_CONNECT = 1 DISC_PROTOCOL_ERROR = 2 DISC_KEY_EXCHANGE_FAILED = 3 DISC_RESERVED = 4 DISC_MAC_ERROR = 5 DISC_COMPRESSION_ERROR = 6 DISC_SERVICE_NOT_AVAILABLE = 7 DISC_PROTOCOL_VERSION_NOT_SUPPORTED = 8 DISC_HOST_KEY_NOT_VERIFIABLE = 9 DISC_CONNECTION_LOST = 10 DISC_BY_APPLICATION = 11 DISC_TOO_MANY_CONNECTIONS = 12 DISC_AUTH_CANCELLED_BY_USER = 13 DISC_NO_MORE_AUTH_METHODS_AVAILABLE = 14 DISC_ILLEGAL_USER_NAME = 15 DISC_HOST_KEY_NOT_VERIFYABLE = 9 # Error in naming, left here to not # break backward compatibility # SSH channel open failure reason codes OPEN_ADMINISTRATIVELY_PROHIBITED = 1 OPEN_CONNECT_FAILED = 2 OPEN_UNKNOWN_CHANNEL_TYPE = 3 OPEN_RESOURCE_SHORTAGE = 4 # Internal failure reason codes OPEN_REQUEST_X11_FORWARDING_FAILED = 0xfffffffd OPEN_REQUEST_PTY_FAILED = 0xfffffffe OPEN_REQUEST_SESSION_FAILED = 0xffffffff # SFTPv3-v5 packet types FXP_INIT = 1 FXP_VERSION = 2 FXP_OPEN = 3 FXP_CLOSE = 4 FXP_READ = 5 FXP_WRITE = 6 FXP_LSTAT = 7 FXP_FSTAT = 8 FXP_SETSTAT = 9 FXP_FSETSTAT = 10 FXP_OPENDIR = 11 FXP_READDIR = 12 FXP_REMOVE = 13 FXP_MKDIR = 14 FXP_RMDIR = 15 FXP_REALPATH = 16 FXP_STAT = 17 FXP_RENAME = 18 FXP_READLINK = 19 FXP_SYMLINK = 20 FXP_STATUS = 101 FXP_HANDLE = 102 FXP_DATA = 103 FXP_NAME = 104 FXP_ATTRS = 105 FXP_EXTENDED = 200 FXP_EXTENDED_REPLY = 201 # SFTPv6 packet types FXP_LINK = 21 FXP_BLOCK = 22 FXP_UNBLOCK = 23 # SFTPv3 open flags FXF_READ = 0x00000001 FXF_WRITE = 0x00000002 FXF_APPEND = 0x00000004 FXF_CREAT = 0x00000008 FXF_TRUNC = 0x00000010 FXF_EXCL = 0x00000020 # SFTPv4 open flags FXF_TEXT = 0x00000040 # SFTPv5 open flags FXF_ACCESS_DISPOSITION = 0x00000007 FXF_CREATE_NEW = 0x00000000 FXF_CREATE_TRUNCATE = 0x00000001 FXF_OPEN_EXISTING = 0x00000002 FXF_OPEN_OR_CREATE = 0x00000003 FXF_TRUNCATE_EXISTING = 0x00000004 FXF_APPEND_DATA = 0x00000008 FXF_APPEND_DATA_ATOMIC = 0x00000010 FXF_TEXT_MODE = 0x00000020 FXF_BLOCK_READ = 0x00000040 FXF_BLOCK_WRITE = 0x00000080 FXF_BLOCK_DELETE = 0x00000100 # SFTPv6 open flags FXF_BLOCK_ADVISORY = 0x00000200 FXF_NOFOLLOW = 0x00000400 FXF_DELETE_ON_CLOSE = 0x00000800 FXF_ACCESS_AUDIT_ALARM_INFO = 0x00001000 FXF_ACCESS_BACKUP = 0x00002000 FXF_BACKUP_STREAM = 0x00004000 FXF_OVERRIDE_OWNER = 0x00008000 # SFTPv5-v6 ACE mask values used in desired-access ACE4_READ_DATA = 0x00000001 ACE4_WRITE_DATA = 0x00000002 ACE4_APPEND_DATA = 0x00000004 ACE4_READ_ATTRIBUTES = 0x00000080 ACE4_WRITE_ATTRIBUTES = 0x00000100 # SFTPv3 attribute flags FILEXFER_ATTR_SIZE = 0x00000001 FILEXFER_ATTR_UIDGID = 0x00000002 FILEXFER_ATTR_PERMISSIONS = 0x00000004 FILEXFER_ATTR_ACMODTIME = 0x00000008 FILEXFER_ATTR_EXTENDED = 0x80000000 FILEXFER_ATTR_DEFINED_V3 = 0x8000000f # SFTPv4 attribute flags FILEXFER_ATTR_ACCESSTIME = 0x00000008 FILEXFER_ATTR_CREATETIME = 0x00000010 FILEXFER_ATTR_MODIFYTIME = 0x00000020 FILEXFER_ATTR_ACL = 0x00000040 FILEXFER_ATTR_OWNERGROUP = 0x00000080 FILEXFER_ATTR_SUBSECOND_TIMES = 0x00000100 FILEXFER_ATTR_DEFINED_V4 = 0x800001fd # SFTPv5 attribute flags FILEXFER_ATTR_BITS = 0x00000200 FILEXFER_ATTR_DEFINED_V5 = 0x800003fd # SFTPv6 attribute flags FILEXFER_ATTR_ALLOCATION_SIZE = 0x00000400 FILEXFER_ATTR_TEXT_HINT = 0x00000800 FILEXFER_ATTR_MIME_TYPE = 0x00001000 FILEXFER_ATTR_LINK_COUNT = 0x00002000 FILEXFER_ATTR_UNTRANSLATED_NAME = 0x00004000 FILEXFER_ATTR_CTIME = 0x00008000 FILEXFER_ATTR_DEFINED_V6 = 0x8000fffd # SFTPv4 file types FILEXFER_TYPE_REGULAR = 1 FILEXFER_TYPE_DIRECTORY = 2 FILEXFER_TYPE_SYMLINK = 3 FILEXFER_TYPE_SPECIAL = 4 FILEXFER_TYPE_UNKNOWN = 5 # SFTPv5 file types FILEXFER_TYPE_SOCKET = 6 FILEXFER_TYPE_CHAR_DEVICE = 7 FILEXFER_TYPE_BLOCK_DEVICE = 8 FILEXFER_TYPE_FIFO = 9 # SFTPv5 attrib bits FILEXFER_ATTR_BITS_READONLY = 0x00000001 FILEXFER_ATTR_BITS_SYSTEM = 0x00000002 FILEXFER_ATTR_BITS_HIDDEN = 0x00000004 FILEXFER_ATTR_BITS_CASE_INSENSITIVE = 0x00000008 FILEXFER_ATTR_BITS_ARCHIVE = 0x00000010 FILEXFER_ATTR_BITS_ENCRYPTED = 0x00000020 FILEXFER_ATTR_BITS_COMPRESSED = 0x00000040 FILEXFER_ATTR_BITS_SPARSE = 0x00000080 FILEXFER_ATTR_BITS_APPEND_ONLY = 0x00000100 FILEXFER_ATTR_BITS_IMMUTABLE = 0x00000200 FILEXFER_ATTR_BITS_SYNC = 0x00000400 # SFTPv6 attrib bits FILEXFER_ATTR_BITS_TRANSLATION_ERR = 0x00000800 # SFTPv6 text hint flags FILEXFER_ATTR_KNOWN_TEXT = 0 FILEXFER_ATTR_GUESSED_TEXT = 1 FILEXFER_ATTR_KNOWN_BINARY = 2 FILEXFER_ATTR_GUESSED_BINARY = 3 # SFTPv5 rename flags FXR_OVERWRITE = 0x00000001 FXR_ATOMIC = 0x00000002 FXR_NATIVE = 0x00000004 # SFTPv6 realpath control byte FXRP_NO_CHECK = 1 FXRP_STAT_IF_EXISTS = 2 FXRP_STAT_ALWAYS = 3 # OpenSSH statvfs attribute flags FXE_STATVFS_ST_RDONLY = 0x1 FXE_STATVFS_ST_NOSUID = 0x2 # SFTPv3 error codes FX_OK = 0 FX_EOF = 1 FX_NO_SUCH_FILE = 2 FX_PERMISSION_DENIED = 3 FX_FAILURE = 4 FX_BAD_MESSAGE = 5 FX_NO_CONNECTION = 6 FX_CONNECTION_LOST = 7 FX_OP_UNSUPPORTED = 8 FX_V3_END = FX_OP_UNSUPPORTED # SFTPv4 error codes FX_INVALID_HANDLE = 9 FX_NO_SUCH_PATH = 10 FX_FILE_ALREADY_EXISTS = 11 FX_WRITE_PROTECT = 12 FX_NO_MEDIA = 13 FX_V4_END = FX_NO_MEDIA # SFTPv5 error codes FX_NO_SPACE_ON_FILESYSTEM = 14 FX_QUOTA_EXCEEDED = 15 FX_UNKNOWN_PRINCIPAL = 16 FX_LOCK_CONFLICT = 17 FX_V5_END = FX_LOCK_CONFLICT # SFTPv6 error codes FX_DIR_NOT_EMPTY = 18 FX_NOT_A_DIRECTORY = 19 FX_INVALID_FILENAME = 20 FX_LINK_LOOP = 21 FX_CANNOT_DELETE = 22 FX_INVALID_PARAMETER = 23 FX_FILE_IS_A_DIRECTORY = 24 FX_BYTE_RANGE_LOCK_CONFLICT = 25 FX_BYTE_RANGE_LOCK_REFUSED = 26 FX_DELETE_PENDING = 27 FX_FILE_CORRUPT = 28 FX_OWNER_INVALID = 29 FX_GROUP_INVALID = 30 FX_NO_MATCHING_BYTE_RANGE_LOCK = 31 FX_V6_END = FX_NO_MATCHING_BYTE_RANGE_LOCK # SSH channel data type codes EXTENDED_DATA_STDERR = 1 # SSH pty mode opcodes PTY_OP_END = 0 PTY_VINTR = 1 PTY_VQUIT = 2 PTY_VERASE = 3 PTY_VKILL = 4 PTY_VEOF = 5 PTY_VEOL = 6 PTY_VEOL2 = 7 PTY_VSTART = 8 PTY_VSTOP = 9 PTY_VSUSP = 10 PTY_VDSUSP = 11 PTY_VREPRINT = 12 PTY_WERASE = 13 PTY_VLNEXT = 14 PTY_VFLUSH = 15 PTY_VSWTCH = 16 PTY_VSTATUS = 17 PTY_VDISCARD = 18 PTY_IGNPAR = 30 PTY_PARMRK = 31 PTY_INPCK = 32 PTY_ISTRIP = 33 PTY_INLCR = 34 PTY_IGNCR = 35 PTY_ICRNL = 36 PTY_IUCLC = 37 PTY_IXON = 38 PTY_IXANY = 39 PTY_IXOFF = 40 PTY_IMAXBEL = 41 PTY_IUTF8 = 42 PTY_ISIG = 50 PTY_ICANON = 51 PTY_XCASE = 52 PTY_ECHO = 53 PTY_ECHOE = 54 PTY_ECHOK = 55 PTY_ECHONL = 56 PTY_NOFLSH = 57 PTY_TOSTOP = 58 PTY_IEXTEN = 59 PTY_ECHOCTL = 60 PTY_ECHOKE = 61 PTY_PENDIN = 62 PTY_OPOST = 70 PTY_OLCUC = 71 PTY_ONLCR = 72 PTY_OCRNL = 73 PTY_ONOCR = 74 PTY_ONLRET = 75 PTY_CS7 = 90 PTY_CS8 = 91 PTY_PARENB = 92 PTY_PARODD = 93 PTY_OP_ISPEED = 128 PTY_OP_OSPEED = 129 PTY_OP_RESERVED = 160 asyncssh-2.20.0/asyncssh/crypto/000077500000000000000000000000001475467777400166355ustar00rootroot00000000000000asyncssh-2.20.0/asyncssh/crypto/__init__.py000066400000000000000000000051061475467777400207500ustar00rootroot00000000000000# Copyright (c) 2014-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim for accessing cryptographic primitives needed by asyncssh""" from .cipher import BasicCipher, GCMCipher, register_cipher, get_cipher_params from .dsa import DSAPrivateKey, DSAPublicKey from .dh import DH from .ec import ECDSAPrivateKey, ECDSAPublicKey, ECDH from .ed import ed25519_available, ed448_available from .ed import curve25519_available, curve448_available from .ed import EdDSAPrivateKey, EdDSAPublicKey, Curve25519DH, Curve448DH from .ec_params import lookup_ec_curve_by_params from .kdf import pbkdf2_hmac from .misc import CryptoKey, PyCAKey from .rsa import RSAPrivateKey, RSAPublicKey from .pq import mlkem_available, sntrup_available, PQDH # Import chacha20-poly1305 cipher if available from .chacha import ChachaCipher, chacha_available # Import umac cryptographic hash if available try: from .umac import umac32, umac64, umac96, umac128 except (ImportError, AttributeError, OSError): # pragma: no cover pass # Import X.509 certificate support if available try: from .x509 import X509Certificate, X509Name, X509NamePattern from .x509 import generate_x509_certificate, import_x509_certificate except (ImportError, AttributeError): # pragma: no cover pass __all__ = [ 'BasicCipher', 'ChachaCipher', 'CryptoKey', 'Curve25519DH', 'Curve448DH', 'DH', 'DSAPrivateKey', 'DSAPublicKey', 'ECDH', 'ECDSAPrivateKey', 'ECDSAPublicKey', 'EdDSAPrivateKey', 'EdDSAPublicKey', 'GCMCipher', 'PQDH', 'PyCAKey', 'RSAPrivateKey', 'RSAPublicKey', 'chacha_available', 'curve25519_available', 'curve448_available', 'X509Certificate', 'X509Name', 'X509NamePattern', 'ed25519_available', 'ed448_available', 'generate_x509_certificate', 'get_cipher_params', 'import_x509_certificate', 'lookup_ec_curve_by_params', 'mlkem_available', 'pbkdf2_hmac', 'register_cipher', 'sntrup_available', 'umac32', 'umac64', 'umac96', 'umac128' ] asyncssh-2.20.0/asyncssh/crypto/chacha.py000066400000000000000000000127051475467777400204230ustar00rootroot00000000000000# Copyright (c) 2015-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Chacha20-Poly1305 symmetric encryption handler""" from ctypes import c_ulonglong, create_string_buffer from typing import Optional, Tuple from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends.openssl import backend from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers.algorithms import ChaCha20 from cryptography.hazmat.primitives.poly1305 import Poly1305 from .cipher import register_cipher if backend.poly1305_supported(): _CTR_0 = (0).to_bytes(8, 'little') _CTR_1 = (1).to_bytes(8, 'little') _POLY1305_KEYBYTES = 32 def chacha20(key: bytes, data: bytes, nonce: bytes, ctr: int) -> bytes: """Encrypt/decrypt a block of data with the ChaCha20 cipher""" return Cipher(ChaCha20(key, (_CTR_1 if ctr else _CTR_0) + nonce), mode=None).encryptor().update(data) def poly1305_key(key: bytes, nonce: bytes) -> bytes: """Derive a Poly1305 key""" return chacha20(key, _POLY1305_KEYBYTES * b'\0', nonce, 0) def poly1305(key: bytes, data: bytes, nonce: bytes) -> bytes: """Compute a Poly1305 tag for a block of data""" return Poly1305.generate_tag(poly1305_key(key, nonce), data) def poly1305_verify(key: bytes, data: bytes, nonce: bytes, tag: bytes) -> bool: """Verify a Poly1305 tag for a block of data""" try: Poly1305.verify_tag(poly1305_key(key, nonce), data, tag) return True except InvalidSignature: return False chacha_available = True else: # pragma: no cover try: from libnacl import nacl _chacha20 = nacl.crypto_stream_chacha20 _chacha20_xor_ic = nacl.crypto_stream_chacha20_xor_ic _POLY1305_BYTES = nacl.crypto_onetimeauth_poly1305_bytes() _POLY1305_KEYBYTES = nacl.crypto_onetimeauth_poly1305_keybytes() _poly1305 = nacl.crypto_onetimeauth_poly1305 _poly1305_verify = nacl.crypto_onetimeauth_poly1305_verify def chacha20(key: bytes, data: bytes, nonce: bytes, ctr: int) -> bytes: """Encrypt/decrypt a block of data with the ChaCha20 cipher""" datalen = len(data) result = create_string_buffer(datalen) ull_datalen = c_ulonglong(datalen) ull_ctr = c_ulonglong(ctr) _chacha20_xor_ic(result, data, ull_datalen, nonce, ull_ctr, key) return result.raw def poly1305_key(key: bytes, nonce: bytes) -> bytes: """Derive a Poly1305 key""" polykey = create_string_buffer(_POLY1305_KEYBYTES) ull_polykeylen = c_ulonglong(_POLY1305_KEYBYTES) _chacha20(polykey, ull_polykeylen, nonce, key) return polykey.raw def poly1305(key: bytes, data: bytes, nonce: bytes) -> bytes: """Compute a Poly1305 tag for a block of data""" tag = create_string_buffer(_POLY1305_BYTES) ull_datalen = c_ulonglong(len(data)) polykey = poly1305_key(key, nonce) _poly1305(tag, data, ull_datalen, polykey) return tag.raw def poly1305_verify(key: bytes, data: bytes, nonce: bytes, tag: bytes) -> bool: """Verify a Poly1305 tag for a block of data""" ull_datalen = c_ulonglong(len(data)) polykey = poly1305_key(key, nonce) return _poly1305_verify(tag, data, ull_datalen, polykey) == 0 chacha_available = True except (ImportError, OSError, AttributeError): chacha_available = False class ChachaCipher: """Shim for Chacha20-Poly1305 symmetric encryption""" def __init__(self, key: bytes): keylen = len(key) // 2 self._key = key[:keylen] self._adkey = key[keylen:] def encrypt_and_sign(self, header: bytes, data: bytes, nonce: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign a block of data""" header = chacha20(self._adkey, header, nonce, 0) data = chacha20(self._key, data, nonce, 1) tag = poly1305(self._key, header + data, nonce) return header + data, tag def decrypt_header(self, header: bytes, nonce: bytes) -> bytes: """Decrypt header data""" return chacha20(self._adkey, header, nonce, 0) def verify_and_decrypt(self, header: bytes, data: bytes, nonce: bytes, tag: bytes) -> Optional[bytes]: """Verify the signature of and decrypt a block of data""" if poly1305_verify(self._key, header + data, nonce, tag): return chacha20(self._key, data, nonce, 1) else: return None if chacha_available: # pragma: no branch register_cipher('chacha20-poly1305', 64, 0, 1) asyncssh-2.20.0/asyncssh/crypto/cipher.py000066400000000000000000000143001475467777400204570ustar00rootroot00000000000000# Copyright (c) 2014-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for accessing symmetric ciphers needed by AsyncSSH""" from types import ModuleType from typing import Any, MutableMapping, Optional, Tuple import warnings from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers import Cipher, CipherContext from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.modes import CBC, CTR import cryptography.hazmat.primitives.ciphers.algorithms as _algs _decrepit_algs: Optional[ModuleType] try: import cryptography.hazmat.decrepit.ciphers.algorithms as _decrepit_algs except ImportError: # pragma: no cover _decrepit_algs = None _CipherAlgs = Tuple[Any, Any, int] _CipherParams = Tuple[int, int, int] _GCM_MAC_SIZE = 16 _cipher_algs: MutableMapping[str, _CipherAlgs] = {} _cipher_params: MutableMapping[str, _CipherParams] = {} class BasicCipher: """Shim for basic ciphers""" def __init__(self, cipher_name: str, key: bytes, iv: bytes): cipher, mode, initial_bytes = _cipher_algs[cipher_name] self._cipher = Cipher(cipher(key), mode(iv) if mode else None) self._initial_bytes = initial_bytes self._encryptor: Optional[CipherContext] = None self._decryptor: Optional[CipherContext] = None def encrypt(self, data: bytes) -> bytes: """Encrypt a block of data""" if not self._encryptor: self._encryptor = self._cipher.encryptor() if self._initial_bytes: assert self._encryptor is not None self._encryptor.update(self._initial_bytes * b'\0') assert self._encryptor is not None return self._encryptor.update(data) def decrypt(self, data: bytes) -> bytes: """Decrypt a block of data""" if not self._decryptor: self._decryptor = self._cipher.decryptor() if self._initial_bytes: assert self._decryptor is not None self._decryptor.update(self._initial_bytes * b'\0') assert self._decryptor is not None return self._decryptor.update(data) class GCMCipher: """Shim for GCM ciphers""" def __init__(self, cipher_name: str, key: bytes, iv: bytes): self._cipher = _cipher_algs[cipher_name][0] self._key = key self._iv = iv def _update_iv(self) -> None: """Update the IV after each encrypt/decrypt operation""" invocation = int.from_bytes(self._iv[4:], 'big') invocation = (invocation + 1) & 0xffffffffffffffff self._iv = self._iv[:4] + invocation.to_bytes(8, 'big') def encrypt_and_sign(self, header: bytes, data: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign a block of data""" data = AESGCM(self._key).encrypt(self._iv, data, header) self._update_iv() return header + data[:-_GCM_MAC_SIZE], data[-_GCM_MAC_SIZE:] def verify_and_decrypt(self, header: bytes, data: bytes, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt a block of data""" try: decrypted_data: Optional[bytes] = \ AESGCM(self._key).decrypt(self._iv, data + mac, header) except InvalidTag: decrypted_data = None self._update_iv() return decrypted_data def register_cipher(cipher_name: str, key_size: int, iv_size: int, block_size: int) -> None: """Register a symmetric cipher""" _cipher_params[cipher_name] = (key_size, iv_size, block_size) def get_cipher_params(cipher_name: str) -> _CipherParams: """Get parameters of a symmetric cipher""" return _cipher_params[cipher_name] _cipher_alg_list = ( ('aes128-cbc', 'AES', CBC, 0, 16, 16, 16), ('aes192-cbc', 'AES', CBC, 0, 24, 16, 16), ('aes256-cbc', 'AES', CBC, 0, 32, 16, 16), ('aes128-ctr', 'AES', CTR, 0, 16, 16, 16), ('aes192-ctr', 'AES', CTR, 0, 24, 16, 16), ('aes256-ctr', 'AES', CTR, 0, 32, 16, 16), ('aes128-gcm', None, None, 0, 16, 12, 16), ('aes256-gcm', None, None, 0, 32, 12, 16), ('arcfour', 'ARC4', None, 0, 16, 1, 1), ('arcfour40', 'ARC4', None, 0, 5, 1, 1), ('arcfour128', 'ARC4', None, 1536, 16, 1, 1), ('arcfour256', 'ARC4', None, 1536, 32, 1, 1), ('blowfish-cbc', 'Blowfish', CBC, 0, 16, 8, 8), ('cast128-cbc', 'CAST5', CBC, 0, 16, 8, 8), ('des-cbc', 'TripleDES', CBC, 0, 8, 8, 8), ('des2-cbc', 'TripleDES', CBC, 0, 16, 8, 8), ('des3-cbc', 'TripleDES', CBC, 0, 24, 8, 8), ('seed-cbc', 'SEED', CBC, 0, 16, 16, 16) ) with warnings.catch_warnings(): warnings.simplefilter('ignore') for _cipher_name, _alg, _mode, _initial_bytes, \ _key_size, _iv_size, _block_size in _cipher_alg_list: if _alg: try: _cipher = getattr(_algs, _alg) except AttributeError as exc: # pragma: no cover if _decrepit_algs: try: _cipher = getattr(_decrepit_algs, _alg) except AttributeError: raise exc from None else: raise else: _cipher = None _cipher_algs[_cipher_name] = (_cipher, _mode, _initial_bytes) register_cipher(_cipher_name, _key_size, _iv_size, _block_size) asyncssh-2.20.0/asyncssh/crypto/dh.py000066400000000000000000000030131475467777400175770ustar00rootroot00000000000000# Copyright (c) 2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for Diffie Hellman key exchange""" from cryptography.hazmat.primitives.asymmetric import dh class DH: """A shim around PyCA for Diffie Hellman key exchange""" def __init__(self, g: int, p: int): self._pn = dh.DHParameterNumbers(p, g) self._priv_key = self._pn.parameters().generate_private_key() def get_public(self) -> int: """Return the public key to send in the handshake""" pub_key = self._priv_key.public_key() return pub_key.public_numbers().y def get_shared(self, peer_public: int) -> int: """Return the shared key from the peer's public key""" peer_key = dh.DHPublicNumbers(peer_public, self._pn).public_key() shared_key = self._priv_key.exchange(peer_key) return int.from_bytes(shared_key, 'big') asyncssh-2.20.0/asyncssh/crypto/dsa.py000066400000000000000000000073431475467777400177650ustar00rootroot00000000000000# Copyright (c) 2014-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for DSA public and private keys""" from typing import Optional, cast from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import dsa from .misc import CryptoKey, PyCAKey, hashes # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _DSAKey(CryptoKey): """Base class for shim around PyCA for DSA keys""" def __init__(self, pyca_key: PyCAKey, params: dsa.DSAParameterNumbers, pub: dsa.DSAPublicNumbers, priv: Optional[dsa.DSAPrivateNumbers] = None): super().__init__(pyca_key) self._params = params self._pub = pub self._priv = priv @property def p(self) -> int: """Return the DSA public modulus""" return self._params.p @property def q(self) -> int: """Return the DSA sub-group order""" return self._params.q @property def g(self) -> int: """Return the DSA generator""" return self._params.g @property def y(self) -> int: """Return the DSA public value""" return self._pub.y @property def x(self) -> Optional[int]: """Return the DSA private value""" return self._priv.x if self._priv else None class DSAPrivateKey(_DSAKey): """A shim around PyCA for DSA private keys""" @classmethod def construct(cls, p: int, q: int, g: int, y: int, x: int) -> 'DSAPrivateKey': """Construct a DSA private key""" params = dsa.DSAParameterNumbers(p, q, g) pub = dsa.DSAPublicNumbers(y, params) priv = dsa.DSAPrivateNumbers(x, pub) priv_key = priv.private_key() return cls(priv_key, params, pub, priv) @classmethod def generate(cls, key_size: int) -> 'DSAPrivateKey': """Generate a new DSA private key""" priv_key = dsa.generate_private_key(key_size) priv = priv_key.private_numbers() pub = priv.public_numbers params = pub.parameter_numbers return cls(priv_key, params, pub, priv) def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" priv_key = cast('dsa.DSAPrivateKey', self.pyca_key) return priv_key.sign(data, hashes[hash_name]()) class DSAPublicKey(_DSAKey): """A shim around PyCA for DSA public keys""" @classmethod def construct(cls, p: int, q: int, g: int, y: int) -> 'DSAPublicKey': """Construct a DSA public key""" params = dsa.DSAParameterNumbers(p, q, g) pub = dsa.DSAPublicNumbers(y, params) pub_key = pub.public_key() return cls(pub_key, params, pub) def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" try: pub_key = cast('dsa.DSAPublicKey', self.pyca_key) pub_key.verify(sig, data, hashes[hash_name]()) return True except InvalidSignature: return False asyncssh-2.20.0/asyncssh/crypto/ec.py000066400000000000000000000146421475467777400176050ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for elliptic curve keys and key exchange""" from typing import Mapping, Optional, Type, cast from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.serialization import Encoding from cryptography.hazmat.primitives.serialization import PublicFormat from .misc import CryptoKey, PyCAKey, hashes # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name _curves: Mapping[bytes, Type[ec.EllipticCurve]] = { b'1.3.132.0.10': ec.SECP256K1, b'nistp256': ec.SECP256R1, b'nistp384': ec.SECP384R1, b'nistp521': ec.SECP521R1 } class _ECKey(CryptoKey): """Base class for shim around PyCA for EC keys""" def __init__(self, pyca_key: PyCAKey, curve_id: bytes, pub: ec.EllipticCurvePublicNumbers, point: bytes, priv: Optional[ec.EllipticCurvePrivateNumbers] = None): super().__init__(pyca_key) self._curve_id = curve_id self._pub = pub self._point = point self._priv = priv @classmethod def lookup_curve(cls, curve_id: bytes) -> Type[ec.EllipticCurve]: """Look up curve and hash algorithm""" try: return _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered raise ValueError(f'Unknown EC curve {curve_id.decode()}') from None @property def curve_id(self) -> bytes: """Return the EC curve name""" return self._curve_id @property def x(self) -> int: """Return the EC public x coordinate""" return self._pub.x @property def y(self) -> int: """Return the EC public y coordinate""" return self._pub.y @property def d(self) -> Optional[int]: """Return the EC private value as an integer""" return self._priv.private_value if self._priv else None @property def public_value(self) -> bytes: """Return the EC public point value encoded as a byte string""" return self._point @property def private_value(self) -> Optional[bytes]: """Return the EC private value encoded as a byte string""" if self._priv: keylen = (self._pub.curve.key_size + 7) // 8 return self._priv.private_value.to_bytes(keylen, 'big') else: return None class ECDSAPrivateKey(_ECKey): """A shim around PyCA for ECDSA private keys""" @classmethod def construct(cls, curve_id: bytes, public_value: bytes, private_value: int) -> 'ECDSAPrivateKey': """Construct an ECDSA private key""" curve = cls.lookup_curve(curve_id) priv_key = ec.derive_private_key(private_value, curve()) priv = priv_key.private_numbers() pub = priv.public_numbers return cls(priv_key, curve_id, pub, public_value, priv) @classmethod def generate(cls, curve_id: bytes) -> 'ECDSAPrivateKey': """Generate a new ECDSA private key""" curve = cls.lookup_curve(curve_id) priv_key = ec.generate_private_key(curve()) priv = priv_key.private_numbers() pub_key = priv_key.public_key() pub = pub_key.public_numbers() public_value = pub_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) return cls(priv_key, curve_id, pub, public_value, priv) def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" # pylint: disable=unused-argument priv_key = cast('ec.EllipticCurvePrivateKey', self.pyca_key) return priv_key.sign(data, ec.ECDSA(hashes[hash_name]())) class ECDSAPublicKey(_ECKey): """A shim around PyCA for ECDSA public keys""" @classmethod def construct(cls, curve_id: bytes, public_value: bytes) -> 'ECDSAPublicKey': """Construct an ECDSA public key""" curve = cls.lookup_curve(curve_id) pub_key = ec.EllipticCurvePublicKey.from_encoded_point(curve(), public_value) pub = pub_key.public_numbers() return cls(pub_key, curve_id, pub, public_value) def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" try: pub_key = cast('ec.EllipticCurvePublicKey', self.pyca_key) pub_key.verify(sig, data, ec.ECDSA(hashes[hash_name]())) return True except InvalidSignature: return False class ECDH: """A shim around PyCA for ECDH key exchange""" def __init__(self, curve_id: bytes): try: curve = _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered raise ValueError(f'Unknown EC curve {curve_id.decode()}') from None self._priv_key = ec.generate_private_key(curve()) def get_public(self) -> bytes: """Return the public key to send in the handshake""" pub_key = self._priv_key.public_key() return pub_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) def get_shared_bytes(self, peer_public: bytes) -> bytes: """Return the shared key from the peer's public key as bytes""" peer_key = ec.EllipticCurvePublicKey.from_encoded_point( self._priv_key.curve, peer_public) return self._priv_key.exchange(ec.ECDH(), peer_key) def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" return int.from_bytes(self.get_shared_bytes(peer_public), 'big') asyncssh-2.20.0/asyncssh/crypto/ec_params.py000066400000000000000000000113041475467777400211400ustar00rootroot00000000000000# Copyright (c) 2013-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Functions for looking up named elliptic curves by their parameters""" _curve_param_map = {} # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name def register_prime_curve(curve_id: bytes, p: int, a: int, b: int, point: bytes, n: int) -> None: """Register an elliptic curve prime domain This function registers an elliptic curve prime domain by specifying the SSH identifier for the curve and the set of parameters describing the curve, generator point, and order. This allows EC keys encoded with explicit parameters to be mapped back into their SSH curve IDs. """ _curve_param_map[p, a % p, b % p, point, n] = curve_id def lookup_ec_curve_by_params(p: int, a: int, b: int, point: bytes, n: int) -> bytes: """Look up an elliptic curve by its parameters This function looks up an elliptic curve by its parameters and returns the curve's name. """ try: return _curve_param_map[p, a % p, b % p, point, n] except (KeyError, ValueError): raise ValueError('Unknown elliptic curve parameters') from None # pylint: disable=line-too-long register_prime_curve(b'nistp521', 6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151, -3, 1093849038073734274511112390766805569936207598951683748994586394495953116150735016013708737573759623248592132296706313309438452531591012912142327488478985984, b'\x04\x00\xc6\x85\x8e\x06\xb7\x04\x04\xe9\xcd\x9e>\xcbf#\x95\xb4B\x9cd\x819\x05?\xb5!\xf8(\xaf`kM=\xba\xa1K^w\xef\xe7Y(\xfe\x1d\xc1\'\xa2\xff\xa8\xde3H\xb3\xc1\x85jB\x9b\xf9~~1\xc2\xe5\xbdf\x01\x189)jx\x9a;\xc0\x04\\\x8a_\xb4,}\x1b\xd9\x98\xf5DIW\x9bDh\x17\xaf\xbd\x17\'>f,\x97\xeer\x99^\xf4&@\xc5P\xb9\x01?\xad\x07a5 and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA and libnacl for Edwards-curve keys and key exchange""" import ctypes import os from typing import Dict, Optional, Union, cast from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends.openssl import backend from cryptography.hazmat.primitives.asymmetric import ed25519, ed448 from cryptography.hazmat.primitives.asymmetric import x25519, x448 from cryptography.hazmat.primitives.serialization import Encoding from cryptography.hazmat.primitives.serialization import PrivateFormat from cryptography.hazmat.primitives.serialization import PublicFormat from cryptography.hazmat.primitives.serialization import NoEncryption from .misc import CryptoKey, PyCAKey _EdPrivateKey = Union[ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey] _EdPublicKey = Union[ed25519.Ed25519PublicKey, ed448.Ed448PublicKey] ed25519_available = backend.ed25519_supported() ed448_available = backend.ed448_supported() curve25519_available = backend.x25519_supported() curve448_available = backend.x448_supported() if ed25519_available or ed448_available: # pragma: no branch class _EdDSAKey(CryptoKey): """Base class for shim around PyCA for EdDSA keys""" def __init__(self, pyca_key: PyCAKey, pub: bytes, priv: Optional[bytes] = None): super().__init__(pyca_key) self._pub = pub self._priv = priv @property def public_value(self) -> bytes: """Return the public value encoded as a byte string""" return self._pub @property def private_value(self) -> Optional[bytes]: """Return the private value encoded as a byte string""" return self._priv class EdDSAPrivateKey(_EdDSAKey): """A shim around PyCA for EdDSA private keys""" _priv_classes: Dict[bytes, object] = {} if ed25519_available: # pragma: no branch _priv_classes[b'ed25519'] = ed25519.Ed25519PrivateKey if ed448_available: # pragma: no branch _priv_classes[b'ed448'] = ed448.Ed448PrivateKey @classmethod def construct(cls, curve_id: bytes, priv: bytes) -> 'EdDSAPrivateKey': """Construct an EdDSA private key""" priv_cls = cast('_EdPrivateKey', cls._priv_classes[curve_id]) priv_key = priv_cls.from_private_bytes(priv) pub_key = priv_key.public_key() pub = pub_key.public_bytes(Encoding.Raw, PublicFormat.Raw) return cls(priv_key, pub, priv) @classmethod def generate(cls, curve_id: bytes) -> 'EdDSAPrivateKey': """Generate a new EdDSA private key""" priv_cls = cast('_EdPrivateKey', cls._priv_classes[curve_id]) priv_key = priv_cls.generate() priv = priv_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption()) pub_key = priv_key.public_key() pub = pub_key.public_bytes(Encoding.Raw, PublicFormat.Raw) return cls(priv_key, pub, priv) def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" # pylint: disable=unused-argument priv_key = cast('_EdPrivateKey', self.pyca_key) return priv_key.sign(data) class EdDSAPublicKey(_EdDSAKey): """A shim around PyCA for EdDSA public keys""" _pub_classes: Dict[bytes, object] = { b'ed25519': ed25519.Ed25519PublicKey, b'ed448': ed448.Ed448PublicKey } @classmethod def construct(cls, curve_id: bytes, pub: bytes) -> 'EdDSAPublicKey': """Construct an EdDSA public key""" pub_cls = cast('_EdPublicKey', cls._pub_classes[curve_id]) pub_key = pub_cls.from_public_bytes(pub) return cls(pub_key, pub) def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" # pylint: disable=unused-argument try: pub_key = cast('_EdPublicKey', self.pyca_key) pub_key.verify(sig, data) return True except InvalidSignature: return False else: # pragma: no cover class _EdDSANaclKey: """Base class for shim around libnacl for EdDSA keys""" def __init__(self, pub: bytes, priv: Optional[bytes] = None): self._pub = pub self._priv = priv @property def public_value(self) -> bytes: """Return the public value encoded as a byte string""" return self._pub @property def private_value(self) -> Optional[bytes]: """Return the private value encoded as a byte string""" return self._priv[:-len(self._pub)] if self._priv else None class EdDSAPrivateKey(_EdDSANaclKey): # type: ignore """A shim around libnacl for EdDSA private keys""" @classmethod def construct(cls, curve_id: bytes, priv: bytes) -> 'EdDSAPrivateKey': """Construct an EdDSA private key""" # pylint: disable=unused-argument return cls(*_ed25519_construct_keypair(priv)) @classmethod def generate(cls, curve_id: str) -> 'EdDSAPrivateKey': """Generate a new EdDSA private key""" # pylint: disable=unused-argument return cls(*_ed25519_generate_keypair()) def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" # pylint: disable=unused-argument assert self._priv is not None return _ed25519_sign(data, self._priv)[:-len(data)] class EdDSAPublicKey(_EdDSANaclKey): # type: ignore """A shim around libnacl for EdDSA public keys""" @classmethod def construct(cls, curve_id: bytes, pub: bytes) -> 'EdDSAPublicKey': """Construct an EdDSA public key""" # pylint: disable=unused-argument if len(pub) != _ED25519_PUBLIC_BYTES: raise ValueError('Invalid EdDSA public key') return cls(pub) def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" # pylint: disable=unused-argument try: return _ed25519_verify(sig + data, self._pub) == data except ValueError: return False try: import libnacl _ED25519_PUBLIC_BYTES = libnacl.crypto_sign_ed25519_PUBLICKEYBYTES _ed25519_construct_keypair = libnacl.crypto_sign_seed_keypair _ed25519_generate_keypair = libnacl.crypto_sign_keypair _ed25519_sign = libnacl.crypto_sign _ed25519_verify = libnacl.crypto_sign_open ed25519_available = True except (ImportError, OSError, AttributeError): pass if curve25519_available: # pragma: no branch class Curve25519DH: """Curve25519 Diffie Hellman implementation based on PyCA""" def __init__(self) -> None: self._priv_key = x25519.X25519PrivateKey.generate() def get_public(self) -> bytes: """Return the public key to send in the handshake""" return self._priv_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) def get_shared_bytes(self, peer_public: bytes) -> bytes: """Return the shared key from the peer's public key as bytes""" peer_key = x25519.X25519PublicKey.from_public_bytes(peer_public) return self._priv_key.exchange(peer_key) def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" return int.from_bytes(self.get_shared_bytes(peer_public), 'big') else: # pragma: no cover class Curve25519DH: # type: ignore """Curve25519 Diffie Hellman implementation based on libnacl""" def __init__(self) -> None: self._private = os.urandom(_CURVE25519_SCALARBYTES) def get_public(self) -> bytes: """Return the public key to send in the handshake""" public = ctypes.create_string_buffer(_CURVE25519_BYTES) if _curve25519_base(public, self._private) != 0: # This error is never returned by libsodium raise ValueError('Curve25519 failed') # pragma: no cover return public.raw def get_shared_bytes(self, peer_public: bytes) -> bytes: """Return the shared key from the peer's public key as bytes""" if len(peer_public) != _CURVE25519_BYTES: raise ValueError('Invalid curve25519 public key size') shared = ctypes.create_string_buffer(_CURVE25519_BYTES) if _curve25519(shared, self._private, peer_public) != 0: raise ValueError('Curve25519 failed') return shared.raw def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" return int.from_bytes(self.get_shared_bytes(peer_public), 'big') try: from libnacl import nacl _CURVE25519_BYTES = nacl.crypto_scalarmult_curve25519_bytes() _CURVE25519_SCALARBYTES = \ nacl.crypto_scalarmult_curve25519_scalarbytes() _curve25519 = nacl.crypto_scalarmult_curve25519 _curve25519_base = nacl.crypto_scalarmult_curve25519_base curve25519_available = True except (ImportError, OSError, AttributeError): pass class Curve448DH: """Curve448 Diffie Hellman implementation based on PyCA""" def __init__(self) -> None: self._priv_key = x448.X448PrivateKey.generate() def get_public(self) -> bytes: """Return the public key to send in the handshake""" return self._priv_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" peer_key = x448.X448PublicKey.from_public_bytes(peer_public) shared = self._priv_key.exchange(peer_key) return int.from_bytes(shared, 'big') asyncssh-2.20.0/asyncssh/crypto/kdf.py000066400000000000000000000022071475467777400177540ustar00rootroot00000000000000# Copyright (c) 2017-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for key derivation functions""" from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from .misc import hashes def pbkdf2_hmac(hash_name: str, passphrase: bytes, salt: bytes, count: int, key_size: int) -> bytes: """A shim around PyCA for PBKDF2 HMAC key derivation""" return PBKDF2HMAC(hashes[hash_name](), key_size, salt, count).derive(passphrase) asyncssh-2.20.0/asyncssh/crypto/misc.py000066400000000000000000000045411475467777400201460ustar00rootroot00000000000000# Copyright (c) 2017-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Miscellaneous PyCA utility classes and functions""" from typing import Callable, Mapping, Union from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa from cryptography.hazmat.primitives.asymmetric import ed25519, ed448 from cryptography.hazmat.primitives.hashes import HashAlgorithm from cryptography.hazmat.primitives.hashes import MD5, SHA1, SHA224 from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512 PyCAPrivateKey = Union[dsa.DSAPrivateKey, rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey, ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey] PyCAPublicKey = Union[dsa.DSAPublicKey, rsa.RSAPublicKey, ec.EllipticCurvePublicKey, ed25519.Ed25519PublicKey, ed448.Ed448PublicKey] PyCAKey = Union[PyCAPrivateKey, PyCAPublicKey] hashes: Mapping[str, Callable[[], HashAlgorithm]] = { str(h.name): h for h in (MD5, SHA1, SHA224, SHA256, SHA384, SHA512) } class CryptoKey: """Base class for PyCA private/public keys""" def __init__(self, pyca_key: PyCAKey): self._pyca_key = pyca_key @property def pyca_key(self) -> PyCAKey: """Return the PyCA object associated with this key""" return self._pyca_key def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" # pylint: disable=no-self-use raise RuntimeError # pragma: no cover def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" # pylint: disable=no-self-use raise RuntimeError # pragma: no cover asyncssh-2.20.0/asyncssh/crypto/pq.py000066400000000000000000000067151475467777400176400ustar00rootroot00000000000000# Copyright (c) 2022-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around liboqs for Streamlined NTRU Prime post-quantum encryption""" import ctypes import ctypes.util from typing import Mapping, Tuple _pq_algs: Mapping[bytes, Tuple[int, int, int, int, str]] = { b'mlkem768': (1184, 2400, 1088, 32, 'KEM_ml_kem_768'), b'mlkem1024': (1568, 3168, 1568, 32, 'KEM_ml_kem_1024'), b'sntrup761': (1158, 1763, 1039, 32, 'KEM_ntruprime_sntrup761') } mlkem_available = False sntrup_available = False for lib in ('oqs', 'liboqs'): _oqs_lib = ctypes.util.find_library(lib) if _oqs_lib: # pragma: no branch break else: # pragma: no cover _oqs_lib = None if _oqs_lib: # pragma: no branch _oqs = ctypes.cdll.LoadLibrary(_oqs_lib) mlkem_available = (hasattr(_oqs, 'OQS_KEM_ml_kem_768_keypair') or hasattr(_oqs, 'OQS_KEM_ml_kem_768_ipd_keypair')) sntrup_available = hasattr(_oqs, 'OQS_KEM_ntruprime_sntrup761_keypair') class PQDH: """A shim around liboqs for post-quantum key exchange algorithms""" def __init__(self, alg_name: bytes): try: self.pubkey_bytes, self.privkey_bytes, \ self.ciphertext_bytes, self.secret_bytes, \ oqs_name = _pq_algs[alg_name] except KeyError: # pragma: no cover, other algs not registered raise ValueError(f'Unknown PQ algorithm {oqs_name}') from None if not hasattr(_oqs, 'OQS_' + oqs_name + '_keypair'): # pragma: no cover oqs_name += '_ipd' self._keypair = getattr(_oqs, 'OQS_' + oqs_name + '_keypair') self._encaps = getattr(_oqs, 'OQS_' + oqs_name + '_encaps') self._decaps = getattr(_oqs, 'OQS_' + oqs_name + '_decaps') def keypair(self) -> Tuple[bytes, bytes]: """Make a new key pair""" pubkey = ctypes.create_string_buffer(self.pubkey_bytes) privkey = ctypes.create_string_buffer(self.privkey_bytes) self._keypair(pubkey, privkey) return pubkey.raw, privkey.raw def encaps(self, pubkey: bytes) -> Tuple[bytes, bytes]: """Generate a random secret and encrypt it with a public key""" if len(pubkey) != self.pubkey_bytes: raise ValueError('Invalid public key') ciphertext = ctypes.create_string_buffer(self.ciphertext_bytes) secret = ctypes.create_string_buffer(self.secret_bytes) self._encaps(ciphertext, secret, pubkey) return secret.raw, ciphertext.raw def decaps(self, ciphertext: bytes, privkey: bytes) -> bytes: """Decrypt an encrypted secret using a private key""" if len(ciphertext) != self.ciphertext_bytes: raise ValueError('Invalid ciphertext') secret = ctypes.create_string_buffer(self.secret_bytes) self._decaps(secret, ciphertext, privkey) return secret.raw asyncssh-2.20.0/asyncssh/crypto/rsa.py000066400000000000000000000120201475467777400177670ustar00rootroot00000000000000# Copyright (c) 2014-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA for RSA public and private keys""" from typing import Optional, cast from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.padding import MGF1, OAEP from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.asymmetric import rsa from .misc import CryptoKey, PyCAKey, hashes # Short variable names are used here, matching names in the spec # pylint: disable=invalid-name class _RSAKey(CryptoKey): """Base class for shim around PyCA for RSA keys""" def __init__(self, pyca_key: PyCAKey, pub: rsa.RSAPublicNumbers, priv: Optional[rsa.RSAPrivateNumbers] = None): super().__init__(pyca_key) self._pub = pub self._priv = priv @property def n(self) -> int: """Return the RSA public modulus""" return self._pub.n @property def e(self) -> int: """Return the RSA public exponent""" return self._pub.e @property def d(self) -> Optional[int]: """Return the RSA private exponent""" return self._priv.d if self._priv else None @property def p(self) -> Optional[int]: """Return the RSA first private prime""" return self._priv.p if self._priv else None @property def q(self) -> Optional[int]: """Return the RSA second private prime""" return self._priv.q if self._priv else None @property def dmp1(self) -> Optional[int]: """Return d modulo p-1""" return self._priv.dmp1 if self._priv else None @property def dmq1(self) -> Optional[int]: """Return q modulo p-1""" return self._priv.dmq1 if self._priv else None @property def iqmp(self) -> Optional[int]: """Return the inverse of q modulo p""" return self._priv.iqmp if self._priv else None class RSAPrivateKey(_RSAKey): """A shim around PyCA for RSA private keys""" @classmethod def construct(cls, n: int, e: int, d: int, p: int, q: int, dmp1: int, dmq1: int, iqmp: int, skip_validation: bool) -> 'RSAPrivateKey': """Construct an RSA private key""" pub = rsa.RSAPublicNumbers(e, n) priv = rsa.RSAPrivateNumbers(p, q, d, dmp1, dmq1, iqmp, pub) priv_key = priv.private_key( unsafe_skip_rsa_key_validation=skip_validation) return cls(priv_key, pub, priv) @classmethod def generate(cls, key_size: int, exponent: int) -> 'RSAPrivateKey': """Generate a new RSA private key""" priv_key = rsa.generate_private_key(exponent, key_size) priv = priv_key.private_numbers() pub = priv.public_numbers return cls(priv_key, pub, priv) def decrypt(self, data: bytes, hash_name: str) -> Optional[bytes]: """Decrypt a block of data""" try: hash_alg = hashes[hash_name]() priv_key = cast('rsa.RSAPrivateKey', self.pyca_key) return priv_key.decrypt(data, OAEP(MGF1(hash_alg), hash_alg, None)) except ValueError: return None def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" priv_key = cast('rsa.RSAPrivateKey', self.pyca_key) return priv_key.sign(data, PKCS1v15(), hashes[hash_name]()) class RSAPublicKey(_RSAKey): """A shim around PyCA for RSA public keys""" @classmethod def construct(cls, n: int, e: int) -> 'RSAPublicKey': """Construct an RSA public key""" pub = rsa.RSAPublicNumbers(e, n) pub_key = pub.public_key() return cls(pub_key, pub) def encrypt(self, data: bytes, hash_name: str) -> Optional[bytes]: """Encrypt a block of data""" try: hash_alg = hashes[hash_name]() pub_key = cast('rsa.RSAPublicKey', self.pyca_key) return pub_key.encrypt(data, OAEP(MGF1(hash_alg), hash_alg, None)) except ValueError: return None def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" try: pub_key = cast('rsa.RSAPublicKey', self.pyca_key) pub_key.verify(sig, data, PKCS1v15(), hashes[hash_name]()) return True except InvalidSignature: return False asyncssh-2.20.0/asyncssh/crypto/umac.py000066400000000000000000000104471475467777400201420ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """UMAC cryptographic hash (RFC 4418) wrapper for Nettle library""" import binascii import ctypes import ctypes.util from typing import TYPE_CHECKING, Callable, Optional if TYPE_CHECKING: _ByteArray = ctypes.Array[ctypes.c_char] _SetKey = Callable[[_ByteArray, bytes], None] _SetNonce = Callable[[_ByteArray, ctypes.c_size_t, bytes], None] _Update = Callable[[_ByteArray, ctypes.c_size_t, bytes], None] _Digest = Callable[[_ByteArray, ctypes.c_size_t, _ByteArray], None] _New = Callable[[bytes, Optional[bytes], Optional[bytes]], object] _UMAC_BLOCK_SIZE = 1024 _UMAC_DEFAULT_CTX_SIZE = 4096 def _build_umac(size: int) -> '_New': """Function to build UMAC wrapper for a specific digest size""" _name = f'umac{size}' _prefix = f'nettle_{_name}_' try: _context_size: int = getattr(_nettle, _prefix + '_ctx_size')() except AttributeError: _context_size = _UMAC_DEFAULT_CTX_SIZE _set_key: _SetKey = getattr(_nettle, _prefix + 'set_key') _set_nonce: _SetNonce = getattr(_nettle, _prefix + 'set_nonce') _update: _Update = getattr(_nettle, _prefix + 'update') _digest: _Digest = getattr(_nettle, _prefix + 'digest') class _UMAC: """Wrapper for UMAC cryptographic hash This class supports the cryptographic hash API defined in PEP 452. """ name = _name block_size = _UMAC_BLOCK_SIZE digest_size = size // 8 def __init__(self, ctx: '_ByteArray', nonce: Optional[bytes] = None, msg: Optional[bytes] = None): self._ctx = ctx if nonce: self.set_nonce(nonce) if msg: self.update(msg) @classmethod def new(cls, key: bytes, msg: Optional[bytes] = None, nonce: Optional[bytes] = None) -> '_UMAC': """Construct a new UMAC hash object""" ctx = ctypes.create_string_buffer(_context_size) _set_key(ctx, key) return cls(ctx, nonce, msg) def copy(self) -> '_UMAC': """Return a new hash object with this object's state""" ctx = ctypes.create_string_buffer(self._ctx.raw) return self.__class__(ctx) def set_nonce(self, nonce: bytes) -> None: """Reset the nonce associated with this object""" _set_nonce(self._ctx, ctypes.c_size_t(len(nonce)), nonce) def update(self, msg: bytes) -> None: """Add the data in msg to the hash""" _update(self._ctx, ctypes.c_size_t(len(msg)), msg) def digest(self) -> bytes: """Return the hash and increment nonce to begin a new message .. note:: The hash is reset and the nonce is incremented when this function is called. This doesn't match the behavior defined in PEP 452. """ result = ctypes.create_string_buffer(self.digest_size) _digest(self._ctx, ctypes.c_size_t(self.digest_size), result) return result.raw def hexdigest(self) -> str: """Return the digest as a string of hexadecimal digits""" return binascii.b2a_hex(self.digest()).decode('ascii') return _UMAC.new for lib in ('nettle', 'libnettle', 'libnettle-6'): _nettle_lib = ctypes.util.find_library(lib) if _nettle_lib: # pragma: no branch break else: # pragma: no cover _nettle_lib = None if _nettle_lib: # pragma: no branch _nettle = ctypes.cdll.LoadLibrary(_nettle_lib) umac32, umac64, umac96, umac128 = map(_build_umac, (32, 64, 96, 128)) asyncssh-2.20.0/asyncssh/crypto/x509.py000066400000000000000000000360071475467777400177220ustar00rootroot00000000000000# Copyright (c) 2017-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """A shim around PyCA and PyOpenSSL for X.509 certificates""" from datetime import datetime, timezone import re import sys from typing import Iterable, List, Optional, Sequence, Set, Union, cast from cryptography.hazmat.primitives.serialization import Encoding from cryptography.hazmat.primitives.serialization import PublicFormat from cryptography import x509 from OpenSSL import crypto from ..asn1 import IA5String, der_decode, der_encode from ..misc import ip_address from .misc import PyCAKey, PyCAPrivateKey, PyCAPublicKey, hashes _Comment = Union[None, bytes, str] _Principals = Union[str, Sequence[str]] _Purposes = Union[None, str, Sequence[str]] _PurposeOIDs = Union[None, Set[x509.ObjectIdentifier]] _GeneralNameList = List[x509.GeneralName] _NameInit = Union[str, x509.Name, Iterable[x509.RelativeDistinguishedName]] _purpose_to_oid = { 'serverAuth': x509.ExtendedKeyUsageOID.SERVER_AUTH, 'clientAuth': x509.ExtendedKeyUsageOID.CLIENT_AUTH, 'secureShellClient': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.21'), 'secureShellServer': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.22')} _purpose_any = '2.5.29.37.0' _nscomment_oid = x509.ObjectIdentifier('2.16.840.1.113730.1.13') _datetime_min = datetime.fromtimestamp(0, timezone.utc).replace(microsecond=1) _datetime_32bit_max = datetime.fromtimestamp(2**31 - 1, timezone.utc) if sys.platform == 'win32': # pragma: no cover # Windows' datetime.max is year 9999, but timestamps that large don't work _datetime_max = datetime.max.replace(year=2999, tzinfo=timezone.utc) else: _datetime_max = datetime.max.replace(tzinfo=timezone.utc) def _to_generalized_time(t: int) -> datetime: """Convert a timestamp value to a datetime""" if t <= 0: return _datetime_min else: try: return datetime.fromtimestamp(t, timezone.utc) except (OSError, OverflowError): try: # Work around a bug in cryptography which shows up on # systems with a small time_t. datetime.fromtimestamp(_datetime_max.timestamp() - 1, timezone.utc) return _datetime_max except (OSError, OverflowError): # pragma: no cover return _datetime_32bit_max def _to_purpose_oids(purposes: _Purposes) -> _PurposeOIDs: """Convert a list of purposes to purpose OIDs""" if isinstance(purposes, str): purposes = [p.strip() for p in purposes.split(',')] if not purposes or 'any' in purposes or _purpose_any in purposes: purpose_oids = None else: purpose_oids = {_purpose_to_oid.get(p) or x509.ObjectIdentifier(p) for p in purposes} return purpose_oids def _encode_user_principals(principals: _Principals) -> _GeneralNameList: """Encode user principals as e-mail addresses""" if isinstance(principals, str): principals = [p.strip() for p in principals.split(',')] return [x509.RFC822Name(name) for name in principals] def _encode_host_principals(principals: _Principals) -> _GeneralNameList: """Encode host principals as DNS names or IP addresses""" def _encode_host(name: str) -> x509.GeneralName: """Encode a host principal as a DNS name or IP address""" try: return x509.IPAddress(ip_address(name)) except ValueError: return x509.DNSName(name) if isinstance(principals, str): principals = [p.strip() for p in principals.split(',')] return [_encode_host(name) for name in principals] class X509Name(x509.Name): """A shim around PyCA for X.509 distinguished names""" _escape = re.compile(r'([,+\\])') _unescape = re.compile(r'\\([,+\\])') _split_rdn = re.compile(r'(?:[^+\\]+|\\.)+') _split_name = re.compile(r'(?:[^,\\]+|\\.)+') _attrs = ( ('C', x509.NameOID.COUNTRY_NAME), ('ST', x509.NameOID.STATE_OR_PROVINCE_NAME), ('L', x509.NameOID.LOCALITY_NAME), ('O', x509.NameOID.ORGANIZATION_NAME), ('OU', x509.NameOID.ORGANIZATIONAL_UNIT_NAME), ('CN', x509.NameOID.COMMON_NAME), ('DC', x509.NameOID.DOMAIN_COMPONENT)) _to_oid = dict(_attrs) _from_oid = {v: k for k, v in _attrs} def __init__(self, name: _NameInit): if isinstance(name, str): rdns = self._parse_name(name) elif isinstance(name, x509.Name): rdns = name.rdns else: rdns = name super().__init__(rdns) def __str__(self) -> str: return ','.join(self._format_rdn(rdn) for rdn in self.rdns) def _format_rdn(self, rdn: x509.RelativeDistinguishedName) -> str: """Format an X.509 RelativeDistinguishedName as a string""" return '+'.join(sorted(self._format_attr(nameattr) for nameattr in rdn)) def _format_attr(self, nameattr: x509.NameAttribute) -> str: """Format an X.509 NameAttribute as a string""" attr = self._from_oid.get(nameattr.oid) or nameattr.oid.dotted_string return attr + '=' + self._escape.sub(r'\\\1', cast(str, nameattr.value)) def _parse_name(self, name: str) -> \ Iterable[x509.RelativeDistinguishedName]: """Parse an X.509 distinguished name""" return [self._parse_rdn(rdn) for rdn in self._split_name.findall(name)] def _parse_rdn(self, rdn: str) -> x509.RelativeDistinguishedName: """Parse an X.509 relative distinguished name""" return x509.RelativeDistinguishedName( self._parse_nameattr(av) for av in self._split_rdn.findall(rdn)) def _parse_nameattr(self, av: str) -> x509.NameAttribute: """Parse an X.509 name attribute/value pair""" try: attr, value = av.split('=', 1) except ValueError: raise ValueError('Invalid X.509 name attribute: ' + av) from None try: attr = attr.strip() oid = self._to_oid.get(attr) or x509.ObjectIdentifier(attr) except ValueError: raise ValueError('Unknown X.509 attribute: ' + attr) from None return x509.NameAttribute(oid, self._unescape.sub(r'\1', value)) class X509NamePattern: """Match X.509 distinguished names""" def __init__(self, pattern: str): if pattern.endswith(',*'): self._pattern = X509Name(pattern[:-2]) self._prefix_len: Optional[int] = len(self._pattern.rdns) else: self._pattern = X509Name(pattern) self._prefix_len = None def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _RSAKey instances # pylint: disable=protected-access if not isinstance(other, X509NamePattern): # pragma: no cover return NotImplemented return (self._pattern == other._pattern and self._prefix_len == other._prefix_len) def __hash__(self) -> int: return hash((self._pattern, self._prefix_len)) def matches(self, name: X509Name) -> bool: """Return whether an X.509 name matches this pattern""" return self._pattern.rdns == name.rdns[:self._prefix_len] class X509Certificate: """A shim around PyCA and PyOpenSSL for X.509 certificates""" def __init__(self, cert: x509.Certificate, data: bytes): self.data = data self.subject = X509Name(cert.subject) self.issuer = X509Name(cert.issuer) self.key_data = cert.public_key().public_bytes( Encoding.DER, PublicFormat.SubjectPublicKeyInfo) self.openssl_cert = crypto.X509.from_cryptography(cert) self.subject_hash = hex(self.openssl_cert.get_subject().hash())[2:] self.issuer_hash = hex(self.openssl_cert.get_issuer().hash())[2:] try: self.purposes: Optional[Set[bytes]] = \ set(cert.extensions.get_extension_for_class( x509.ExtendedKeyUsage).value) except x509.ExtensionNotFound: self.purposes = None try: sans = cert.extensions.get_extension_for_class( x509.SubjectAlternativeName).value self.user_principals = sans.get_values_for_type(x509.RFC822Name) self.host_principals = sans.get_values_for_type(x509.DNSName) + \ [str(ip) for ip in sans.get_values_for_type(x509.IPAddress)] except x509.ExtensionNotFound: cn = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME) principals = [cast(str, attr.value) for attr in cn] self.user_principals = principals self.host_principals = principals try: comment = cert.extensions.get_extension_for_oid(_nscomment_oid) comment_der = cast(x509.UnrecognizedExtension, comment.value).value self.comment: Optional[bytes] = \ cast(IA5String, der_decode(comment_der)).value except x509.ExtensionNotFound: self.comment = None def __eq__(self, other: object) -> bool: if not isinstance(other, X509Certificate): # pragma: no cover return NotImplemented return self.data == other.data def __hash__(self) -> int: return hash(self.data) def validate(self, trust_store: Sequence['X509Certificate'], purposes: _Purposes, user_principal: str, host_principal: str) -> None: """Validate an X.509 certificate""" purpose_oids = _to_purpose_oids(purposes) if purpose_oids and self.purposes and not purpose_oids & self.purposes: raise ValueError('Certificate purpose mismatch') if user_principal and user_principal not in self.user_principals: raise ValueError('Certificate user principal mismatch') if host_principal and host_principal not in self.host_principals: raise ValueError('Certificate host principal mismatch') x509_store = crypto.X509Store() for c in trust_store: x509_store.add_cert(c.openssl_cert) try: x509_ctx = crypto.X509StoreContext(x509_store, self.openssl_cert, None) x509_ctx.verify_certificate() except crypto.X509StoreContextError as exc: raise ValueError(f'X.509 chain validation error: {exc}') from None def generate_x509_certificate(signing_key: PyCAKey, key: PyCAKey, subject: _NameInit, issuer: Optional[_NameInit], serial: Optional[int], valid_after: int, valid_before: int, ca: bool, ca_path_len: Optional[int], purposes: _Purposes, user_principals: _Principals, host_principals: _Principals, hash_name: str, comment: _Comment) -> X509Certificate: """Generate a new X.509 certificate""" builder = x509.CertificateBuilder() subject = X509Name(subject) issuer = X509Name(issuer) if issuer else subject self_signed = subject == issuer builder = builder.subject_name(subject) builder = builder.issuer_name(issuer) if serial is None: serial = x509.random_serial_number() builder = builder.serial_number(serial) builder = builder.not_valid_before(_to_generalized_time(valid_after)) builder = builder.not_valid_after(_to_generalized_time(valid_before)) builder = builder.public_key(cast(PyCAPublicKey, key)) if ca: basic_constraints = x509.BasicConstraints(ca=True, path_length=ca_path_len) key_usage = x509.KeyUsage(digital_signature=False, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False) else: basic_constraints = x509.BasicConstraints(ca=False, path_length=None) key_usage = x509.KeyUsage(digital_signature=True, content_commitment=False, key_encipherment=True, data_encipherment=False, key_agreement=True, key_cert_sign=False, crl_sign=False, encipher_only=False, decipher_only=False) builder = builder.add_extension(basic_constraints, critical=True) if ca or not self_signed: builder = builder.add_extension(key_usage, critical=True) purpose_oids = _to_purpose_oids(purposes) if purpose_oids: builder = builder.add_extension(x509.ExtendedKeyUsage(purpose_oids), critical=False) skid = x509.SubjectKeyIdentifier.from_public_key(cast(PyCAPublicKey, key)) builder = builder.add_extension(skid, critical=False) if not self_signed: issuer_pk = cast(PyCAPrivateKey, signing_key).public_key() akid = x509.AuthorityKeyIdentifier.from_issuer_public_key(issuer_pk) builder = builder.add_extension(akid, critical=False) sans = _encode_user_principals(user_principals) + \ _encode_host_principals(host_principals) if sans: builder = builder.add_extension(x509.SubjectAlternativeName(sans), critical=False) if comment: if isinstance(comment, str): comment_bytes = comment.encode('utf-8') else: comment_bytes = comment comment_bytes = der_encode(IA5String(comment_bytes)) builder = builder.add_extension( x509.UnrecognizedExtension(_nscomment_oid, comment_bytes), critical=False) try: hash_alg = hashes[hash_name]() if hash_name else None except KeyError: raise ValueError('Unknown hash algorithm') from None cert = builder.sign(cast(PyCAPrivateKey, signing_key), hash_alg) # type: ignore data = cert.public_bytes(Encoding.DER) return X509Certificate(cert, data) def import_x509_certificate(data: bytes) -> X509Certificate: """Construct an X.509 certificate from DER data""" cert = x509.load_der_x509_certificate(data) return X509Certificate(cert, data) asyncssh-2.20.0/asyncssh/dsa.py000066400000000000000000000203431475467777400164400ustar00rootroot00000000000000# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """DSA public key encryption handler""" from typing import Optional, Tuple, Union, cast from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import DSAPrivateKey, DSAPublicKey from .misc import all_ints from .packet import MPInt, String, SSHPacket from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg _PrivateKeyArgs = Tuple[int, int, int, int, int] _PublicKeyArgs = Tuple[int, int, int, int] class _DSAKey(SSHKey): """Handler for DSA public key encryption""" _key: Union[DSAPrivateKey, DSAPublicKey] algorithm = b'ssh-dss' default_x509_hash = 'sha256' pem_name = b'DSA' pkcs8_oid = ObjectIdentifier('1.2.840.10040.4.1') sig_algorithms = (algorithm,) x509_algorithms = (b'x509v3-' + algorithm,) all_sig_algorithms = set(sig_algorithms) def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _DSAKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.p == other._key.p and self._key.q == other._key.q and self._key.g == other._key.g and self._key.y == other._key.y and self._key.x == other._key.x) def __hash__(self) -> int: return hash((self._key.p, self._key.q, self._key.g, self._key.y, self._key.x)) @classmethod def generate(cls, algorithm: bytes) -> '_DSAKey': # type: ignore """Generate a new DSA private key""" # pylint: disable=arguments-differ,unused-argument return cls(DSAPrivateKey.generate(key_size=1024)) @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct a DSA private key""" p, q, g, y, x = cast(_PrivateKeyArgs, key_params) return cls(DSAPrivateKey.construct(p, q, g, y, x)) @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct a DSA public key""" p, q, g, y = cast(_PublicKeyArgs, key_params) return cls(DSAPublicKey.construct(p, q, g, y)) @classmethod def decode_pkcs1_private(cls, key_data: object) -> \ Optional[_PrivateKeyArgs]: """Decode a PKCS#1 format DSA private key""" if (isinstance(key_data, tuple) and len(key_data) == 6 and all_ints(key_data) and key_data[0] == 0): return cast(_PrivateKeyArgs, key_data[1:]) else: return None @classmethod def decode_pkcs1_public(cls, key_data: object) -> \ Optional[_PublicKeyArgs]: """Decode a PKCS#1 format DSA public key""" if (isinstance(key_data, tuple) and len(key_data) == 4 and all_ints(key_data)): y, p, q, g = key_data return p, q, g, y else: return None @classmethod def decode_pkcs8_private(cls, alg_params: object, data: bytes) -> Optional[_PrivateKeyArgs]: """Decode a PKCS#8 format DSA private key""" try: x = der_decode(data) except ASN1DecodeError: return None if (isinstance(alg_params, tuple) and len(alg_params) == 3 and all_ints(alg_params) and isinstance(x, int)): p, q, g = alg_params y: int = pow(g, x, p) return p, q, g, y, x else: return None @classmethod def decode_pkcs8_public(cls, alg_params: object, data: bytes) -> Optional[_PublicKeyArgs]: """Decode a PKCS#8 format DSA public key""" try: y = der_decode(data) except ASN1DecodeError: return None if (isinstance(alg_params, tuple) and len(alg_params) == 3 and all_ints(alg_params) and isinstance(y, int)): p, q, g = alg_params return p, q, g, y else: return None @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format DSA private key""" p = packet.get_mpint() q = packet.get_mpint() g = packet.get_mpint() y = packet.get_mpint() x = packet.get_mpint() return p, q, g, y, x @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format DSA public key""" p = packet.get_mpint() q = packet.get_mpint() g = packet.get_mpint() y = packet.get_mpint() return p, q, g, y def encode_pkcs1_private(self) -> object: """Encode a PKCS#1 format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return (0, self._key.p, self._key.q, self._key.g, self._key.y, self._key.x) def encode_pkcs1_public(self) -> object: """Encode a PKCS#1 format DSA public key""" return (self._key.y, self._key.p, self._key.q, self._key.g) def encode_pkcs8_private(self) -> Tuple[object, object]: """Encode a PKCS#8 format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return (self._key.p, self._key.q, self._key.g), der_encode(self._key.x) def encode_pkcs8_public(self) -> Tuple[object, object]: """Encode a PKCS#8 format DSA public key""" return (self._key.p, self._key.q, self._key.g), der_encode(self._key.y) def encode_ssh_private(self) -> bytes: """Encode an SSH format DSA private key""" if not self._key.x: raise KeyExportError('Key is not private') return b''.join((MPInt(self._key.p), MPInt(self._key.q), MPInt(self._key.g), MPInt(self._key.y), MPInt(self._key.x))) def encode_ssh_public(self) -> bytes: """Encode an SSH format DSA public key""" return b''.join((MPInt(self._key.p), MPInt(self._key.q), MPInt(self._key.g), MPInt(self._key.y))) def encode_agent_cert_private(self) -> bytes: """Encode DSA certificate private key data for agent""" if not self._key.x: raise KeyExportError('Key is not private') return MPInt(self._key.x) def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.x: raise ValueError('Private key needed for signing') sig = der_decode(self._key.sign(data, 'sha1')) r, s = cast(Tuple[int, int], sig) return String(r.to_bytes(20, 'big') + s.to_bytes(20, 'big')) def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument sig = packet.get_string() packet.check_end() if len(sig) != 40: return False r = int.from_bytes(sig[:20], 'big') s = int.from_bytes(sig[20:], 'big') return self._key.verify(data, der_encode((r, s)), 'sha1') register_public_key_alg(b'ssh-dss', _DSAKey, False) register_certificate_alg(1, b'ssh-dss', b'ssh-dss-cert-v01@openssh.com', _DSAKey, SSHOpenSSHCertificateV01, False) for alg in _DSAKey.x509_algorithms: register_x509_certificate_alg(alg, False) asyncssh-2.20.0/asyncssh/ecdsa.py000066400000000000000000000276641475467777400167650ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """ECDSA public key encryption handler""" from typing import Dict, Optional, Tuple, Union, cast from .asn1 import ASN1DecodeError, BitString, ObjectIdentifier, TaggedDERObject from .asn1 import der_encode, der_decode from .crypto import CryptoKey, ECDSAPrivateKey, ECDSAPublicKey from .crypto import lookup_ec_curve_by_params from .packet import MPInt, String, SSHPacket from .public_key import SSHKey, SSHOpenSSHCertificateV01 from .public_key import KeyImportError, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg _PrivateKeyArgs = Tuple[bytes, Union[bytes, int], bytes] _PublicKeyArgs = Tuple[bytes, bytes] # OID for EC prime fields PRIME_FIELD = ObjectIdentifier('1.2.840.10045.1.1') _hash_algs = {b'1.3.132.0.10': 'sha256', b'nistp256': 'sha256', b'nistp384': 'sha384', b'nistp521': 'sha512'} _alg_oids: Dict[bytes, ObjectIdentifier] = {} _alg_oid_map: Dict[ObjectIdentifier, bytes] = {} class _ECKey(SSHKey): """Handler for elliptic curve public key encryption""" _key: Union[ECDSAPrivateKey, ECDSAPublicKey] default_x509_hash = 'sha256' pem_name = b'EC' pkcs8_oid = ObjectIdentifier('1.2.840.10045.2.1') def __init__(self, key: CryptoKey): super().__init__(key) self.algorithm = b'ecdsa-sha2-' + self._key.curve_id self.sig_algorithms = (self.algorithm,) self.x509_algorithms = (b'x509v3-' + self.algorithm,) self.all_sig_algorithms = set(self.sig_algorithms) self._alg_oid = _alg_oids[self._key.curve_id] self._hash_alg = _hash_algs[self._key.curve_id] def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _ECKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.curve_id == other._key.curve_id and self._key.x == other._key.x and self._key.y == other._key.y and self._key.d == other._key.d) def __hash__(self) -> int: return hash((self._key.curve_id, self._key.x, self._key.y, self._key.d)) @classmethod def _lookup_curve(cls, alg_params: object) -> bytes: """Look up an EC curve matching the specified parameters""" if isinstance(alg_params, ObjectIdentifier): try: curve_id = _alg_oid_map[alg_params] except KeyError: raise KeyImportError('Unknown elliptic curve OID ' f'{alg_params}') from None elif (isinstance(alg_params, tuple) and len(alg_params) >= 5 and alg_params[0] == 1 and isinstance(alg_params[1], tuple) and len(alg_params[1]) == 2 and alg_params[1][0] == PRIME_FIELD and isinstance(alg_params[2], tuple) and len(alg_params[2]) >= 2 and isinstance(alg_params[3], bytes) and isinstance(alg_params[2][0], bytes) and isinstance(alg_params[2][1], bytes) and isinstance(alg_params[4], int)): p = alg_params[1][1] a = int.from_bytes(alg_params[2][0], 'big') b = int.from_bytes(alg_params[2][1], 'big') point = alg_params[3] n = alg_params[4] try: curve_id = lookup_ec_curve_by_params(p, a, b, point, n) except ValueError as exc: raise KeyImportError(str(exc)) from None else: raise KeyImportError('Invalid EC curve parameters') return curve_id @classmethod def generate(cls, algorithm: bytes) -> '_ECKey': # type: ignore """Generate a new EC private key""" # pylint: disable=arguments-differ # Strip 'ecdsa-sha2-' prefix of algorithm to get curve_id return cls(ECDSAPrivateKey.generate(algorithm[11:])) @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct an EC private key""" curve_id, private_value, public_value = \ cast(_PrivateKeyArgs, key_params) if isinstance(private_value, bytes): private_value = int.from_bytes(private_value, 'big') return cls(ECDSAPrivateKey.construct(curve_id, public_value, private_value)) @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct an EC public key""" curve_id, public_value = cast(_PublicKeyArgs, key_params) return cls(ECDSAPublicKey.construct(curve_id, public_value)) @classmethod def decode_pkcs1_private(cls, key_data: object) -> \ Optional[_PrivateKeyArgs]: """Decode a PKCS#1 format EC private key""" if (isinstance(key_data, tuple) and len(key_data) > 2 and key_data[0] == 1 and isinstance(key_data[1], bytes) and isinstance(key_data[2], TaggedDERObject) and key_data[2].tag == 0): alg_params = key_data[2].value private_key = key_data[1] if (len(key_data) > 3 and isinstance(key_data[3], TaggedDERObject) and key_data[3].tag == 1 and isinstance(key_data[3].value, BitString) and key_data[3].value.unused == 0): public_key: bytes = key_data[3].value.value else: public_key = b'' return cls._lookup_curve(alg_params), private_key, public_key else: return None @classmethod def decode_pkcs1_public(cls, key_data: object) -> \ Optional[_PublicKeyArgs]: """Decode a PKCS#1 format EC public key""" # pylint: disable=unused-argument raise KeyImportError('PKCS#1 not supported for EC public keys') @classmethod def decode_pkcs8_private(cls, alg_params: object, data: bytes) -> Optional[_PrivateKeyArgs]: """Decode a PKCS#8 format EC private key""" try: key_data = der_decode(data) except ASN1DecodeError: key_data = None if (isinstance(key_data, tuple) and len(key_data) > 1 and key_data[0] == 1 and isinstance(key_data[1], bytes)): private_key = key_data[1] if (len(key_data) > 2 and isinstance(key_data[2], TaggedDERObject) and key_data[2].tag == 1 and isinstance(key_data[2].value, BitString) and key_data[2].value.unused == 0): public_key = key_data[2].value.value else: public_key = b'' return cls._lookup_curve(alg_params), private_key, public_key else: return None @classmethod def decode_pkcs8_public(cls, alg_params: object, data: bytes) -> Optional[_PublicKeyArgs]: """Decode a PKCS#8 format EC public key""" if isinstance(alg_params, ObjectIdentifier): return cls._lookup_curve(alg_params), data else: return None @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format EC private key""" curve_id = packet.get_string() public_key = packet.get_string() private_key = packet.get_mpint() return curve_id, private_key, public_key @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format EC public key""" curve_id = packet.get_string() public_key = packet.get_string() return curve_id, public_key def encode_public_tagged(self) -> object: """Encode an EC public key blob as a tagged bitstring""" return TaggedDERObject(1, BitString(self._key.public_value)) def encode_pkcs1_private(self) -> object: """Encode a PKCS#1 format EC private key""" if not self._key.private_value: raise KeyExportError('Key is not private') return (1, self._key.private_value, TaggedDERObject(0, self._alg_oid), self.encode_public_tagged()) def encode_pkcs1_public(self) -> object: """Encode a PKCS#1 format EC public key""" raise KeyExportError('PKCS#1 is not supported for EC public keys') def encode_pkcs8_private(self) -> Tuple[object, object]: """Encode a PKCS#8 format EC private key""" if not self._key.private_value: raise KeyExportError('Key is not private') return self._alg_oid, der_encode((1, self._key.private_value, self.encode_public_tagged())) def encode_pkcs8_public(self) -> Tuple[object, object]: """Encode a PKCS#8 format EC public key""" return self._alg_oid, self._key.public_value def encode_ssh_private(self) -> bytes: """Encode an SSH format EC private key""" if not self._key.d: raise KeyExportError('Key is not private') return b''.join((String(self._key.curve_id), String(self._key.public_value), MPInt(self._key.d))) def encode_ssh_public(self) -> bytes: """Encode an SSH format EC public key""" return b''.join((String(self._key.curve_id), String(self._key.public_value))) def encode_agent_cert_private(self) -> bytes: """Encode ECDSA certificate private key data for agent""" if not self._key.d: raise KeyExportError('Key is not private') return MPInt(self._key.d) def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.private_value: raise ValueError('Private key needed for signing') sig = der_decode(self._key.sign(data, self._hash_alg)) r, s = cast(Tuple[int, int], sig) return String(MPInt(r) + MPInt(s)) def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument sig = packet.get_string() packet.check_end() packet = SSHPacket(sig) r = packet.get_mpint() s = packet.get_mpint() packet.check_end() return self._key.verify(data, der_encode((r, s)), self._hash_alg) for _curve_id, _oid_str in ((b'nistp521', '1.3.132.0.35'), (b'nistp384', '1.3.132.0.34'), (b'nistp256', '1.2.840.10045.3.1.7'), (b'1.3.132.0.10', '1.3.132.0.10')): _algorithm = b'ecdsa-sha2-' + _curve_id _cert_algorithm = _algorithm + b'-cert-v01@openssh.com' _x509_algorithm = b'x509v3-' + _algorithm _oid = ObjectIdentifier(_oid_str) _alg_oids[_curve_id] = _oid _alg_oid_map[_oid] = _curve_id register_public_key_alg(_algorithm, _ECKey, True, (_algorithm,)) register_certificate_alg(1, _algorithm, _cert_algorithm, _ECKey, SSHOpenSSHCertificateV01, True) register_x509_certificate_alg(_x509_algorithm, True) asyncssh-2.20.0/asyncssh/eddsa.py000066400000000000000000000162271475467777400167570ustar00rootroot00000000000000# Copyright (c) 2019-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """EdDSA public key encryption handler""" from typing import Optional, Tuple, Union, cast from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import EdDSAPrivateKey, EdDSAPublicKey from .crypto import ed25519_available, ed448_available from .packet import String, SSHPacket from .public_key import OMIT, SSHKey, SSHOpenSSHCertificateV01 from .public_key import KeyImportError, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg _PrivateKeyArgs = Tuple[bytes] _PublicKeyArgs = Tuple[bytes] class _EdKey(SSHKey): """Handler for EdDSA public key encryption""" _key: Union[EdDSAPrivateKey, EdDSAPublicKey] algorithm = b'' def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _EdKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.public_value == other._key.public_value and self._key.private_value == other._key.private_value) def __hash__(self) -> int: return hash((self._key.public_value, self._key.private_value)) @classmethod def generate(cls, algorithm: bytes) -> '_EdKey': # type: ignore """Generate a new EdDSA private key""" # pylint: disable=arguments-differ # Strip 'ssh-' prefix of algorithm to get curve_id return cls(EdDSAPrivateKey.generate(algorithm[4:])) @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct an EdDSA private key""" try: private_value, = cast(_PrivateKeyArgs, key_params) return cls(EdDSAPrivateKey.construct(cls.algorithm[4:], private_value)) except (TypeError, ValueError): raise KeyImportError('Invalid EdDSA private key') from None @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct an EdDSA public key""" try: public_value, = cast(_PublicKeyArgs, key_params) return cls(EdDSAPublicKey.construct(cls.algorithm[4:], public_value)) except (TypeError, ValueError): raise KeyImportError('Invalid EdDSA public key') from None @classmethod def decode_pkcs8_private(cls, alg_params: object, data: bytes) -> Optional[_PrivateKeyArgs]: """Decode a PKCS#8 format EdDSA private key""" # pylint: disable=unused-argument try: return (cast(bytes, der_decode(data)),) except ASN1DecodeError: return None @classmethod def decode_pkcs8_public(cls, alg_params: object, data: bytes) -> Optional[_PublicKeyArgs]: """Decode a PKCS#8 format EdDSA public key""" # pylint: disable=unused-argument return (data,) @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format EdDSA private key""" public_value = packet.get_string() private_value = packet.get_string() return (private_value[:-len(public_value)],) @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format EdDSA public key""" public_value = packet.get_string() return (public_value,) def encode_pkcs8_private(self) -> Tuple[object, object]: """Encode a PKCS#8 format EdDSA private key""" if not self._key.private_value: raise KeyExportError('Key is not private') return OMIT, der_encode(self._key.private_value) def encode_pkcs8_public(self) -> Tuple[object, object]: """Encode a PKCS#8 format EdDSA public key""" return OMIT, self._key.public_value def encode_ssh_private(self) -> bytes: """Encode an SSH format EdDSA private key""" if self._key.private_value is None: raise KeyExportError('Key is not private') return b''.join((String(self._key.public_value), String(self._key.private_value + self._key.public_value))) def encode_ssh_public(self) -> bytes: """Encode an SSH format EdDSA public key""" return String(self._key.public_value) def encode_agent_cert_private(self) -> bytes: """Encode EdDSA certificate private key data for agent""" return self.encode_ssh_private() def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument if not self._key.private_value: raise ValueError('Private key needed for signing') return String(self._key.sign(data)) def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument sig = packet.get_string() packet.check_end() return self._key.verify(data, sig) class _Ed25519Key(_EdKey): """Handler for Curve25519 public key encryption""" algorithm = b'ssh-ed25519' pkcs8_oid = ObjectIdentifier('1.3.101.112') sig_algorithms = (algorithm,) x509_algorithms = (b'x509v3-' + algorithm,) all_sig_algorithms = set(sig_algorithms) class _Ed448Key(_EdKey): """Handler for Curve448 public key encryption""" algorithm = b'ssh-ed448' pkcs8_oid = ObjectIdentifier('1.3.101.113') sig_algorithms = (algorithm,) x509_algorithms = (b'x509v3-' + algorithm,) all_sig_algorithms = set(sig_algorithms) if ed25519_available: # pragma: no branch register_public_key_alg(b'ssh-ed25519', _Ed25519Key, True) register_certificate_alg(1, b'ssh-ed25519', b'ssh-ed25519-cert-v01@openssh.com', _Ed25519Key, SSHOpenSSHCertificateV01, True) for alg in _Ed25519Key.x509_algorithms: register_x509_certificate_alg(alg, True) if ed448_available: # pragma: no branch register_public_key_alg(b'ssh-ed448', _Ed448Key, True) register_certificate_alg(1, b'ssh-ed448', b'ssh-ed448-cert-v01@openssh.com', _Ed448Key, SSHOpenSSHCertificateV01, True) for alg in _Ed448Key.x509_algorithms: register_x509_certificate_alg(alg, True) asyncssh-2.20.0/asyncssh/editor.py000066400000000000000000000733011475467777400171610ustar00rootroot00000000000000# Copyright (c) 2016-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Input line editor""" import re from functools import partial from typing import TYPE_CHECKING, Callable, Dict, List from typing import Optional, Set, Tuple, Union, cast from unicodedata import east_asian_width from .session import DataType if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHServerChannel from .session import SSHServerSession _CharDict = Dict[str, object] _CharHandler = Callable[['SSHLineEditor'], None] _KeyHandler = Callable[[str, int], Union[bool, Tuple[str, int]]] _DEFAULT_WIDTH = 80 _ansi_terminals = ('ansi', 'cygwin', 'linux', 'putty', 'screen', 'teraterm', 'cit80', 'vt100', 'vt102', 'vt220', 'vt320', 'xterm', 'xterm-color', 'xterm-16color', 'xterm-256color', 'rxvt', 'rxvt-color') def _is_wide(ch: str) -> bool: """Return display width of character""" return east_asian_width(ch) in 'WF' class SSHLineEditor: """Input line editor""" def __init__(self, chan: 'SSHServerChannel[str]', session: 'SSHServerSession[str]', line_echo: bool, history_size: int, max_line_length: int, term_type: str, width: int): self._chan = chan self._session = session self._line_echo = line_echo self._line_pending = False self._history_size = history_size if history_size > 0 else 0 self._max_line_length = max_line_length self._wrap = term_type in _ansi_terminals self._width = width or _DEFAULT_WIDTH self._line_mode = True self._echo = True self._start_column = 0 self._end_column = 0 self._cursor = 0 self._left_pos = 0 self._right_pos = 0 self._pos = 0 self._line = '' self._bell_rung = False self._early_wrap: Set[int] = set() self._outbuf: List[str] = [] self._keymap: _CharDict = {} self._key_state = self._keymap self._erased = '' self._history: List[str] = [] self._history_index = 0 for func, keys in self._keylist: for key in keys: self._add_key(key, func) self._build_printable() def _add_key(self, key: str, func: _CharHandler) -> None: """Add a key to the keymap""" keymap = self._keymap for ch in key[:-1]: if ch not in keymap: keymap[ch] = {} keymap = cast(_CharDict, keymap[ch]) keymap[key[-1]] = func def _del_key(self, key: str) -> None: """Delete a key from the keymap""" keymap = self._keymap for ch in key[:-1]: if ch not in keymap: return keymap = cast(_CharDict, keymap[ch]) keymap.pop(key[-1], None) def _build_printable(self) -> None: """Build a regex of printable ASCII non-registered keys""" def _escape(c: int) -> str: """Backslash escape special characters in regex character range""" ch = chr(c) return ('\\' if (ch in '-&|[]\\^~') else '') + ch def _is_printable(ch: str) -> bool: """Return if character is printable and has no handler""" return ch.isprintable() and ch not in keys pat: List[str] = [] keys = self._keymap.keys() start = ord(' ') limit = 0x10000 while start < limit: while start < limit and not _is_printable(chr(start)): start += 1 end = start while _is_printable(chr(end)): end += 1 pat.append(_escape(start)) if start != end - 1: pat.append('-' + _escape(end - 1)) start = end + 1 self._printable = re.compile('[' + ''.join(pat) + ']*') def _char_width(self, pos: int) -> int: """Return width of character at specified position""" return 1 + _is_wide(self._line[pos]) + ((pos + 1) in self._early_wrap) def _determine_column(self, data: str, column: int, pos: Optional[int] = None) -> Tuple[str, int]: """Determine new output column after output occurs""" escaped = False offset = pos last_wrap_pos = pos wrapped_data = [] for ch in data: if ch == '\b': column -= 1 elif ch == '\x1b': escaped = True elif escaped: if ch == 'm': escaped = False else: if _is_wide(ch) and (column % self._width) == self._width - 1: column += 1 if pos is not None: assert last_wrap_pos is not None assert offset is not None wrapped_data.append(data[last_wrap_pos - offset: pos - offset]) last_wrap_pos = pos self._early_wrap.add(pos) else: if pos is not None: self._early_wrap.discard(pos) column += 1 + _is_wide(ch) if pos is not None: pos += 1 if pos is not None: assert last_wrap_pos is not None assert offset is not None wrapped_data.append(data[last_wrap_pos - offset:]) return ' '.join(wrapped_data), column else: return data, column def _output(self, data: str, pos: Optional[int] = None) -> None: """Generate output and calculate new output column""" idx = data.rfind('\n') if idx >= 0: self._outbuf.append(data[:idx+1]) tail = data[idx+1:] self._cursor = 0 else: tail = data data, self._cursor = self._determine_column(tail, self._cursor, pos) self._outbuf.append(data) if self._cursor and self._cursor % self._width == 0: self._outbuf.append(' \b') def _ring_bell(self) -> None: """Ring the terminal bell""" if not self._bell_rung: self._outbuf.append('\a') self._bell_rung = True def _update_input_window(self, new_pos: int) -> int: """Update visible input window when not wrapping onto multiple lines""" line_len = len(self._line) if new_pos < self._left_pos: self._left_pos = new_pos else: if new_pos < line_len: new_pos += 1 pos = self._pos column = self._cursor while pos < new_pos: column += self._char_width(pos) pos += 1 if column >= self._width: while column >= self._width: column -= self._char_width(self._left_pos) self._left_pos += 1 else: while self._left_pos > 0: column += self._char_width(self._left_pos) if column < self._width: self._left_pos -= 1 else: break column = self._start_column self._right_pos = self._left_pos while self._right_pos < line_len: ch_width = self._char_width(self._right_pos) if column + ch_width < self._width: self._right_pos += 1 column += ch_width else: break return column def _move_cursor(self, column: int) -> None: """Move the cursor to selected position in input line""" start_row = self._cursor // self._width start_col = self._cursor % self._width end_row = column // self._width end_col = column % self._width if end_row < start_row: self._outbuf.append('\x1b[' + str(start_row-end_row) + 'A') elif end_row > start_row: self._outbuf.append('\x1b[' + str(end_row-start_row) + 'B') if end_col > start_col: self._outbuf.append('\x1b[' + str(end_col-start_col) + 'C') elif end_col < start_col: self._outbuf.append('\x1b[' + str(start_col-end_col) + 'D') self._cursor = column def _move_back(self, column: int) -> None: """Move the cursor backward to selected position in input line""" if self._wrap: self._move_cursor(column) else: self._outbuf.append('\b' * (self._cursor - column)) self._cursor = column def _clear_to_end(self) -> None: """Clear any remaining characters from previous input line""" column = self._cursor remaining = self._end_column - column if remaining > 0: self._outbuf.append(' ' * remaining) self._cursor = self._end_column if self._cursor % self._width == 0: self._outbuf.append(' \b') self._move_back(column) self._end_column = column def _erase_input(self) -> None: """Erase current input line""" self._move_cursor(self._start_column) self._clear_to_end() self._early_wrap.clear() def _draw_input(self) -> None: """Draw current input line""" if self._line and self._echo: if self._wrap: self._output(self._line[:self._pos], 0) column = self._cursor self._output(self._line[self._pos:], self._pos) else: self._update_input_window(self._pos) self._output(self._line[self._left_pos:self._pos]) column = self._cursor self._output(self._line[self._pos:self._right_pos]) self._end_column = self._cursor self._move_back(column) def _reposition(self, new_pos: int, new_column: int) -> None: """Reposition the cursor to selected position in input""" if self._echo: if self._wrap: self._move_cursor(new_column) else: self._update_input(self._pos, self._cursor, new_pos) self._pos = new_pos def _update_input(self, pos: int, column: int, new_pos: int) -> None: """Update selected portion of current input line""" if self._echo: if self._wrap: if pos in self._early_wrap: column -= 1 self._move_cursor(column) prev_wrap = new_pos in self._early_wrap self._output(self._line[pos:new_pos], pos) column = self._cursor self._output(self._line[new_pos:], new_pos) column += (new_pos in self._early_wrap) - prev_wrap else: self._update_input_window(new_pos) self._move_back(self._start_column) self._output(self._line[self._left_pos:new_pos]) column = self._cursor self._output(self._line[new_pos:self._right_pos]) self._clear_to_end() self._move_back(column) self._pos = new_pos def _reset_line(self) -> None: """Reset input line to empty""" self._line = '' self._left_pos = 0 self._right_pos = 0 self._pos = 0 self._start_column = self._cursor self._end_column = self._cursor def _reset_pending(self) -> None: """Reset a pending echoed line if any""" if self._line_pending: self._erase_input() self._reset_line() self._line_pending = False def _insert_printable(self, data: str) -> None: """Insert data into the input line""" line_len = len(self._line) data_len = len(data) if self._max_line_length: if line_len + data_len > self._max_line_length: self._ring_bell() data_len = self._max_line_length - line_len data = data[:data_len] if data: pos = self._pos new_pos = pos + data_len self._line = self._line[:pos] + data + self._line[pos:] self._update_input(pos, self._cursor, new_pos) def _end_line(self) -> None: """End the current input line and send it to the session""" line = self._line need_wrap = (self._echo and not self._wrap and (self._left_pos > 0 or self._right_pos < len(line))) if self._line_echo or need_wrap: if need_wrap: self._output('\b' * (self._cursor - self._start_column) + line) else: self._move_to_end() self._output('\r\n') self._reset_line() else: self._move_to_end() self._line_pending = True if self._echo and self._history_size and line: self._history.append(line) self._history = self._history[-self._history_size:] self._history_index = len(self._history) self._session.data_received(line + '\n', None) def _eof_or_delete(self) -> None: """Erase character to the right, or send EOF if input line is empty""" if not self._line: self._session.soft_eof_received() else: self._erase_right() def _erase_left(self) -> None: """Erase character to the left""" if self._pos > 0: pos = self._pos - 1 column = self._cursor - self._char_width(pos) self._line = self._line[:pos] + self._line[pos+1:] self._update_input(pos, column, pos) else: self._ring_bell() def _erase_right(self) -> None: """Erase character to the right""" if self._pos < len(self._line): pos = self._pos self._line = self._line[:pos] + self._line[pos+1:] self._update_input(pos, self._cursor, pos) else: self._ring_bell() def _erase_line(self) -> None: """Erase entire input line""" self._erased = self._line self._line = '' self._update_input(0, self._start_column, 0) def _erase_to_end(self) -> None: """Erase to end of input line""" pos = self._pos self._erased = self._line[pos:] self._line = self._line[:pos] self._update_input(pos, self._cursor, pos) def _handle_key(self, key: str, handler: _KeyHandler) -> None: """Call an external key handler""" result = handler(self._line, self._pos) if result is True: if key.isprintable(): self._insert_printable(key) else: self._ring_bell() elif result is False: self._ring_bell() else: line, new_pos = cast(Tuple[str, int], result) if new_pos < 0: self._session.signal_received(line) else: self._line = line self._update_input(0, self._start_column, new_pos) def _history_prev(self) -> None: """Replace input with previous line in history""" if self._history_index > 0: self._history_index -= 1 self._line = self._history[self._history_index] self._update_input(0, self._start_column, len(self._line)) else: self._ring_bell() def _history_next(self) -> None: """Replace input with next line in history""" if self._history_index < len(self._history): self._history_index += 1 if self._history_index < len(self._history): self._line = self._history[self._history_index] else: self._line = '' self._update_input(0, self._start_column, len(self._line)) else: self._ring_bell() def _move_left(self) -> None: """Move left in input line""" if self._pos > 0: pos = self._pos - 1 column = self._cursor - self._char_width(pos) self._reposition(pos, column) else: self._ring_bell() def _move_right(self) -> None: """Move right in input line""" if self._pos < len(self._line): pos = self._pos column = self._cursor + self._char_width(pos) self._reposition(pos + 1, column) else: self._ring_bell() def _move_to_start(self) -> None: """Move to start of input line""" self._reposition(0, self._start_column) def _move_to_end(self) -> None: """Move to end of input line""" self._reposition(len(self._line), self._end_column) def _redraw(self) -> None: """Redraw input line""" self._erase_input() self._draw_input() def _insert_erased(self) -> None: """Insert previously erased input""" self._insert_printable(self._erased) def _send_break(self) -> None: """Send break to session""" self._session.break_received(0) _keylist = ((_end_line, ('\n', '\r', '\x1bOM')), (_eof_or_delete, ('\x04',)), (_erase_left, ('\x08', '\x7f')), (_erase_right, ('\x1b[3~',)), (_erase_line, ('\x15',)), (_erase_to_end, ('\x0b',)), (_history_prev, ('\x10', '\x1b[A', '\x1bOA')), (_history_next, ('\x0e', '\x1b[B', '\x1bOB')), (_move_left, ('\x02', '\x1b[D', '\x1bOD')), (_move_right, ('\x06', '\x1b[C', '\x1bOC')), (_move_to_start, ('\x01', '\x1b[H', '\x1b[1~')), (_move_to_end, ('\x05', '\x1b[F', '\x1b[4~')), (_redraw, ('\x12',)), (_insert_erased, ('\x19',)), (_send_break, ('\x03', '\x1b[33~'))) def register_key(self, key: str, handler: _KeyHandler) -> None: """Register a handler to be called when a key is pressed""" self._add_key(key, partial(SSHLineEditor._handle_key, key=key, handler=handler)) self._build_printable() def unregister_key(self, key: str) -> None: """Remove the handler associated with a key""" self._del_key(key) self._build_printable() def set_input(self, line: str, pos: int) -> None: """Set input line and cursor position""" self._reset_pending() self._line = line self._update_input(0, self._start_column, pos) def set_line_mode(self, line_mode: bool) -> None: """Enable/disable input line editing""" self._reset_pending() if self._line and not line_mode: data = self._line self._erase_input() self._line = '' self._session.data_received(data, None) self._line_mode = line_mode def set_echo(self, echo: bool) -> None: """Enable/disable echoing of input in line mode""" self._reset_pending() if self._echo and not echo: self._erase_input() self._echo = False elif echo and not self._echo: self._echo = True self._draw_input() def set_width(self, width: int) -> None: """Set terminal line width""" self._reset_pending() self._width = width or _DEFAULT_WIDTH if self._wrap: _, self._cursor = self._determine_column(self._line, self._start_column, 0) self._redraw() def process_input(self, data: str, datatype: DataType) -> None: """Process input from channel""" if self._line_mode: data_len = len(data) idx = 0 while idx < data_len: self._reset_pending() ch = data[idx] idx += 1 if ch in self._key_state: key_state = self._key_state[ch] if callable(key_state): try: cast(_CharHandler, key_state)(self) finally: self._key_state = self._keymap else: self._key_state = cast(_CharDict, key_state) elif self._key_state == self._keymap and ch.isprintable(): match = self._printable.match(data, idx - 1) assert match is not None match = match[0] if match: self._insert_printable(match) idx += len(match) - 1 else: self._insert_printable(ch) else: self._key_state = self._keymap self._ring_bell() self._bell_rung = False if self._outbuf: self._chan.write(''.join(self._outbuf)) self._outbuf.clear() else: self._session.data_received(data, datatype) def process_output(self, data: str) -> None: """Process output to channel""" if self._line_pending: if data.startswith(self._line): self._start_column = self._cursor data = data[len(self._line):] else: self._erase_input() self._reset_line() self._line_pending = False data = data.replace('\n', '\r\n') self._erase_input() self._output(data) if not self._wrap: self._cursor %= self._width self._start_column = self._cursor self._end_column = self._cursor self._draw_input() self._chan.write(''.join(self._outbuf)) self._outbuf.clear() class SSHLineEditorChannel: """Input line editor channel wrapper When creating server channels with `line_editor` set to `True`, this class is wrapped around the channel, providing the caller with the ability to enable and disable input line editing and echoing. .. note:: Line editing is only available when a pseudo-terminal is requested on the server channel and the character encoding on the channel is not set to `None`. """ def __init__(self, orig_chan: 'SSHServerChannel[str]', orig_session: 'SSHServerSession[str]', line_echo: bool, history_size: int, max_line_length: int): self._orig_chan = orig_chan self._orig_session = orig_session self._line_echo = line_echo self._history_size = history_size self._max_line_length = max_line_length self._editor: Optional[SSHLineEditor] = None def __getattr__(self, attr: str): """Delegate most channel functions to original channel""" return getattr(self._orig_chan, attr) def create_editor(self) -> Optional[SSHLineEditor]: """Create input line editor if encoding and terminal type are set""" encoding, _ = self._orig_chan.get_encoding() term_type = self._orig_chan.get_terminal_type() width = self._orig_chan.get_terminal_size()[0] if encoding and term_type: self._editor = SSHLineEditor( self._orig_chan, self._orig_session, self._line_echo, self._history_size, self._max_line_length, term_type, width) return self._editor def register_key(self, key: str, handler: _KeyHandler) -> None: """Register a handler to be called when a key is pressed This method registers a handler function which will be called when a user presses the specified key while inputting a line. The handler will be called with arguments of the current input line and cursor position, and updated versions of these two values should be returned as a tuple. The handler can also return a tuple of a signal name and negative cursor position to cause a signal to be delivered on the channel. In this case, the current input line is left unchanged but the signal is delivered before processing any additional input. This can be used to define "hot keys" that trigger actions unrelated to editing the input. If the registered key is printable text, returning `True` will insert that text at the current cursor position, acting as if no handler was registered for that key. This is useful if you want to perform a special action in some cases but not others, such as based on the current cursor position. Returning `False` will ring the bell and leave the input unchanged, indicating the requested action could not be performed. :param key: The key sequence to look for :param handler: The handler function to call when the key is pressed :type key: `str` :type handler: `callable` """ assert self._editor is not None self._editor.register_key(key, handler) def unregister_key(self, key: str) -> None: """Remove the handler associated with a key This method removes a handler function associated with the specified key. If the key sequence is printable, this will cause it to return to being inserted at the current position when pressed. Otherwise, it will cause the bell to ring to signal the key is not understood. :param key: The key sequence to look for :type key: `str` """ assert self._editor is not None self._editor.unregister_key(key) def clear_input(self) -> None: """Clear input line This method clears the current input line. """ assert self._editor is not None self._editor.set_input('', 0) def set_input(self, line: str, pos: int) -> None: """Clear input line This method sets the current input line and cursor position. :param line: The new input line :param pos: The new cursor position within the input line :type line: `str` :type pos: `int` """ assert self._editor is not None self._editor.set_input(line, pos) def set_line_mode(self, line_mode: bool) -> None: """Enable/disable input line editing This method enabled or disables input line editing. When set, only full lines of input are sent to the session, and each line of input can be edited before it is sent. :param line_mode: Whether or not to process input a line at a time :type line_mode: `bool` """ self._orig_chan.logger.info('%s line editor', 'Enabling' if line_mode else 'Disabling') assert self._editor is not None self._editor.set_line_mode(line_mode) def set_echo(self, echo: bool) -> None: """Enable/disable echoing of input in line mode This method enables or disables echoing of input data when input line editing is enabled. :param echo: Whether or not input to echo input as it is entered :type echo: `bool` """ self._orig_chan.logger.info('%s echo', 'Enabling' if echo else 'Disabling') assert self._editor is not None self._editor.set_echo(echo) def write(self, data: str, datatype: DataType = None) -> None: """Process data written to the channel""" if self._editor and datatype is None: self._editor.process_output(data) else: self._orig_chan.write(data, datatype) class SSHLineEditorSession: """Input line editor session wrapper""" def __init__(self, chan: SSHLineEditorChannel, orig_session: 'SSHServerSession[str]'): self._chan = chan self._orig_session = orig_session self._editor: Optional[SSHLineEditor] = None def __getattr__(self, attr: str): """Delegate most channel functions to original session""" return getattr(self._orig_session, attr) def session_started(self) -> None: """Start a session for this newly opened server channel""" self._editor = self._chan.create_editor() self._orig_session.session_started() def terminal_size_changed(self, width: int, height: int, pixwidth: int, pixheight: int) -> None: """The terminal size has changed""" if self._editor: self._editor.set_width(width) self._orig_session.terminal_size_changed(width, height, pixwidth, pixheight) def data_received(self, data: str, datatype: DataType) -> None: """Process data received from the channel""" if self._editor: self._editor.process_input(data, datatype) else: self._orig_session.data_received(data, datatype) def eof_received(self) -> Optional[bool]: """Process EOF received from the channel""" if self._editor: self._editor.set_line_mode(False) return self._orig_session.eof_received() asyncssh-2.20.0/asyncssh/encryption.py000066400000000000000000000253741475467777400200740ustar00rootroot00000000000000# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Symmetric key encryption handlers""" from typing import Dict, List, Optional, Tuple, Type from .crypto import BasicCipher, GCMCipher, ChachaCipher, get_cipher_params from .mac import MAC, get_mac_params, get_mac from .packet import UInt64 _EncParams = Tuple[int, int, int, int, int, bool] _EncParamsMap = Dict[bytes, Tuple[Type['Encryption'], str]] _enc_algs: List[bytes] = [] _default_enc_algs: List[bytes] = [] _enc_params: _EncParamsMap = {} class Encryption: """Parent class for SSH packet encryption objects""" @classmethod def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', mac_key: bytes = b'', etm: bool = False) -> 'Encryption': """Construct a new SSH packet encryption object""" raise NotImplementedError @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: """Get parameters of the MAC algorithm used with this encryption""" return get_mac_params(mac_alg) def encrypt_packet(self, seq: int, header: bytes, packet: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign an SSH packet""" raise NotImplementedError def decrypt_header(self, seq: int, first_block: bytes, header_len: int) -> Tuple[bytes, bytes]: """Decrypt an SSH packet header""" raise NotImplementedError def decrypt_packet(self, seq: int, first: bytes, rest: bytes, header_len: int, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt an SSH packet""" raise NotImplementedError class BasicEncryption(Encryption): """Shim for basic encryption""" def __init__(self, cipher: BasicCipher, mac: MAC): self._cipher = cipher self._mac = mac @classmethod def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', mac_key: bytes = b'', etm: bool = False) -> 'BasicEncryption': """Construct a new SSH packet encryption object for basic ciphers""" cipher = BasicCipher(cipher_name, key, iv) mac = get_mac(mac_alg, mac_key) if etm: return ETMEncryption(cipher, mac) else: return cls(cipher, mac) def encrypt_packet(self, seq: int, header: bytes, packet: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign an SSH packet""" packet = header + packet mac = self._mac.sign(seq, packet) if self._mac else b'' return self._cipher.encrypt(packet), mac def decrypt_header(self, seq: int, first_block: bytes, header_len: int) -> Tuple[bytes, bytes]: """Decrypt an SSH packet header""" first_block = self._cipher.decrypt(first_block) return first_block, first_block[:header_len] def decrypt_packet(self, seq: int, first: bytes, rest: bytes, header_len: int, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt an SSH packet""" packet = first + self._cipher.decrypt(rest) if self._mac.verify(seq, packet, mac): return packet[header_len:] else: return None class ETMEncryption(BasicEncryption): """Shim for encrypt-then-mac encryption""" def encrypt_packet(self, seq: int, header: bytes, packet: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign an SSH packet""" packet = header + self._cipher.encrypt(packet) return packet, self._mac.sign(seq, packet) def decrypt_header(self, seq: int, first_block: bytes, header_len: int) -> Tuple[bytes, bytes]: """Decrypt an SSH packet header""" return first_block, first_block[:header_len] def decrypt_packet(self, seq: int, first: bytes, rest: bytes, header_len: int, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt an SSH packet""" packet = first + rest if self._mac.verify(seq, packet, mac): return self._cipher.decrypt(packet[header_len:]) else: return None class GCMEncryption(Encryption): """Shim for GCM encryption""" def __init__(self, cipher: GCMCipher): self._cipher = cipher @classmethod def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', mac_key: bytes = b'', etm: bool = False) -> 'GCMEncryption': """Construct a new SSH packet encryption object for GCM ciphers""" return cls(GCMCipher(cipher_name, key, iv)) @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: """Get parameters of the MAC algorithm used with this encryption""" return 0, 16, True def encrypt_packet(self, seq: int, header: bytes, packet: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign an SSH packet""" return self._cipher.encrypt_and_sign(header, packet) def decrypt_header(self, seq: int, first_block: bytes, header_len: int) -> Tuple[bytes, bytes]: """Decrypt an SSH packet header""" return first_block, first_block[:header_len] def decrypt_packet(self, seq: int, first: bytes, rest: bytes, header_len: int, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt an SSH packet""" return self._cipher.verify_and_decrypt(first[:header_len], first[header_len:] + rest, mac) class ChachaEncryption(Encryption): """Shim for chacha20-poly1305 encryption""" def __init__(self, cipher: ChachaCipher): self._cipher = cipher @classmethod def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', mac_key: bytes = b'', etm: bool = False) -> 'ChachaEncryption': """Construct a new SSH packet encryption object for Chacha ciphers""" return cls(ChachaCipher(key)) @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: """Get parameters of the MAC algorithm used with this encryption""" return 0, 16, True def encrypt_packet(self, seq: int, header: bytes, packet: bytes) -> Tuple[bytes, bytes]: """Encrypt and sign an SSH packet""" return self._cipher.encrypt_and_sign(header, packet, UInt64(seq)) def decrypt_header(self, seq: int, first_block: bytes, header_len: int) -> Tuple[bytes, bytes]: """Decrypt an SSH packet header""" return (first_block, self._cipher.decrypt_header(first_block[:header_len], UInt64(seq))) def decrypt_packet(self, seq: int, first: bytes, rest: bytes, header_len: int, mac: bytes) -> Optional[bytes]: """Verify the signature of and decrypt an SSH packet""" return self._cipher.verify_and_decrypt(first[:header_len], first[header_len:] + rest, UInt64(seq), mac) def register_encryption_alg(enc_alg: bytes, encryption: Type[Encryption], cipher_name: str, default: bool) -> None: """Register an encryption algorithm""" try: get_cipher_params(cipher_name) except KeyError: pass else: _enc_algs.append(enc_alg) if default: _default_enc_algs.append(enc_alg) _enc_params[enc_alg] = (encryption, cipher_name) def get_encryption_algs() -> List[bytes]: """Return supported encryption algorithms""" return _enc_algs def get_default_encryption_algs() -> List[bytes]: """Return default encryption algorithms""" return _default_enc_algs def get_encryption_params(enc_alg: bytes, mac_alg: bytes = b'') -> _EncParams: """Get parameters of an encryption and MAC algorithm""" encryption, cipher_name = _enc_params[enc_alg] enc_keysize, enc_ivsize, enc_blocksize = get_cipher_params(cipher_name) mac_keysize, mac_hashsize, etm = encryption.get_mac_params(mac_alg) return (enc_keysize, enc_ivsize, enc_blocksize, mac_keysize, mac_hashsize, etm) def get_encryption(enc_alg: bytes, key: bytes, iv: bytes, mac_alg: bytes = b'', mac_key: bytes = b'', etm: bool = False) -> Encryption: """Return an object which can encrypt and decrypt SSH packets""" encryption, cipher_name = _enc_params[enc_alg] return encryption.new(cipher_name, key, iv, mac_alg, mac_key, etm) _enc_alg_list = ( (b'chacha20-poly1305@openssh.com', ChachaEncryption, 'chacha20-poly1305', True), (b'aes256-gcm@openssh.com', GCMEncryption, 'aes256-gcm', True), (b'aes128-gcm@openssh.com', GCMEncryption, 'aes128-gcm', True), (b'aes256-ctr', BasicEncryption, 'aes256-ctr', True), (b'aes192-ctr', BasicEncryption, 'aes192-ctr', True), (b'aes128-ctr', BasicEncryption, 'aes128-ctr', True), (b'aes256-cbc', BasicEncryption, 'aes256-cbc', False), (b'aes192-cbc', BasicEncryption, 'aes192-cbc', False), (b'aes128-cbc', BasicEncryption, 'aes128-cbc', False), (b'3des-cbc', BasicEncryption, 'des3-cbc', False), (b'blowfish-cbc', BasicEncryption, 'blowfish-cbc', False), (b'cast128-cbc', BasicEncryption, 'cast128-cbc', False), (b'seed-cbc@ssh.com', BasicEncryption, 'seed-cbc', False), (b'arcfour256', BasicEncryption, 'arcfour256', False), (b'arcfour128', BasicEncryption, 'arcfour128', False), (b'arcfour', BasicEncryption, 'arcfour', False) ) for _enc_alg_args in _enc_alg_list: register_encryption_alg(*_enc_alg_args) asyncssh-2.20.0/asyncssh/forward.py000066400000000000000000000157151475467777400173440ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH port forwarding handlers""" import asyncio import socket from types import TracebackType from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from typing import Type, cast from typing_extensions import Self from .misc import ChannelOpenError, SockAddr if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection SSHForwarderCoro = Callable[..., Awaitable] class SSHForwarder(asyncio.BaseProtocol): """SSH port forwarding connection handler""" def __init__(self, peer: Optional['SSHForwarder'] = None, extra: Optional[Dict[str, Any]] = None): self._peer = peer self._transport: Optional[asyncio.Transport] = None self._inpbuf = b'' self._eof_received = False if peer: peer.set_peer(self) if extra is None: extra = {} self._extra = extra async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: self.close() return False def get_extra_info(self, name: str, default: Any = None) -> Any: """Get additional information about the forwarder This method returns extra information about the forwarder. Currently, the only information available is the value ``interface`` for TUN/TAP forwarders, returning the name of the local TUN/TAP network interface created for this forwarder. """ return self._extra.get(name, default) def set_peer(self, peer: 'SSHForwarder') -> None: """Set the peer forwarder to exchange data with""" self._peer = peer def write(self, data: bytes) -> None: """Write data to the transport""" assert self._transport is not None self._transport.write(data) def write_eof(self) -> None: """Write end of file to the transport""" assert self._transport is not None try: self._transport.write_eof() except OSError: # pragma: no cover pass def was_eof_received(self) -> bool: """Return whether end of file has been received or not""" return self._eof_received def pause_reading(self) -> None: """Pause reading from the transport""" assert self._transport is not None self._transport.pause_reading() def resume_reading(self) -> None: """Resume reading on the transport""" assert self._transport is not None self._transport.resume_reading() def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened connection""" self._transport = cast(Optional['asyncio.Transport'], transport) sock = cast(socket.socket, transport.get_extra_info('socket')) if sock and sock.family in {socket.AF_INET, socket.AF_INET6}: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def connection_lost(self, exc: Optional[Exception]) -> None: """Handle an incoming connection close""" # pylint: disable=unused-argument self.close() def session_started(self) -> None: """Handle session start""" def data_received(self, data: bytes, datatype: Optional[int] = None) -> None: """Handle incoming data from the transport""" # pylint: disable=unused-argument if self._peer: self._peer.write(data) else: self._inpbuf += data def eof_received(self) -> bool: """Handle an incoming end of file from the transport""" self._eof_received = True if self._peer: self._peer.write_eof() return not self._peer.was_eof_received() else: return True def pause_writing(self) -> None: """Pause writing by asking peer to pause reading""" if self._peer: # pragma: no branch self._peer.pause_reading() def resume_writing(self) -> None: """Resume writing by asking peer to resume reading""" if self._peer: # pragma: no branch self._peer.resume_reading() def close(self) -> None: """Close this port forwarder""" if self._transport: self._transport.close() self._transport = None if self._peer: peer = self._peer self._peer = None peer.close() class SSHLocalForwarder(SSHForwarder): """Local forwarding connection handler""" def __init__(self, conn: 'SSHConnection', coro: SSHForwarderCoro): super().__init__() self._conn = conn self._coro = coro async def _forward(self, *args: object) -> None: """Begin local forwarding""" def session_factory() -> SSHForwarder: """Return an SSH forwarder""" return SSHForwarder(self) try: await self._coro(session_factory, *args) except ChannelOpenError as exc: self.connection_lost(exc) return assert self._peer is not None if self._inpbuf: self._peer.write(self._inpbuf) self._inpbuf = b'' if self._eof_received: self._peer.write_eof() def forward(self, *args: object) -> None: """Start a task to begin local forwarding""" self._conn.create_task(self._forward(*args)) class SSHLocalPortForwarder(SSHLocalForwarder): """Local TCP port forwarding connection handler""" def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened connection""" super().connection_made(transport) peername = cast(SockAddr, transport.get_extra_info('peername')) if peername: # pragma: no branch orig_host, orig_port = peername[:2] self.forward(orig_host, orig_port) class SSHLocalPathForwarder(SSHLocalForwarder): """Local UNIX domain socket forwarding connection handler""" def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened connection""" super().connection_made(transport) self.forward() asyncssh-2.20.0/asyncssh/gss.py000066400000000000000000000041231475467777400164630ustar00rootroot00000000000000# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper""" import sys from typing import Optional from .misc import BytesOrStrDict try: # pylint: disable=unused-import if sys.platform == 'win32': # pragma: no cover from .gss_win32 import GSSBase, GSSClient, GSSServer, GSSError else: from .gss_unix import GSSBase, GSSClient, GSSServer, GSSError gss_available = True except ImportError: # pragma: no cover gss_available = False class GSSError(ValueError): # type: ignore """Stub class for reporting that GSS is not available""" def __init__(self, maj_code: int, min_code: int, token: Optional[bytes] = None): super().__init__('GSS not available') self.maj_code = maj_code self.min_code = min_code self.token = token class GSSBase: # type: ignore """Base class for reporting that GSS is not available""" class GSSClient(GSSBase): # type: ignore """Stub client class for reporting that GSS is not available""" def __init__(self, _host: str, _store: Optional[BytesOrStrDict], _delegate_creds: bool): raise GSSError(0, 0) class GSSServer(GSSBase): # type: ignore """Stub client class for reporting that GSS is not available""" def __init__(self, _host: str, _store: Optional[BytesOrStrDict]): raise GSSError(0, 0) asyncssh-2.20.0/asyncssh/gss_unix.py000066400000000000000000000120761475467777400175340ustar00rootroot00000000000000# Copyright (c) 2017-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper for UNIX""" from typing import Optional, Sequence, SupportsBytes, cast from gssapi import Credentials, Name, NameType, OID from gssapi import RequirementFlag, SecurityContext from gssapi.exceptions import GSSError from .asn1 import OBJECT_IDENTIFIER from .misc import BytesOrStrDict def _mech_to_oid(mech: OID) -> bytes: """Return a DER-encoded OID corresponding to the requested GSS mechanism""" mech_bytes = bytes(cast(SupportsBytes, mech)) return bytes((OBJECT_IDENTIFIER, len(mech_bytes))) + mech_bytes class GSSBase: """GSS base class""" def __init__(self, host: str, store: Optional[BytesOrStrDict]): if '@' in host: self._host = Name(host) else: self._host = Name('host@' + host, NameType.hostbased_service) self._store = store self._mechs = [_mech_to_oid(mech) for mech in self._creds.mechs] self._ctx: Optional[SecurityContext] = None @property def _creds(self) -> Credentials: """Abstract method to construct GSS credentials""" raise NotImplementedError def _init_context(self) -> None: """Abstract method to construct GSS security context""" raise NotImplementedError @property def mechs(self) -> Sequence[bytes]: """Return GSS mechanisms available for this host""" return self._mechs @property def complete(self) -> bool: """Return whether or not GSS negotiation is complete""" return self._ctx.complete if self._ctx else False @property def provides_mutual_auth(self) -> bool: """Return whether or not this context provides mutual authentication""" assert self._ctx is not None return bool(self._ctx.actual_flags & RequirementFlag.mutual_authentication) @property def provides_integrity(self) -> bool: """Return whether or not this context provides integrity protection""" assert self._ctx is not None return bool(self._ctx.actual_flags & RequirementFlag.integrity) @property def user(self) -> str: """Return user principal associated with this context""" assert self._ctx is not None return str(self._ctx.initiator_name) @property def host(self) -> str: """Return host principal associated with this context""" assert self._ctx is not None return str(self._ctx.target_name) def reset(self) -> None: """Reset GSS security context""" self._ctx = None def step(self, token: Optional[bytes] = None) -> Optional[bytes]: """Perform next step in GSS security exchange""" if not self._ctx: self._init_context() assert self._ctx is not None return self._ctx.step(token) def sign(self, data: bytes) -> bytes: """Sign a block of data""" assert self._ctx is not None return self._ctx.get_signature(data) def verify(self, data: bytes, sig: bytes) -> bool: """Verify a signature for a block of data""" assert self._ctx is not None try: self._ctx.verify_signature(data, sig) return True except GSSError: return False class GSSClient(GSSBase): """GSS client""" def __init__(self, host: str, store: Optional[BytesOrStrDict], delegate_creds: bool): super().__init__(host, store) flags = RequirementFlag.mutual_authentication | \ RequirementFlag.integrity if delegate_creds: flags |= RequirementFlag.delegate_to_peer self._flags = flags @property def _creds(self) -> Credentials: """Abstract method to construct GSS credentials""" return Credentials(usage='initiate', store=self._store) def _init_context(self) -> None: """Construct GSS client security context""" self._ctx = SecurityContext(name=self._host, creds=self._creds, flags=self._flags) class GSSServer(GSSBase): """GSS server""" @property def _creds(self) -> Credentials: """Abstract method to construct GSS credentials""" return Credentials(name=self._host, usage='accept', store=self._store) def _init_context(self) -> None: """Construct GSS server security context""" self._ctx = SecurityContext(creds=self._creds) asyncssh-2.20.0/asyncssh/gss_win32.py000066400000000000000000000137061475467777400175140ustar00rootroot00000000000000# Copyright (c) 2017-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """GSSAPI wrapper for Windows""" # Some of the imports below won't be found when running pylint on UNIX # pylint: disable=import-error from typing import Optional, Sequence, Union from sspi import ClientAuth, ServerAuth from sspi import error as SSPIError from sspicon import ISC_REQ_DELEGATE, ISC_REQ_INTEGRITY, ISC_REQ_MUTUAL_AUTH from sspicon import ISC_RET_INTEGRITY, ISC_RET_MUTUAL_AUTH from sspicon import ASC_REQ_INTEGRITY, ASC_REQ_MUTUAL_AUTH from sspicon import ASC_RET_INTEGRITY, ASC_RET_MUTUAL_AUTH from sspicon import SECPKG_ATTR_NATIVE_NAMES from .asn1 import ObjectIdentifier, der_encode from .misc import BytesOrStrDict _krb5_oid = der_encode(ObjectIdentifier('1.2.840.113554.1.2.2')) class GSSBase: """GSS base class""" # Overridden in child classes _mutual_auth_flag = 0 _integrity_flag = 0 def __init__(self, host: str): if '@' in host: self._host = host else: self._host = 'host/' + host self._ctx: Optional[Union[ClientAuth, ServerAuth]] = None self._init_token: Optional[bytes] = None @property def mechs(self) -> Sequence[bytes]: """Return GSS mechanisms available for this host""" return [_krb5_oid] @property def complete(self) -> bool: """Return whether or not GSS negotiation is complete""" assert self._ctx is not None return self._ctx.authenticated @property def provides_mutual_auth(self) -> bool: """Return whether or not this context provides mutual authentication""" assert self._ctx is not None return bool(self._ctx.ctxt_attr & self._mutual_auth_flag) @property def provides_integrity(self) -> bool: """Return whether or not this context provides integrity protection""" assert self._ctx is not None return bool(self._ctx.ctxt_attr & self._integrity_flag) @property def user(self) -> str: """Return user principal associated with this context""" assert self._ctx is not None names = self._ctx.ctxt.QueryContextAttributes(SECPKG_ATTR_NATIVE_NAMES) return names[0] @property def host(self) -> str: """Return host principal associated with this context""" assert self._ctx is not None names = self._ctx.ctxt.QueryContextAttributes(SECPKG_ATTR_NATIVE_NAMES) return names[1] def reset(self) -> None: """Reset GSS security context""" assert self._ctx is not None self._ctx.reset() self._init_token = None def step(self, token: Optional[bytes] = None) -> Optional[bytes]: """Perform next step in GSS security exchange""" assert self._ctx is not None if self._init_token: token = self._init_token self._init_token = None return token try: _, buf = self._ctx.authorize(token) return buf[0].Buffer except SSPIError as exc: raise GSSError(details=exc.strerror) from None def sign(self, data: bytes) -> bytes: """Sign a block of data""" assert self._ctx is not None try: return self._ctx.sign(data) except SSPIError as exc: # pragna: no cover raise GSSError(details=exc.strerror) from None def verify(self, data: bytes, sig: bytes) -> bool: """Verify a signature for a block of data""" assert self._ctx is not None try: self._ctx.verify(data, sig) return True except SSPIError: return False class GSSClient(GSSBase): """GSS client""" _mutual_auth_flag = ISC_RET_MUTUAL_AUTH _integrity_flag = ISC_RET_INTEGRITY def __init__(self, host: str, store: Optional[BytesOrStrDict], delegate_creds: bool): if store is not None: # pragna: no cover raise GSSError(details='GSS store not supported on Windows') super().__init__(host) flags = ISC_REQ_MUTUAL_AUTH | ISC_REQ_INTEGRITY if delegate_creds: flags |= ISC_REQ_DELEGATE try: self._ctx = ClientAuth('Kerberos', targetspn=self._host, scflags=flags) except SSPIError as exc: # pragna: no cover raise GSSError(1, 1, details=exc.strerror) from None self._init_token = self.step(None) class GSSServer(GSSBase): """GSS server""" _mutual_auth_flag = ASC_RET_MUTUAL_AUTH _integrity_flag = ASC_RET_INTEGRITY def __init__(self, host: str, store: Optional[BytesOrStrDict]): if store is not None: # pragna: no cover raise GSSError(details='GSS store not supported on Windows') super().__init__(host) flags = ASC_REQ_MUTUAL_AUTH | ASC_REQ_INTEGRITY try: self._ctx = ServerAuth('Kerberos', spn=self._host, scflags=flags) except SSPIError as exc: raise GSSError(1, 1, details=exc.strerror) from None class GSSError(Exception): """Class for reporting GSS errors""" def __init__(self, maj_code: int = 0, min_code: int = 0, token: Optional[bytes] = None, details: str = ''): super().__init__(details) self.maj_code = maj_code self.min_code = min_code self.token = token asyncssh-2.20.0/asyncssh/kex.py000066400000000000000000000105561475467777400164650ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH key exchange handlers""" import binascii from hashlib import md5 from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple, Type from .logging import SSHLogger from .misc import HashType from .packet import SSHPacketHandler if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection _KexAlgList = List[bytes] _KexAlgMap = Dict[bytes, Tuple[Type['Kex'], HashType, object]] _kex_algs: _KexAlgList = [] _default_kex_algs:_KexAlgList = [] _kex_handlers: _KexAlgMap = {} _gss_kex_algs: _KexAlgList = [] _default_gss_kex_algs: _KexAlgList = [] _gss_kex_handlers: _KexAlgMap = {} class Kex(SSHPacketHandler): """Parent class for key exchange handlers""" def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): self.algorithm = alg self._conn = conn self._logger = conn.logger self._hash_alg = hash_alg async def start(self) -> None: """Start key exchange""" raise NotImplementedError def send_packet(self, pkttype: int, *args: bytes) -> None: """Send a kex packet""" self._conn.send_packet(pkttype, *args, handler=self) @property def logger(self) -> SSHLogger: """A logger associated with this connection""" return self._logger def compute_key(self, k: bytes, h: bytes, x: bytes, session_id: bytes, keylen: int) -> bytes: """Compute keys from output of key exchange""" key = b'' while len(key) < keylen: hash_obj = self._hash_alg() hash_obj.update(k) hash_obj.update(h) hash_obj.update(key if key else x + session_id) key += hash_obj.digest() return key[:keylen] def register_kex_alg(alg: bytes, handler: Type[Kex], hash_alg: HashType, args: Tuple, default: bool) -> None: """Register a key exchange algorithm""" _kex_algs.append(alg) if default: _default_kex_algs.append(alg) _kex_handlers[alg] = (handler, hash_alg, args) def register_gss_kex_alg(alg: bytes, handler: Type[Kex], hash_alg: HashType, args: Tuple, default: bool) -> None: """Register a GSSAPI key exchange algorithm""" _gss_kex_algs.append(alg) if default: _default_gss_kex_algs.append(alg) _gss_kex_handlers[alg] = (handler, hash_alg, args) def get_kex_algs() -> List[bytes]: """Return supported key exchange algorithms""" return _gss_kex_algs + _kex_algs def get_default_kex_algs() -> List[bytes]: """Return default key exchange algorithms""" return _default_gss_kex_algs + _default_kex_algs def expand_kex_algs(kex_algs: Sequence[bytes], mechs: Sequence[bytes], host_key_available: bool) -> List[bytes]: """Add mechanisms to GSS entries in key exchange algorithm list""" expanded_kex_algs: List[bytes] = [] for alg in kex_algs: if alg.startswith(b'gss-'): for mech in mechs: suffix = b'-' + binascii.b2a_base64(md5(mech).digest())[:-1] expanded_kex_algs.append(alg + suffix) elif host_key_available: expanded_kex_algs.append(alg) return expanded_kex_algs def get_kex(conn: 'SSHConnection', alg: bytes) -> Kex: """Return a key exchange handler The function looks up a key exchange algorithm and returns a handler which can perform that type of key exchange. """ if alg.startswith(b'gss-'): alg = alg.rsplit(b'-', 1)[0] handler, hash_alg, args = _gss_kex_handlers[alg] else: handler, hash_alg, args = _kex_handlers[alg] return handler(alg, conn, hash_alg, *args) asyncssh-2.20.0/asyncssh/kex_dh.py000066400000000000000000001012531475467777400171330ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH Diffie-Hellman, ECDH, and Edwards DH key exchange handlers""" from hashlib import sha1, sha224, sha256, sha384, sha512 from typing import TYPE_CHECKING, Callable, Mapping, Optional, cast from typing_extensions import Protocol from .constants import DEFAULT_LANG from .crypto import Curve25519DH, Curve448DH, DH, ECDH, PQDH from .crypto import curve25519_available, curve448_available from .crypto import mlkem_available, sntrup_available from .gss import GSSError from .kex import Kex, register_kex_alg, register_gss_kex_alg from .misc import HashType, KeyExchangeFailed, ProtocolError from .misc import get_symbol_names, run_in_executor from .packet import Boolean, MPInt, String, UInt32, SSHPacket from .public_key import SigningKey, VerifyingKey if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection, SSHClientConnection from .connection import SSHServerConnection class DHKey(Protocol): """Protocol for performing Diffie-Hellman key exchange""" def get_public(self) -> bytes: """Return the public key to send to the peer""" def get_shared_bytes(self, peer_public: bytes) -> bytes: """Return the shared key from the peer's public key in bytes""" def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" _ECDHClass = Callable[..., DHKey] # pylint: disable=line-too-long # SSH KEX DH message values MSG_KEXDH_INIT = 30 MSG_KEXDH_REPLY = 31 # SSH KEX DH group exchange message values MSG_KEX_DH_GEX_REQUEST_OLD = 30 MSG_KEX_DH_GEX_GROUP = 31 MSG_KEX_DH_GEX_INIT = 32 MSG_KEX_DH_GEX_REPLY = 33 MSG_KEX_DH_GEX_REQUEST = 34 # SSH KEX ECDH message values MSG_KEX_ECDH_INIT = 30 MSG_KEX_ECDH_REPLY = 31 # SSH KEXGSS message values MSG_KEXGSS_INIT = 30 MSG_KEXGSS_CONTINUE = 31 MSG_KEXGSS_COMPLETE = 32 MSG_KEXGSS_HOSTKEY = 33 MSG_KEXGSS_ERROR = 34 MSG_KEXGSS_GROUPREQ = 40 MSG_KEXGSS_GROUP = 41 # SSH KEX group exchange key sizes KEX_DH_GEX_MIN_SIZE = 1024 KEX_DH_GEX_PREFERRED_SIZE = 2048 KEX_DH_GEX_MAX_SIZE = 8192 # SSH Diffie-Hellman group 1 parameters _group1_g = 2 _group1_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece65381ffffffffffffffff # SSH Diffie-Hellman group 14 parameters _group14_g = 2 _group14_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aacaa68ffffffffffffffff # SSH Diffie-Hellman group 15 parameters _group15_g = 2 _group15_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a93ad2caffffffffffffffff # SSH Diffie-Hellman group 16 parameters _group16_g = 2 _group16_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c934063199ffffffffffffffff # SSH Diffie-Hellman group 17 parameters _group17_g = 2 _group17_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c93402849236c3fab4d27c7026c1d4dcb2602646dec9751e763dba37bdf8ff9406ad9e530ee5db382f413001aeb06a53ed9027d831179727b0865a8918da3edbebcf9b14ed44ce6cbaced4bb1bdb7f1447e6cc254b332051512bd7af426fb8f401378cd2bf5983ca01c64b92ecf032ea15d1721d03f482d7ce6e74fef6d55e702f46980c82b5a84031900b1c9e59e7c97fbec7e8f323a97a7e36cc88be0f1d45b7ff585ac54bd407b22b4154aacc8f6d7ebf48e1d814cc5ed20f8037e0a79715eef29be32806a1d58bb7c5da76f550aa3d8a1fbff0eb19ccb1a313d55cda56c9ec2ef29632387fe8d76e3c0468043e8f663f4860ee12bf2d5b0b7474d6e694f91e6dcc4024ffffffffffffffff # SSH Diffie-Hellman group 18 parameters _group18_g = 2 _group18_p = 0xffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed529077096966d670c354e4abc9804f1746c08ca18217c32905e462e36ce3be39e772c180e86039b2783a2ec07a28fb5c55df06f4c52c9de2bcbf6955817183995497cea956ae515d2261898fa051015728e5a8aaac42dad33170d04507a33a85521abdf1cba64ecfb850458dbef0a8aea71575d060c7db3970f85a6e1e4c7abf5ae8cdb0933d71e8c94e04a25619dcee3d2261ad2ee6bf12ffa06d98a0864d87602733ec86a64521f2b18177b200cbbe117577a615d6c770988c0bad946e208e24fa074e5ab3143db5bfce0fd108e4b82d120a92108011a723c12a787e6d788719a10bdba5b2699c327186af4e23c1a946834b6150bda2583e9ca2ad44ce8dbbbc2db04de8ef92e8efc141fbecaa6287c59474e6bc05d99b2964fa090c3a2233ba186515be7ed1f612970cee2d7afb81bdd762170481cd0069127d5b05aa993b4ea988d8fddc186ffb7dc90a6c08f4df435c93402849236c3fab4d27c7026c1d4dcb2602646dec9751e763dba37bdf8ff9406ad9e530ee5db382f413001aeb06a53ed9027d831179727b0865a8918da3edbebcf9b14ed44ce6cbaced4bb1bdb7f1447e6cc254b332051512bd7af426fb8f401378cd2bf5983ca01c64b92ecf032ea15d1721d03f482d7ce6e74fef6d55e702f46980c82b5a84031900b1c9e59e7c97fbec7e8f323a97a7e36cc88be0f1d45b7ff585ac54bd407b22b4154aacc8f6d7ebf48e1d814cc5ed20f8037e0a79715eef29be32806a1d58bb7c5da76f550aa3d8a1fbff0eb19ccb1a313d55cda56c9ec2ef29632387fe8d76e3c0468043e8f663f4860ee12bf2d5b0b7474d6e694f91e6dbe115974a3926f12fee5e438777cb6a932df8cd8bec4d073b931ba3bc832b68d9dd300741fa7bf8afc47ed2576f6936ba424663aab639c5ae4f5683423b4742bf1c978238f16cbe39d652de3fdb8befc848ad922222e04a4037c0713eb57a81a23f0c73473fc646cea306b4bcbc8862f8385ddfa9d4b7fa2c087e879683303ed5bdd3a062b3cf5b3a278a66d2a13f83f44f82ddf310ee074ab6a364597e899a0255dc164f31cc50846851df9ab48195ded7ea1b1d510bd7ee74d73faf36bc31ecfa268359046f4eb879f924009438b481c6cd7889a002ed5ee382bc9190da6fc026e479558e4475677e9aa9e3050e2765694dfc81f56e880b96e7160c980dd98edd3dfffffffffffffffff _dh_gex_groups = ((1024, _group1_g, _group1_p), (2048, _group14_g, _group14_p), (3072, _group15_g, _group15_p), (4096, _group16_g, _group16_p), (6144, _group17_g, _group17_p), (8192, _group18_g, _group18_p)) # pylint: enable=line-too-long class _KexDHBase(Kex): """Abstract base class for Diffie-Hellman key exchange""" _init_type: int = 0 _reply_type: int = 0 def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): super().__init__(alg, conn, hash_alg) self._dh: Optional[DH] = None self._g = 0 self._p = 0 self._e = 0 self._f = 0 self._gex_data = b'' def _init_group(self, g: int, p: int) -> None: """Initialize DH group parameters""" self._g = g self._p = p def _compute_hash(self, host_key_data: bytes, k: bytes) -> bytes: """Compute a hash of key information associated with the connection""" hash_obj = self._hash_alg() hash_obj.update(self._conn.get_hash_prefix()) hash_obj.update(String(host_key_data)) hash_obj.update(self._gex_data) hash_obj.update(self._format_client_key()) hash_obj.update(self._format_server_key()) hash_obj.update(k) return hash_obj.digest() def _parse_client_key(self, packet: SSHPacket) -> None: """Parse a DH client key""" if not self._p: raise ProtocolError('Kex DH p not specified') self._e = packet.get_mpint() def _parse_server_key(self, packet: SSHPacket) -> None: """Parse a DH server key""" if not self._p: raise ProtocolError('Kex DH p not specified') self._f = packet.get_mpint() def _format_client_key(self) -> bytes: """Format a DH client key""" return MPInt(self._e) def _format_server_key(self) -> bytes: """Format a DH server key""" return MPInt(self._f) def _send_init(self) -> None: """Send a DH init message""" self.send_packet(self._init_type, self._format_client_key()) def _send_reply(self, key_data: bytes, sig: bytes) -> None: """Send a DH reply message""" self.send_packet(self._reply_type, String(key_data), self._format_server_key(), String(sig)) def _perform_init(self) -> None: """Compute e and send init message""" self._dh = DH(self._g, self._p) self._e = self._dh.get_public() self._send_init() def _compute_client_shared(self) -> bytes: """Compute client shared key""" if not 1 <= self._f < self._p: raise ProtocolError('Kex DH f out of range') assert self._dh is not None return MPInt(self._dh.get_shared(self._f)) def _compute_server_shared(self) -> bytes: """Compute server shared key""" if not 1 <= self._e < self._p: raise ProtocolError('Kex DH e out of range') self._dh = DH(self._g, self._p) self._f = self._dh.get_public() return MPInt(self._dh.get_shared(self._e)) def _perform_reply(self, key: SigningKey, key_data: bytes) -> None: """Compute server shared key and send reply message""" k = self._compute_server_shared() h = self._compute_hash(key_data, k) self._send_reply(key_data, key.sign(h)) self._conn.send_newkeys(k, h) def _verify_reply(self, key: VerifyingKey, key_data: bytes, sig: bytes) -> None: """Verify a DH reply message""" k = self._compute_client_shared() h = self._compute_hash(key_data, k) if not key.verify(h, sig): raise KeyExchangeFailed('Key exchange hash mismatch') self._conn.send_newkeys(k, h) def _process_init(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a DH init message""" if self._conn.is_client(): raise ProtocolError('Unexpected kex init msg') self._parse_client_key(packet) packet.check_end() server_conn = cast('SSHServerConnection', self._conn) host_key = server_conn.get_server_host_key() assert host_key is not None self._perform_reply(host_key, host_key.public_data) def _process_reply(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a DH reply message""" if self._conn.is_server(): raise ProtocolError('Unexpected kex reply msg') host_key_data = packet.get_string() self._parse_server_key(packet) sig = packet.get_string() packet.check_end() client_conn = cast('SSHClientConnection', self._conn) host_key = client_conn.validate_server_host_key(host_key_data) self._verify_reply(host_key, host_key_data, sig) async def start(self) -> None: """Start DH key exchange""" if self._conn.is_client(): self._perform_init() class _KexDH(_KexDHBase): """Handler for Diffie-Hellman key exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEXDH_') _init_type = MSG_KEXDH_INIT _reply_type = MSG_KEXDH_REPLY def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, g: int, p: int): super().__init__(alg, conn, hash_alg) self._init_group(g, p) _packet_handlers: Mapping[int, Callable]= { MSG_KEXDH_INIT: _KexDHBase._process_init, MSG_KEXDH_REPLY: _KexDHBase._process_reply } class _KexDHGex(_KexDHBase): """Handler for Diffie-Hellman group exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEX_DH_GEX_') _init_type = MSG_KEX_DH_GEX_INIT _reply_type = MSG_KEX_DH_GEX_REPLY _request_type = MSG_KEX_DH_GEX_REQUEST _group_type = MSG_KEX_DH_GEX_GROUP def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, preferred_size: Optional[int] = None, max_size: Optional[int] = None): super().__init__(alg, conn, hash_alg) self._pref_size = preferred_size self._max_size = max_size def _send_request(self) -> None: """Send a DH gex request message""" if self._pref_size and not self._max_size: # Send old request message for unit test pkttype = MSG_KEX_DH_GEX_REQUEST_OLD args = UInt32(self._pref_size) else: pkttype = self._request_type args = (UInt32(KEX_DH_GEX_MIN_SIZE) + UInt32(self._pref_size or KEX_DH_GEX_PREFERRED_SIZE) + UInt32(self._max_size or KEX_DH_GEX_MAX_SIZE)) self._gex_data = args self.send_packet(pkttype, args) def _process_request(self, pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a DH gex request message""" if self._conn.is_client(): raise ProtocolError('Unexpected kex request msg') self._gex_data = packet.get_remaining_payload() if pkttype == MSG_KEX_DH_GEX_REQUEST_OLD: preferred_size = packet.get_uint32() max_size = KEX_DH_GEX_MAX_SIZE else: _ = packet.get_uint32() preferred_size = packet.get_uint32() max_size = packet.get_uint32() packet.check_end() g, p = _group1_g, _group1_p for gex_size, gex_g, gex_p in _dh_gex_groups: if gex_size > max_size: break else: g, p = gex_g, gex_p if gex_size >= preferred_size: break self._init_group(g, p) self._gex_data += MPInt(p) + MPInt(g) self.send_packet(self._group_type, MPInt(p), MPInt(g)) def _process_group(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a DH gex group message""" if self._conn.is_server(): raise ProtocolError('Unexpected kex group msg') p = packet.get_mpint() g = packet.get_mpint() packet.check_end() self._init_group(g, p) self._gex_data += MPInt(p) + MPInt(g) self._perform_init() async def start(self) -> None: """Start DH group exchange""" if self._conn.is_client(): self._send_request() _packet_handlers: Mapping[int, Callable] = { MSG_KEX_DH_GEX_REQUEST_OLD: _process_request, MSG_KEX_DH_GEX_GROUP: _process_group, MSG_KEX_DH_GEX_INIT: _KexDHBase._process_init, MSG_KEX_DH_GEX_REPLY: _KexDHBase._process_reply, MSG_KEX_DH_GEX_REQUEST: _process_request } class _KexECDH(_KexDHBase): """Handler for elliptic curve Diffie-Hellman key exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEX_ECDH_') _init_type = MSG_KEX_ECDH_INIT _reply_type = MSG_KEX_ECDH_REPLY def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, ecdh_class: _ECDHClass, *args: object): super().__init__(alg, conn, hash_alg) self._priv = ecdh_class(*args) pub = self._priv.get_public() if conn.is_client(): self._client_pub = pub else: self._server_pub = pub def _parse_client_key(self, packet: SSHPacket) -> None: """Parse an ECDH client key""" self._client_pub = packet.get_string() def _parse_server_key(self, packet: SSHPacket) -> None: """Parse an ECDH server key""" self._server_pub = packet.get_string() def _format_client_key(self) -> bytes: """Format an ECDH client key""" return String(self._client_pub) def _format_server_key(self) -> bytes: """Format an ECDH server key""" return String(self._server_pub) def _compute_client_shared(self) -> bytes: """Compute client shared key""" try: return MPInt(self._priv.get_shared(self._server_pub)) except ValueError: raise ProtocolError('Invalid ECDH server public key') from None def _compute_server_shared(self) -> bytes: """Compute server shared key""" try: return MPInt(self._priv.get_shared(self._client_pub)) except ValueError: raise ProtocolError('Invalid ECDH client public key') from None async def start(self) -> None: """Start ECDH key exchange""" if self._conn.is_client(): self._send_init() _packet_handlers: Mapping[int, Callable] = { MSG_KEX_ECDH_INIT: _KexDHBase._process_init, MSG_KEX_ECDH_REPLY: _KexDHBase._process_reply } class _KexHybridECDH(_KexECDH): """Handler for post-quantum key exchange""" def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, pq_alg_name: bytes, ecdh_class: _ECDHClass, *args: object): super().__init__(alg, conn, hash_alg, ecdh_class, *args) self._pq = PQDH(pq_alg_name) if conn.is_client(): pq_pub, self._pq_priv = self._pq.keypair() self._client_pub = pq_pub + self._client_pub def _compute_client_shared(self) -> bytes: """Compute client shared key""" pq_ciphertext = self._server_pub[:self._pq.ciphertext_bytes] ec_pub = self._server_pub[self._pq.ciphertext_bytes:] try: pq_secret = self._pq.decaps(pq_ciphertext, self._pq_priv) except ValueError: raise ProtocolError('Invalid PQ server ciphertext') from None try: ec_shared = self._priv.get_shared_bytes(ec_pub) except ValueError: raise ProtocolError('Invalid ECDH server public key') from None return String(self._hash_alg(pq_secret + ec_shared).digest()) def _compute_server_shared(self) -> bytes: """Compute server shared key""" pq_pub = self._client_pub[:self._pq.pubkey_bytes] ec_pub = self._client_pub[self._pq.pubkey_bytes:] try: pq_secret, pq_ciphertext = self._pq.encaps(pq_pub) except ValueError: raise ProtocolError('Invalid PQ client public key') from None try: ec_shared = self._priv.get_shared_bytes(ec_pub) except ValueError: raise ProtocolError('Invalid ECDH client public key') from None self._server_pub = pq_ciphertext + self._server_pub return String(self._hash_alg(pq_secret + ec_shared).digest()) class _KexGSSBase(_KexDHBase): """Handler for GSS key exchange""" def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, *args: object): super().__init__(alg, conn, hash_alg, *args) self._gss = conn.get_gss_context() self._token: Optional[bytes] = None self._host_key_data = b'' def _check_secure(self) -> None: """Check that GSS context is secure enough for key exchange""" if (not self._gss.provides_mutual_auth or not self._gss.provides_integrity): raise ProtocolError('GSS context not secure') def _send_init(self) -> None: """Send a GSS init message""" if not self._token: raise ProtocolError('Empty GSS token in init') self.send_packet(MSG_KEXGSS_INIT, String(self._token), self._format_client_key()) def _send_reply(self, key_data: bytes, sig: bytes) -> None: """Send a GSS reply message""" if self._token: token_data = Boolean(True) + String(self._token) else: token_data = Boolean(False) self.send_packet(MSG_KEXGSS_COMPLETE, self._format_server_key(), String(sig), token_data) def _send_continue(self) -> None: """Send a GSS continue message""" if not self._token: raise ProtocolError('Empty GSS token in continue') self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token)) async def _process_token(self, token: Optional[bytes] = None) -> None: """Process a GSS token""" try: self._token = await run_in_executor(self._gss.step, token) except GSSError as exc: if self._conn.is_server(): self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code), UInt32(exc.min_code), String(str(exc)), String(DEFAULT_LANG)) if exc.token: self.send_packet(MSG_KEXGSS_CONTINUE, String(exc.token)) raise KeyExchangeFailed(str(exc)) from None async def _process_gss_init(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS init message""" if self._conn.is_client(): raise ProtocolError('Unexpected kexgss init msg') token = packet.get_string() self._parse_client_key(packet) packet.check_end() server_conn = cast('SSHServerConnection', self._conn) host_key = server_conn.get_server_host_key() if host_key: self._host_key_data = host_key.public_data self.send_packet(MSG_KEXGSS_HOSTKEY, String(self._host_key_data)) else: self._host_key_data = b'' await self._process_token(token) if self._gss.complete: self._check_secure() self._perform_reply(self._gss, self._host_key_data) self._conn.enable_gss_kex_auth() else: self._send_continue() async def _process_continue(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS continue message""" token = packet.get_string() packet.check_end() if self._conn.is_client() and self._gss.complete: raise ProtocolError('Unexpected kexgss continue msg') await self._process_token(token) if self._conn.is_server() and self._gss.complete: self._check_secure() self._perform_reply(self._gss, self._host_key_data) else: self._send_continue() async def _process_complete(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS complete message""" if self._conn.is_server(): raise ProtocolError('Unexpected kexgss complete msg') self._parse_server_key(packet) mic = packet.get_string() token_present = packet.get_boolean() token = packet.get_string() if token_present else None packet.check_end() if token: if self._gss.complete: raise ProtocolError('Non-empty token after complete') await self._process_token(token) if self._token: raise ProtocolError('Non-empty token after complete') if not self._gss.complete: raise ProtocolError('GSS exchange failed to complete') self._check_secure() self._verify_reply(self._gss, self._host_key_data, mic) self._conn.enable_gss_kex_auth() def _process_hostkey(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS hostkey message""" self._host_key_data = packet.get_string() packet.check_end() def _process_error(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS error message""" if self._conn.is_server(): raise ProtocolError('Unexpected kexgss error msg') _ = packet.get_uint32() # major_status _ = packet.get_uint32() # minor_status msg = packet.get_string() _ = packet.get_string() # lang packet.check_end() self._conn.logger.debug1('GSS error: %s', msg.decode('utf-8', errors='ignore')) async def start(self) -> None: """Start GSS key exchange""" if self._conn.is_client(): await self._process_token() await super().start() class _KexGSS(_KexGSSBase, _KexDH): """Handler for GSS key exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, MSG_KEXGSS_ERROR: _KexGSSBase._process_error } class _KexGSSGex(_KexGSSBase, _KexDHGex): """Handler for GSS group exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _request_type = MSG_KEXGSS_GROUPREQ _group_type = MSG_KEXGSS_GROUP _packet_handlers = { MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, MSG_KEXGSS_ERROR: _KexGSSBase._process_error, MSG_KEXGSS_GROUPREQ: _KexDHGex._process_request, MSG_KEXGSS_GROUP: _KexDHGex._process_group } class _KexGSSECDH(_KexGSSBase, _KexECDH): """Handler for GSS ECDH key exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, MSG_KEXGSS_ERROR: _KexGSSBase._process_error } if mlkem_available: # pragma: no branch if curve25519_available: # pragma: no branch register_kex_alg(b'mlkem768x25519-sha256', _KexHybridECDH, sha256, (b'mlkem768', Curve25519DH), True) register_kex_alg(b'mlkem768nistp256-sha256', _KexHybridECDH, sha256, (b'mlkem768', ECDH, b'nistp256'), True) register_kex_alg(b'mlkem1024nistp384-sha384', _KexHybridECDH, sha384, (b'mlkem1024', ECDH, b'nistp384'), True) if curve25519_available: # pragma: no branch if sntrup_available: # pragma: no branch register_kex_alg(b'sntrup761x25519-sha512', _KexHybridECDH, sha512, (b'sntrup761', Curve25519DH), True) register_kex_alg(b'sntrup761x25519-sha512@openssh.com', _KexHybridECDH, sha512, (b'sntrup761', Curve25519DH), True) register_kex_alg(b'curve25519-sha256', _KexECDH, sha256, (Curve25519DH,), True) register_kex_alg(b'curve25519-sha256@libssh.org', _KexECDH, sha256, (Curve25519DH,), True) register_gss_kex_alg(b'gss-curve25519-sha256', _KexGSSECDH, sha256, (Curve25519DH,), True) if curve448_available: # pragma: no branch register_kex_alg(b'curve448-sha512', _KexECDH, sha512, (Curve448DH,), True) register_gss_kex_alg(b'gss-curve448-sha512', _KexGSSECDH, sha512, (Curve448DH,), True) for _curve_id, _hash_name, _hash_alg, _default in ( (b'nistp521', b'sha512', sha512, True), (b'nistp384', b'sha384', sha384, True), (b'nistp256', b'sha256', sha256, True), (b'1.3.132.0.10', b'sha256', sha256, True)): register_kex_alg(b'ecdh-sha2-' + _curve_id, _KexECDH, _hash_alg, (ECDH, _curve_id), _default) register_gss_kex_alg(b'gss-' + _curve_id + b'-' + _hash_name, _KexGSSECDH, _hash_alg, (ECDH, _curve_id), _default) for _hash_name, _hash_alg, _default in ( (b'sha256', sha256, True), (b'sha224@ssh.com', sha224, False), (b'sha384@ssh.com', sha384, False), (b'sha512@ssh.com', sha512, False), (b'sha1', sha1, False)): register_kex_alg(b'diffie-hellman-group-exchange-' + _hash_name, _KexDHGex, _hash_alg, (), _default) if not _hash_name.endswith(b'@ssh.com'): register_gss_kex_alg(b'gss-gex-' + _hash_name, _KexGSSGex, _hash_alg, (), _default) for _name, _hash_alg, _g, _p, _default in ( (b'group14-sha256', sha256, _group14_g, _group14_p, True), (b'group15-sha512', sha512, _group15_g, _group15_p, True), (b'group16-sha512', sha512, _group16_g, _group16_p, True), (b'group17-sha512', sha512, _group17_g, _group17_p, True), (b'group18-sha512', sha512, _group18_g, _group18_p, True), (b'group14-sha256@ssh.com', sha256, _group14_g, _group14_p, True), (b'group14-sha224@ssh.com', sha224, _group14_g, _group14_p, False), (b'group15-sha256@ssh.com', sha256, _group15_g, _group15_p, False), (b'group15-sha384@ssh.com', sha384, _group15_g, _group15_p, False), (b'group16-sha384@ssh.com', sha384, _group16_g, _group16_p, False), (b'group16-sha512@ssh.com', sha512, _group16_g, _group16_p, False), (b'group18-sha512@ssh.com', sha512, _group18_g, _group18_p, False), (b'group14-sha1', sha1, _group14_g, _group14_p, True), (b'group1-sha1', sha1, _group1_g, _group1_p, False)): register_kex_alg(b'diffie-hellman-' + _name, _KexDH, _hash_alg, (_g, _p), _default) if not _name.endswith(b'@ssh.com'): register_gss_kex_alg(b'gss-' + _name, _KexGSS, _hash_alg, (_g, _p), _default) asyncssh-2.20.0/asyncssh/kex_rsa.py000066400000000000000000000135251475467777400173310ustar00rootroot00000000000000# Copyright (c) 2018-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """RSA key exchange handler""" from hashlib import sha1, sha256 from typing import TYPE_CHECKING, Optional, cast from .kex import Kex, register_kex_alg from .misc import HashType, KeyExchangeFailed, ProtocolError from .misc import get_symbol_names, randrange from .packet import MPInt, String, SSHPacket from .public_key import KeyImportError, SSHKey from .public_key import decode_ssh_public_key, generate_private_key from .rsa import RSAKey if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection, SSHClientConnection from .connection import SSHServerConnection # SSH KEXRSA message values MSG_KEXRSA_PUBKEY = 30 MSG_KEXRSA_SECRET = 31 MSG_KEXRSA_DONE = 32 class _KexRSA(Kex): """Handler for RSA key exchange""" _handler_names = get_symbol_names(globals(), 'MSG_KEXRSA') def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, key_size: int, hash_size: int): super().__init__(alg, conn, hash_alg) self._key_size = key_size self._k_limit = 1 << (key_size - 2*hash_size - 49) self._host_key_data = b'' self._trans_key: Optional[SSHKey] = None self._trans_key_data = b'' self._k = 0 self._encrypted_k = b'' async def start(self) -> None: """Start RSA key exchange""" if self._conn.is_server(): server_conn = cast('SSHServerConnection', self._conn) host_key = server_conn.get_server_host_key() assert host_key is not None self._host_key_data = host_key.public_data self._trans_key = generate_private_key( 'ssh-rsa', key_size=self._key_size) self._trans_key_data = self._trans_key.public_data self.send_packet(MSG_KEXRSA_PUBKEY, String(self._host_key_data), String(self._trans_key_data)) def _compute_hash(self) -> bytes: """Compute a hash of key information associated with the connection""" hash_obj = self._hash_alg() hash_obj.update(self._conn.get_hash_prefix()) hash_obj.update(String(self._host_key_data)) hash_obj.update(String(self._trans_key_data)) hash_obj.update(String(self._encrypted_k)) hash_obj.update(MPInt(self._k)) return hash_obj.digest() def _process_pubkey(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a KEXRSA pubkey message""" if self._conn.is_server(): raise ProtocolError('Unexpected KEXRSA pubkey msg') self._host_key_data = packet.get_string() self._trans_key_data = packet.get_string() packet.check_end() try: pubkey = decode_ssh_public_key(self._trans_key_data) except KeyImportError: raise ProtocolError('Invalid KEXRSA pubkey msg') from None trans_key = cast(RSAKey, pubkey) self._k = randrange(self._k_limit) self._encrypted_k = \ cast(bytes, trans_key.encrypt(MPInt(self._k), self.algorithm)) self.send_packet(MSG_KEXRSA_SECRET, String(self._encrypted_k)) def _process_secret(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a KEXRSA secret message""" if self._conn.is_client(): raise ProtocolError('Unexpected KEXRSA secret msg') self._encrypted_k = packet.get_string() packet.check_end() trans_key = cast(RSAKey, self._trans_key) decrypted_k = trans_key.decrypt(self._encrypted_k, self.algorithm) if not decrypted_k: raise KeyExchangeFailed('Key exchange decryption failed') packet = SSHPacket(decrypted_k) self._k = packet.get_mpint() packet.check_end() server_conn = cast('SSHServerConnection', self._conn) host_key = server_conn.get_server_host_key() assert host_key is not None h = self._compute_hash() sig = host_key.sign(h) self.send_packet(MSG_KEXRSA_DONE, String(sig)) self._conn.send_newkeys(MPInt(self._k), h) def _process_done(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a KEXRSA done message""" if self._conn.is_server(): raise ProtocolError('Unexpected KEXRSA done msg') sig = packet.get_string() packet.check_end() client_conn = cast('SSHClientConnection', self._conn) host_key = client_conn.validate_server_host_key(self._host_key_data) h = self._compute_hash() if not host_key.verify(h, sig): raise KeyExchangeFailed('Key exchange hash mismatch') self._conn.send_newkeys(MPInt(self._k), h) _packet_handlers = { MSG_KEXRSA_PUBKEY: _process_pubkey, MSG_KEXRSA_SECRET: _process_secret, MSG_KEXRSA_DONE: _process_done } for _name, _hash_alg, _key_size, _hash_size, _default in ( (b'rsa2048-sha256', sha256, 2048, 256, True), (b'rsa1024-sha1', sha1, 1024, 160, False)): register_kex_alg(_name, _KexRSA, _hash_alg, (_key_size, _hash_size), _default) asyncssh-2.20.0/asyncssh/keysign.py000066400000000000000000000072451475467777400173500ustar00rootroot00000000000000# Copyright (c) 2018-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH keysign client""" import asyncio from pathlib import Path import subprocess from typing import Iterable, Sequence, Union, cast from .misc import FilePath from .packet import Byte, String, UInt32, PacketDecodeError, SSHPacket from .public_key import SSHKey, SSHKeyPair, SSHCertificate _KeySignKey = Union[SSHKey, SSHCertificate] KeySignPath = Union[None, bool, FilePath] KEYSIGN_VERSION = 2 _DEFAULT_KEYSIGN_DIRS = ('/opt/local/libexec', '/usr/local/libexec', '/usr/libexec', '/usr/libexec/openssh', '/usr/lib/openssh') class SSHKeySignKeyPair(SSHKeyPair): """Surrogate for a key where signing is done via ssh-keysign""" def __init__(self, keysign_path: str, sock_fd: int, key_or_cert: _KeySignKey): algorithm = key_or_cert.algorithm sig_algorithms = key_or_cert.sig_algorithms[:1] public_data = key_or_cert.public_data comment = key_or_cert.get_comment_bytes() super().__init__(algorithm, algorithm, sig_algorithms, sig_algorithms, public_data, comment) self._keysign_path = keysign_path self._sock_fd = sock_fd async def sign_async(self, data: bytes) -> bytes: """Use ssh-keysign to sign a block of data with this key""" proc = await asyncio.create_subprocess_exec( self._keysign_path, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, pass_fds=[self._sock_fd]) request = String(Byte(KEYSIGN_VERSION) + UInt32(self._sock_fd) + String(data)) stdout, stderr = await proc.communicate(request) if stderr: error = stderr.decode().strip() raise ValueError(error) try: packet = SSHPacket(stdout) resp = packet.get_string() packet.check_end() packet = SSHPacket(resp) version = packet.get_byte() sig = packet.get_string() packet.check_end() if version != KEYSIGN_VERSION: raise ValueError('unexpected version') return sig except PacketDecodeError: raise ValueError('invalid response') from None def find_keysign(path: KeySignPath) -> str: """Return path to ssh-keysign executable""" if path is True: for keysign_dir in _DEFAULT_KEYSIGN_DIRS: path = Path(keysign_dir, 'ssh-keysign') if path.exists(): break else: raise ValueError('Keysign not found') else: if not path or not Path(cast(FilePath, path)).exists(): raise ValueError('Keysign not found') return str(path) def get_keysign_keys(keysign_path: str, sock_fd: int, keys: Iterable[_KeySignKey]) -> \ Sequence[SSHKeySignKeyPair]: """Return keypair objects which invoke ssh-keysign""" return [SSHKeySignKeyPair(keysign_path, sock_fd, key) for key in keys] asyncssh-2.20.0/asyncssh/known_hosts.py000066400000000000000000000332731475467777400202530ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Alexander Travov - proposed changes to add negated patterns, hashed # entries, and support for the revoked marker # Josh Yudaken - proposed change to split parsing and matching to avoid # parsing large known_hosts lists multiple times """Parser for SSH known_hosts files""" import binascii from hashlib import sha1 import hmac from typing import Callable, Dict, List, Optional from typing import Sequence, Tuple, Union, cast try: from .crypto import X509NamePattern _x509_available = True except ImportError: # pragma: no cover _x509_available = False from .misc import IPAddress, ip_address, read_file from .pattern import HostPatternList from .public_key import KeyImportError from .public_key import SSHKey, SSHCertificate, SSHX509Certificate from .public_key import import_public_key, import_certificate from .public_key import import_certificate_subject from .public_key import load_public_keys, load_certificates _HostPattern = Union['_PlainHost', '_HashedHost'] _HostEntry = Tuple[Optional[str], Optional[SSHKey], Optional[SSHX509Certificate], Optional['X509NamePattern']] _KnownHostsKeys = Sequence[SSHKey] _KnownHostsCerts = Sequence[SSHX509Certificate] _KnownHostsNames = Sequence['X509NamePattern'] _KnownHostsResult = Tuple[_KnownHostsKeys, _KnownHostsKeys, _KnownHostsKeys, _KnownHostsCerts, _KnownHostsCerts, _KnownHostsNames, _KnownHostsNames] _KnownHostsCallable = Callable[[str, str, Optional[int]], Sequence[str]] _KnownHostsListArg = Union[str, Sequence[str], 'X509NamePattern'] KnownHostsArg = Union[None, str, bytes, _KnownHostsCallable, 'SSHKnownHosts', _KnownHostsResult, Sequence[_KnownHostsListArg]] def _load_subject_names(names: Sequence[str]) -> Sequence['X509NamePattern']: """Load a list of X.509 subject name patterns""" if not _x509_available: # pragma: no cover return [] return list(map(X509NamePattern, names)) class _PlainHost: """A plain host entry in a known_hosts file""" def __init__(self, pattern: str): self._pattern = HostPatternList(pattern) def matches(self, host: str, addr: str, ip: Optional[IPAddress]) -> bool: """Return whether a host or address matches this host pattern list""" return self._pattern.matches(host, addr, ip) class _HashedHost: """A hashed host entry in a known_hosts file""" _HMAC_SHA1_MAGIC = '1' def __init__(self, pattern: str): try: magic, salt, hosthash = pattern[1:].split('|') self._salt = binascii.a2b_base64(salt) self._hosthash = binascii.a2b_base64(hosthash) except (ValueError, binascii.Error): raise ValueError( f'Invalid known hosts hash entry: {pattern}') from None if magic != self._HMAC_SHA1_MAGIC: # Only support HMAC SHA-1 for now raise ValueError( f'Invalid known hosts hash type: {magic}') from None def _match(self, value: str) -> bool: """Return whether this host hash matches a value""" hosthash = hmac.new(self._salt, value.encode(), sha1).digest() return hosthash == self._hosthash def matches(self, host: str, addr: str, _ip: Optional[IPAddress]) -> bool: """Return whether a host or address matches this host hash""" return self._match(host) or self._match(addr) class SSHKnownHosts: """An SSH known hosts list""" def __init__(self, known_hosts: Optional[str] = None): self._exact_entries: Dict[Optional[str], List[_HostEntry]] = {} self._pattern_entries: List[Tuple[_HostPattern, _HostEntry]] = [] if known_hosts: self.load(known_hosts) def load(self, known_hosts: str) -> None: """Load known hosts data into this object""" for line in known_hosts.splitlines(): line = line.strip() if not line or line.startswith('#'): continue marker: Optional[str] try: if line.startswith('@'): marker, pattern, data = line[1:].split(None, 2) else: marker = None pattern, data = line.split(None, 1) except ValueError: raise ValueError( f'Invalid known hosts entry: {line}') from None if marker not in (None, 'cert-authority', 'revoked'): raise ValueError( f'Invalid known hosts marker: {marker}') from None key: Optional[SSHKey] = None cert: Optional[SSHCertificate] = None subject: Optional['X509NamePattern'] = None try: key = import_public_key(data) except KeyImportError: try: cert = import_certificate(data) except KeyImportError: if not _x509_available: # pragma: no cover continue try: subject_text = import_certificate_subject(data) except KeyImportError: # Ignore keys in the file that we're unable to parse continue subject = X509NamePattern(subject_text) entry = (marker, key, cast(SSHX509Certificate, cert), subject) if any(c in pattern for c in '*?|/!'): self._add_pattern(pattern, entry) else: self._add_exact(pattern, entry) def _add_exact(self, pattern: str, entry: _HostEntry) -> None: """Add an exact match entry""" for host_pat in pattern.split(','): if host_pat not in self._exact_entries: self._exact_entries[host_pat] = [] self._exact_entries[host_pat].append(entry) def _add_pattern(self, pattern: str, entry: _HostEntry) -> None: """Add a pattern match entry""" if pattern.startswith('|'): host_pat: _HostPattern = _HashedHost(pattern) else: host_pat = _PlainHost(pattern) self._pattern_entries.append((host_pat, entry)) def _match(self, host: str, addr: str, port: Optional[int] = None) -> _KnownHostsResult: """Find host keys matching specified host, address, and port""" if addr: ip: Optional[IPAddress] = ip_address(addr) else: try: ip = ip_address(host) except ValueError: ip = None if port: host = f'[{host}]:{port}' if host else '' addr = f'[{addr}]:{port}' if addr else '' matches = [] matches += self._exact_entries.get(host, []) matches += self._exact_entries.get(addr, []) matches += (match for (entry, match) in self._pattern_entries if entry.matches(host, addr, ip)) host_keys: List[SSHKey] = [] ca_keys: List[SSHKey] = [] revoked_keys: List[SSHKey] = [] x509_certs: List[SSHX509Certificate] = [] revoked_certs: List[SSHX509Certificate] = [] x509_subjects: List['X509NamePattern'] = [] revoked_subjects: List['X509NamePattern'] = [] for marker, key, cert, subject in matches: if key: if marker == 'revoked': revoked_keys.append(key) elif marker == 'cert-authority': ca_keys.append(key) else: host_keys.append(key) elif cert: if marker == 'revoked': revoked_certs.append(cert) else: x509_certs.append(cert) else: assert subject is not None if marker == 'revoked': revoked_subjects.append(subject) else: x509_subjects.append(subject) return (host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, x509_subjects, revoked_subjects) def match(self, host: str, addr: str, port: Optional[int]) -> _KnownHostsResult: """Match a host, IP address, and port against known_hosts patterns If the port is not the default port and no match is found for it, the lookup is attempted again without a port number. :param host: The hostname of the target host :param addr: The IP address of the target host :param port: The port number on the target host, or `None` for the default :type host: `str` :type addr: `str` :type port: `int` :returns: A tuple of matching host keys, CA keys, and revoked keys """ host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, \ x509_subjects, revoked_subjects = self._match(host, addr, port) if port and not (host_keys or ca_keys or x509_certs or x509_subjects): host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, \ x509_subjects, revoked_subjects = self._match(host, addr) return (host_keys, ca_keys, revoked_keys, x509_certs, revoked_certs, x509_subjects, revoked_subjects) def import_known_hosts(data: str) -> SSHKnownHosts: """Import SSH known hosts This function imports known host patterns and keys in OpenSSH known hosts format. :param data: The known hosts data to import :type data: `str` :returns: An :class:`SSHKnownHosts` object """ return SSHKnownHosts(data) def read_known_hosts(filelist: Union[str, Sequence[str]]) -> SSHKnownHosts: """Read SSH known hosts from a file or list of files This function reads known host patterns and keys in OpenSSH known hosts format from a file or list of files. :param filelist: The file or list of files to read the known hosts from :type filelist: `str` or `list` of `str` :returns: An :class:`SSHKnownHosts` object """ known_hosts = SSHKnownHosts() if isinstance(filelist, str): filelist = [filelist] for filename in filelist: known_hosts.load(read_file(filename, 'r')) return known_hosts def match_known_hosts(known_hosts: KnownHostsArg, host: str, addr: str, port: Optional[int]) -> _KnownHostsResult: """Match a host, IP address, and port against a known_hosts list This function looks up a host, IP address, and port in a list of host patterns in OpenSSH `known_hosts` format and returns the host keys, CA keys, and revoked keys which match. The `known_hosts` argument can be any of the following: * a string containing the filename to load host patterns from * a byte string containing host pattern data to load * an already loaded :class:`SSHKnownHosts` object containing host patterns to match against * an alternate matching function which accepts a host, address, and port and returns lists of trusted host keys, trusted CA keys, and revoked keys to load * lists of trusted host keys, trusted CA keys, and revoked keys to load without doing any matching If the port is not the default port and no match is found for it, the lookup is attempted again without a port number. :param known_hosts: The host patterns to match against :param host: The hostname of the target host :param addr: The IP address of the target host :param port: The port number on the target host, or `None` for the default :type host: `str` :type addr: `str` :type port: `int` :returns: A tuple of matching host keys, CA keys, and revoked keys """ if isinstance(known_hosts, str) or \ (known_hosts and isinstance(known_hosts, list) and isinstance(known_hosts[0], str)): known_hosts = read_known_hosts(known_hosts) elif isinstance(known_hosts, bytes): known_hosts = import_known_hosts(known_hosts.decode()) if isinstance(known_hosts, SSHKnownHosts): known_hosts = known_hosts.match(host, addr, port) else: if callable(known_hosts): known_hosts = known_hosts(host, addr, port) result = cast(Sequence[str], known_hosts) result = (tuple(map(load_public_keys, result[:3])) + tuple(map(load_certificates, result[3:5])) + tuple(map(_load_subject_names, result[5:7]))) if len(result) == 3: # Provide backward compatibility for pre-X.509 releases result += ((), (), (), ()) known_hosts = cast(_KnownHostsResult, result) for cert in list(known_hosts[3]) + list(known_hosts[4]): if not cert.is_x509: raise ValueError('OpenSSH certificates not ' 'allowed in known hosts') from None return known_hosts asyncssh-2.20.0/asyncssh/listener.py000066400000000000000000000315741475467777400175260ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH listeners""" import asyncio import errno import socket from types import TracebackType from typing import TYPE_CHECKING, AnyStr, Callable, Generic, List, Optional from typing import Sequence, Set, Tuple, Type, Union from typing_extensions import Self from .forward import SSHForwarderCoro from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder from .misc import HostPort, MaybeAwait from .session import SSHTCPSession, SSHUNIXSession from .socks import SSHSOCKSForwarder if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHTCPChannel, SSHUNIXChannel from .connection import SSHConnection, SSHClientConnection _LocalListenerFactory = Callable[[], asyncio.BaseProtocol] ListenKey = Union[HostPort, str] TCPListenerFactory = Callable[[str, int], MaybeAwait[SSHTCPSession[AnyStr]]] UNIXListenerFactory = Callable[[], MaybeAwait[SSHUNIXSession[AnyStr]]] class SSHListener: """SSH listener for inbound connections""" def __init__(self) -> None: self._tunnel: Optional['SSHConnection'] = None async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: self.close() await self.wait_closed() return False def get_port(self) -> int: """Return the port number being listened on This method returns the port number that the remote listener was bound to. When the requested remote listening port is `0` to indicate a dynamic port, this method can be called to determine what listening port was selected. This function only applies to TCP listeners. :returns: The port number being listened on """ # pylint: disable=no-self-use return 0 def set_tunnel(self, tunnel: 'SSHConnection') -> None: """Set tunnel associated with listener""" self._tunnel = tunnel def close(self) -> None: """Stop listening for new connections This method can be called to stop listening for connections. Existing connections will remain open. """ if self._tunnel: self._tunnel.close() async def wait_closed(self) -> None: """Wait for the listener to close This method is a coroutine which waits for the associated listeners to be closed. """ if self._tunnel: await self._tunnel.wait_closed() self._tunnel = None class SSHClientListener(SSHListener): """Client listener used to accept inbound forwarded connections""" def __init__(self, conn: 'SSHClientConnection', encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__() self._conn: Optional['SSHClientConnection'] = conn self._encoding = encoding self._errors = errors self._window = window self._max_pktsize = max_pktsize self._close_event = asyncio.Event() async def _close(self) -> None: """Close this listener""" self._close_event.set() self._conn = None def close(self) -> None: """Close this listener asynchronously""" super().close() if self._conn: self._conn.create_task(self._close()) async def wait_closed(self) -> None: """Wait for this listener to finish closing""" await super().wait_closed() await self._close_event.wait() class SSHTCPClientListener(SSHClientListener, Generic[AnyStr]): """Client listener used to accept inbound forwarded TCP connections""" def __init__(self, conn: 'SSHClientConnection', session_factory: TCPListenerFactory[AnyStr], listen_host: str, listen_port: int, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, encoding, errors, window, max_pktsize) self._session_factory: TCPListenerFactory[AnyStr] = session_factory self._listen_host = listen_host self._listen_port = listen_port async def _close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch await self._conn.close_client_tcp_listener(self._listen_host, self._listen_port) await super()._close() def process_connection(self, orig_host: str, orig_port: int) -> \ Tuple['SSHTCPChannel[AnyStr]', MaybeAwait[SSHTCPSession[AnyStr]]]: """Process a forwarded TCP connection""" assert self._conn is not None chan = self._conn.create_tcp_channel(self._encoding, self._errors, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_host, self._listen_port, orig_host, orig_port) return chan, self._session_factory(orig_host, orig_port) def get_addresses(self) -> List[Tuple]: """Return the socket addresses being listened on""" return [(self._listen_host, self._listen_port)] def get_port(self) -> int: """Return the port number being listened on""" return self._listen_port class SSHUNIXClientListener(SSHClientListener, Generic[AnyStr]): """Client listener used to accept inbound forwarded UNIX connections""" def __init__(self, conn: 'SSHClientConnection', session_factory: UNIXListenerFactory[AnyStr], listen_path: str, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, encoding, errors, window, max_pktsize) self._session_factory: UNIXListenerFactory[AnyStr] = session_factory self._listen_path = listen_path async def _close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch await self._conn.close_client_unix_listener(self._listen_path) await super()._close() def process_connection(self) -> \ Tuple['SSHUNIXChannel[AnyStr]', MaybeAwait[SSHUNIXSession[AnyStr]]]: """Process a forwarded UNIX connection""" assert self._conn is not None chan = self._conn.create_unix_channel(self._encoding, self._errors, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_path) return chan, self._session_factory() class SSHForwardListener(SSHListener): """A listener used when forwarding traffic from local ports""" def __init__(self, conn: 'SSHConnection', servers: Sequence[asyncio.AbstractServer], listen_key: ListenKey, listen_port: int = 0): super().__init__() self._conn: Optional['SSHConnection'] = conn self._servers = servers self._listen_key = listen_key self._listen_port = listen_port def get_port(self) -> int: """Return the port number being listened on""" return self._listen_port def close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch self._conn.close_forward_listener(self._listen_key) for server in self._servers: server.close() self._conn = None async def wait_closed(self) -> None: """Wait for this listener to finish closing""" await super().wait_closed() for server in self._servers: await server.wait_closed() self._servers = [] async def create_tcp_local_listener( conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, protocol_factory: _LocalListenerFactory, listen_host: str, listen_port: int) -> 'SSHForwardListener': """Create a listener to forward traffic from a local TCP port over SSH""" addrinfo = await loop.getaddrinfo(listen_host or None, listen_port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE) if not addrinfo: # pragma: no cover raise OSError('getaddrinfo() returned empty list') seen_addrinfo: Set[Tuple] = set() servers: List[asyncio.AbstractServer] = [] for addrinfo_entry in addrinfo: # Work around an issue where getaddrinfo() on some systems may # return duplicate results, causing bind to fail. if addrinfo_entry in seen_addrinfo: # pragma: no cover continue seen_addrinfo.add(addrinfo_entry) family, socktype, proto, _, sa = addrinfo_entry try: sock = socket.socket(family, socktype, proto) except OSError: # pragma: no cover continue sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) if family == socket.AF_INET6: try: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) except AttributeError: # pragma: no cover pass if sa[1] == 0: sa = sa[:1] + (listen_port,) + sa[2:] # type: ignore try: sock.bind(sa) except (OSError, OverflowError) as exc: sock.close() for server in servers: server.close() if isinstance(exc, OverflowError): # pragma: no cover exc.errno = errno.EOVERFLOW # type: ignore exc.strerror = str(exc) # type: ignore # pylint: disable=no-member raise OSError(exc.errno, f'error while attempting ' # type: ignore f'to bind on address {sa!r}: ' f'{exc.strerror}') from None # type: ignore if listen_port == 0: listen_port = sock.getsockname()[1] conn.logger.debug1('Assigning dynamic port %d', listen_port) server = await loop.create_server(protocol_factory, sock=sock) servers.append(server) listen_key = listen_host or '', listen_port return SSHForwardListener(conn, servers, listen_key, listen_port) async def create_tcp_forward_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_host: str, listen_port: int) -> \ 'SSHForwardListener': """Create a listener to forward traffic from a local TCP port over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a port forwarder for each new local connection""" return SSHLocalPortForwarder(conn, coro) return await create_tcp_local_listener(conn, loop, protocol_factory, listen_host, listen_port) async def create_unix_forward_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_path: str) -> \ 'SSHForwardListener': """Create a listener to forward a local UNIX domain socket over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a path forwarder for each new local connection""" return SSHLocalPathForwarder(conn, coro) server = await loop.create_unix_server(protocol_factory, listen_path) return SSHForwardListener(conn, [server], listen_path) async def create_socks_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_host: str, listen_port: int) -> SSHForwardListener: """Create a SOCKS listener to forward traffic over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a port forwarder for each new SOCKS connection""" return SSHSOCKSForwarder(conn, coro) return await create_tcp_local_listener(conn, loop, protocol_factory, listen_host, listen_port) asyncssh-2.20.0/asyncssh/logging.py000066400000000000000000000170531475467777400173230ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Sam Crooks - initial implementation # Ron Frederick - minor cleanup """Logging functions""" import logging from typing import MutableMapping, Optional, Tuple, Union, cast _LogArg = object _ObjDict = MutableMapping[str, object] class SSHLogger(logging.LoggerAdapter): """Adapter to add context to AsyncSSH log messages""" _debug_level = 1 _pkg_logger = logging.getLogger(__package__ or 'asyncssh') def __init__(self, parent: logging.Logger = _pkg_logger, child: str = '', context: str = ''): self._context = context self._logger = parent.getChild(child) if child else parent super().__init__(self._logger, {}) def _extend_context(self, context: str) -> str: """Extend context provided by this logger""" if context: if self._context: context = self._context + ', ' + context else: context = self._context return context def get_child(self, child: str = '', context: str = '') -> 'SSHLogger': """Return child logger with optional added context""" return type(self)(self._logger, child, self._extend_context(context)) def log(self, level: int, msg: object, *args, **kwargs) -> None: """Log a message to the underlying logger""" def _item_text(item: _LogArg) -> str: """Convert a list item to text""" if isinstance(item, bytes): result = item.decode('utf-8', errors='backslashreplace') if not result.isprintable(): result = repr(result)[1:-1] elif not isinstance(item, str): result = str(item) else: result = item return result def _text(arg: _LogArg) -> _LogArg: """Convert a log argument to text""" result: _LogArg if isinstance(arg, list): result = ','.join(_item_text(item) for item in arg) elif isinstance(arg, tuple): host, port = arg if host: result = f'{host}, port {port}' if port else host else: result = f'port {port}' if port else 'dynamic port' elif isinstance(arg, bytes): result = _item_text(arg) else: result = arg return result log_args = [_text(arg) for arg in args] super().log(level, msg, *log_args, **kwargs) def process(self, msg: str, kwargs: _ObjDict) -> Tuple[str, _ObjDict]: """Add context to log message""" extra = cast(_ObjDict, kwargs.get('extra', {})) context = self._extend_context(cast(str, extra.get('context'))) context = '[' + context + '] ' if context else '' packet = cast(bytes, extra.get('packet')) pktdata = '' offset = 0 while packet: line = f'\n {offset:08x}:' for b in packet[:16]: line += f' {b:02x}' line += (62 - len(line)) * ' ' for b in packet[:16]: if b < 0x20 or b >= 0x80: c = '.' elif b == ord('%'): c = '%%' else: c = chr(b) line += c pktdata += line packet = packet[16:] offset += 16 return context + msg + pktdata, kwargs @classmethod def set_debug_level(cls, level: int) -> None: """Set AsyncSSH debug log level""" if level < 1 or level > 3: raise ValueError('Debug log level must be between 1 and 3') cls._debug_level = level def debug1(self, msg: str, *args: _LogArg, **kwargs: object) -> None: """Write a level 1 debug log message""" self.log(logging.DEBUG, msg, *args, **kwargs) def debug2(self, msg: str, *args: _LogArg, **kwargs: object) -> None: """Write a level 2 debug log message""" if self._debug_level >= 2: self.log(logging.DEBUG, msg, *args, **kwargs) def packet(self, pktid: Optional[int], packet: bytes, msg: str, *args: _LogArg, **kwargs: object) -> None: """Write a control packet debug log message""" if self._debug_level >= 3: kwargs.setdefault('extra', {}) extra = cast(_ObjDict, kwargs.get('extra')) if pktid is not None: extra.update(context=f'pktid={pktid}') extra.update(packet=packet) self.log(logging.DEBUG, msg, *args, **kwargs) def set_log_level(level: Union[int, str]) -> None: """Set the AsyncSSH log level This function sets the log level of the AsyncSSH logger. It defaults to `'NOTSET`', meaning that it will track the debug level set on the root Python logger. For additional control over the level of debug logging, see the function :func:`set_debug_level` for additional information. :param level: The log level to set, as defined by the `logging` module :type level: `int` or `str` """ logger.setLevel(level) def set_sftp_log_level(level: Union[int, str]) -> None: """Set the AsyncSSH SFTP/SCP log level This function sets the log level of the AsyncSSH SFTP/SCP logger. It defaults to `'NOTSET`', meaning that it will track the debug level set on the main AsyncSSH logger. For additional control over the level of debug logging, see the function :func:`set_debug_level` for additional information. :param level: The log level to set, as defined by the `logging` module :type level: `int` or `str` """ sftp_logger.setLevel(level) def set_debug_level(level: int) -> None: """Set the AsyncSSH debug log level This function sets the level of debugging logging done by the AsyncSSH logger, from the following options: ===== ==================================== Level Description ===== ==================================== 1 Minimal debug logging 2 Full debug logging 3 Full debug logging with packet dumps ===== ==================================== The debug level defaults to level 1 (minimal debug logging). .. note:: For this setting to have any effect, the effective log level of the AsyncSSH logger must be set to DEBUG. .. warning:: Extreme caution should be used when setting debug level to 3, as this can expose user passwords in clear text. This level should generally only be needed when tracking down issues with malformed or incomplete packets. :param level: The debug level to set, as defined above. :type level: `int` """ logger.set_debug_level(level) logger = SSHLogger() sftp_logger = logger.get_child('sftp') asyncssh-2.20.0/asyncssh/mac.py000066400000000000000000000157441475467777400164420ustar00rootroot00000000000000# Copyright (c) 2013-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH message authentication handlers""" from hashlib import md5, sha1, sha224, sha256, sha384, sha512 import hmac from typing import Dict, Callable, List, Tuple from .packet import UInt32, UInt64 try: from .crypto import umac64, umac128 _umac_available = True except ImportError: # pragma: no cover _umac_available = False _MACAlgsArgs = Tuple[bytes, int, int, bool, Callable, Tuple, bool] _MACHandler = Tuple[Callable, int, Tuple] _MACParams = Tuple[int, int, bool] _OPENSSH = b'@openssh.com' _ETM = b'-etm' + _OPENSSH _mac_algs: List[bytes] = [] _default_mac_algs: List[bytes] = [] _mac_handler: Dict[bytes, _MACHandler] = {} _mac_params: Dict[bytes, _MACParams] = {} class MAC: """Parent class for SSH message authentication handlers""" def __init__(self, key: bytes, hash_size: int): self._key = key self._hash_size = hash_size def sign(self, seq: int, packet: bytes) -> bytes: """Compute a signature for a message""" raise NotImplementedError def verify(self, seq: int, packet: bytes, sig: bytes) -> bool: """Verify the signature of a message""" raise NotImplementedError class _NullMAC(MAC): """Null message authentication handler""" def sign(self, seq: int, packet: bytes) -> bytes: """Compute a signature for a message""" return b'' def verify(self, seq: int, packet: bytes, sig: bytes) -> bool: """Verify the signature of a message""" return sig == b'' class _HMAC(MAC): """HMAC-based message authentication handler""" def __init__(self, key: bytes, hash_size: int, hash_alg: Callable): super().__init__(key, hash_size) self._hash_alg = hash_alg def sign(self, seq: int, packet: bytes) -> bytes: """Compute a signature for a message""" data = UInt32(seq) + packet sig = hmac.new(self._key, data, self._hash_alg).digest() return sig[:self._hash_size] def verify(self, seq: int, packet: bytes, sig: bytes) -> bool: """Verify the signature of a message""" return hmac.compare_digest(self.sign(seq, packet), sig) class _UMAC(MAC): """UMAC-based message authentication handler""" def __init__(self, key: bytes, hash_size: int, umac_alg: Callable): super().__init__(key, hash_size) self._umac_alg = umac_alg def sign(self, seq: int, packet: bytes) -> bytes: """Compute a signature for a message""" return self._umac_alg(self._key, packet, UInt64(seq)).digest() def verify(self, seq: int, packet: bytes, sig: bytes) -> bool: """Verify the signature of a message""" return hmac.compare_digest(self.sign(seq, packet), sig) def register_mac_alg(mac_alg: bytes, key_size: int, hash_size: int, etm: bool, handler: Callable, args: Tuple, default: bool) -> None: """Register a MAC algorithm""" if mac_alg: _mac_algs.append(mac_alg) if default: _default_mac_algs.append(mac_alg) _mac_handler[mac_alg] = (handler, hash_size, args) _mac_params[mac_alg] = (key_size, hash_size, etm) def get_mac_algs() -> List[bytes]: """Return supported MAC algorithms""" return _mac_algs def get_default_mac_algs() -> List[bytes]: """Return default MAC algorithms""" return _default_mac_algs def get_mac_params(mac_alg: bytes) -> _MACParams: """Get parameters of a MAC algorithm This function returns the key and hash sizes of a MAC algorithm and whether or not to compute the MAC before or after encryption. """ return _mac_params[mac_alg] def get_mac(mac_alg: bytes, key: bytes) -> MAC: """Return a MAC handler This function returns a MAC object initialized with the specified key that can be used for data signing and verification. """ handler, hash_size, args = _mac_handler[mac_alg] return handler(key, hash_size, *args) _mac_algs_list: Tuple[_MACAlgsArgs, ...] = ( (b'', 0, 0, False, _NullMAC, (), True), ) if _umac_available: # pragma: no branch _mac_algs_list += ( (b'umac-64' + _ETM, 16, 8, True, _UMAC, (umac64,), True), (b'umac-128' + _ETM, 16, 16, True, _UMAC, (umac128,), True)) _mac_algs_list += ( (b'hmac-sha2-256' + _ETM, 32, 32, True, _HMAC, (sha256,), True), (b'hmac-sha2-512' + _ETM, 64, 64, True, _HMAC, (sha512,), True), (b'hmac-sha1' + _ETM, 20, 20, True, _HMAC, (sha1,), True), (b'hmac-md5' + _ETM, 16, 16, True, _HMAC, (md5,), False), (b'hmac-sha2-256-96' + _ETM, 32, 12, True, _HMAC, (sha256,), False), (b'hmac-sha2-512-96' + _ETM, 64, 12, True, _HMAC, (sha512,), False), (b'hmac-sha1-96' + _ETM, 20, 12, True, _HMAC, (sha1,), False), (b'hmac-md5-96' + _ETM, 16, 12, True, _HMAC, (md5,), False)) if _umac_available: # pragma: no branch _mac_algs_list += ( (b'umac-64' + _OPENSSH, 16, 8, False, _UMAC, (umac64,), True), (b'umac-128' + _OPENSSH, 16, 16, False, _UMAC, (umac128,), True)) _mac_algs_list += ( (b'hmac-sha2-256', 32, 32, False, _HMAC, (sha256,), True), (b'hmac-sha2-512', 64, 64, False, _HMAC, (sha512,), True), (b'hmac-sha1', 20, 20, False, _HMAC, (sha1,), True), (b'hmac-sha256-2@ssh.com', 32, 32, False, _HMAC, (sha256,), True), (b'hmac-sha224@ssh.com', 28, 28, False, _HMAC, (sha224,), True), (b'hmac-sha256@ssh.com', 16, 32, False, _HMAC, (sha256,), True), (b'hmac-sha384@ssh.com', 48, 48, False, _HMAC, (sha384,), True), (b'hmac-sha512@ssh.com', 64, 64, False, _HMAC, (sha512,), True), (b'hmac-md5', 16, 16, False, _HMAC, (md5,), False), (b'hmac-sha2-256-96', 32, 12, False, _HMAC, (sha256,), False), (b'hmac-sha2-512-96', 64, 12, False, _HMAC, (sha512,), False), (b'hmac-sha1-96', 20, 12, False, _HMAC, (sha1,), False), (b'hmac-md5-96', 16, 12, False, _HMAC, (md5,), False)) for _alg, _key_size, _hash_size, _etm, \ _mac_alg, _args, _default in _mac_algs_list: register_mac_alg(_alg, _key_size, _hash_size, _etm, _mac_alg, _args, _default) asyncssh-2.20.0/asyncssh/misc.py000066400000000000000000000645441475467777400166370ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Miscellaneous utility classes and functions""" import asyncio import fnmatch import functools import ipaddress import os import re import shlex import socket import sys from pathlib import Path, PurePath from random import SystemRandom from types import TracebackType from typing import Any, AsyncContextManager, Awaitable, Callable, Dict from typing import Generator, Generic, IO, Iterator, Mapping, Optional from typing import Sequence, Tuple, Type, TypeVar, Union, cast, overload from typing_extensions import Literal, Protocol from .constants import DEFAULT_LANG from .constants import DISC_COMPRESSION_ERROR, DISC_CONNECTION_LOST from .constants import DISC_HOST_KEY_NOT_VERIFIABLE, DISC_ILLEGAL_USER_NAME from .constants import DISC_KEY_EXCHANGE_FAILED, DISC_MAC_ERROR from .constants import DISC_NO_MORE_AUTH_METHODS_AVAILABLE from .constants import DISC_PROTOCOL_ERROR, DISC_PROTOCOL_VERSION_NOT_SUPPORTED from .constants import DISC_SERVICE_NOT_AVAILABLE if sys.platform != 'win32': # pragma: no branch import fcntl import struct import termios TermModes = Mapping[int, int] TermModesArg = Optional[TermModes] TermSize = Tuple[int, int, int, int] TermSizeArg = Union[None, Tuple[int, int], TermSize] class _Hash(Protocol): """Protocol for hashing data""" @property def digest_size(self) -> int: """Return the hash digest size""" @property def block_size(self) -> int: """Return the hash block size""" @property def name(self) -> str: """Return the hash name""" def digest(self) -> bytes: """Return the digest value as a bytes object""" def hexdigest(self) -> str: """Return the digest value as a string of hexadecimal digits""" def update(self, __data: bytes) -> None: """Update this hash object's state with the provided bytes""" class HashType(Protocol): """Protocol for returning the type of a hash function""" def __call__(self, __data: bytes = ...) -> _Hash: """Create a new hash object""" class _SupportsWaitClosed(Protocol): """A class that supports async wait_closed""" async def wait_closed(self) -> None: """Wait for transport to close""" _T = TypeVar('_T') DefTuple = Union[Tuple[()], _T] MaybeAwait = Union[_T, Awaitable[_T]] ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] OptExcInfo = Union[ExcInfo, Tuple[None, None, None]] BytesOrStr = Union[bytes, str] BytesOrStrDict = Dict[BytesOrStr, BytesOrStr] FilePath = Union[str, PurePath] HostPort = Tuple[str, int] IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]] EnvMap = Mapping[BytesOrStr, BytesOrStr] EnvItems = Sequence[Tuple[BytesOrStr, BytesOrStr]] EnvSeq = Sequence[BytesOrStr] Env = Optional[Union[EnvMap, EnvItems, EnvSeq]] # Define a version of randrange which is based on SystemRandom(), so that # we get back numbers suitable for cryptographic use. _random = SystemRandom() randrange = _random.randrange _unit_pattern = re.compile(r'([A-Za-z])') _byte_units = {'': 1, 'k': 1024, 'm': 1024*1024, 'g': 1024*1024*1024} _time_units = {'': 1, 's': 1, 'm': 60, 'h': 60*60, 'd': 24*60*60, 'w': 7*24*60*60} def encode_env(env: Env) -> Iterator[Tuple[bytes, bytes]]: """Convert environemnt dict or list to bytes-based dictionary""" env = cast(Sequence[Tuple[BytesOrStr, BytesOrStr]], env.items() if isinstance(env, dict) else env) try: for item in env: if isinstance(item, (bytes, str)): if isinstance(item, str): item = item.encode('utf-8') key_bytes, value_bytes = item.split(b'=', 1) else: key, value = item key_bytes = key.encode('utf-8') \ if isinstance(key, str) else key value_bytes = value.encode('utf-8') \ if isinstance(value, str) else value yield key_bytes, value_bytes except (TypeError, ValueError) as exc: raise ValueError(f'Invalid environment value: {exc}') from None def lookup_env(patterns: EnvSeq) -> Iterator[Tuple[bytes, bytes]]: """Look up environemnt variables with wildcard matches""" for pattern in patterns: if isinstance(pattern, str): pattern = pattern.encode('utf-8') if os.supports_bytes_environ: for key_bytes, value_bytes in os.environb.items(): if fnmatch.fnmatch(key_bytes, pattern): yield key_bytes, value_bytes else: # pragma: no cover for key, value in os.environ.items(): key_bytes = key.encode('utf-8') value_bytes = value.encode('utf-8') if fnmatch.fnmatch(key_bytes, pattern): yield key_bytes, value_bytes def decode_env(env: Dict[bytes, bytes]) -> Iterator[Tuple[str, str]]: """Convert bytes-based environemnt dict to Unicode strings""" for key, value in env.items(): try: yield key.decode('utf-8'), value.decode('utf-8') except UnicodeDecodeError: pass def hide_empty(value: object, prefix: str = ', ') -> str: """Return a string with optional prefix if value is non-empty""" value = str(value) return prefix + value if value else '' def plural(length: int, label: str, suffix: str = 's') -> str: """Return a label with an optional plural suffix""" return f'{length} {label}{suffix if length != 1 else ""}' def all_ints(seq: Sequence[object]) -> bool: """Return if a sequence contains all integers""" return all(isinstance(i, int) for i in seq) def get_symbol_names(symbols: Mapping[str, int], prefix: str, strip_leading: int = 0) -> Mapping[int, str]: """Return a mapping from values to symbol names for logging""" return {value: name[strip_leading:] for name, value in symbols.items() if name.startswith(prefix)} # Punctuation to map when creating handler names _HANDLER_PUNCTUATION = (('@', '_at_'), ('.', '_dot_'), ('-', '_')) def map_handler_name(name: str) -> str: """Map punctuation so a string can be used as a handler name""" for old, new in _HANDLER_PUNCTUATION: name = name.replace(old, new) return name def _normalize_scoped_ip(addr: str) -> str: """Normalize scoped IP address The ipaddress module doesn't handle scoped addresses properly, so we normalize scoped IP addresses using socket.getaddrinfo before we pass them into ip_address/ip_network. """ try: addrinfo = socket.getaddrinfo(addr, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_NUMERICHOST)[0] except socket.gaierror: return addr if addrinfo[0] == socket.AF_INET6: sa = addrinfo[4] addr = sa[0] idx = addr.find('%') if idx >= 0: # pragma: no cover addr = addr[:idx] ip = ipaddress.ip_address(addr) if ip.is_link_local: scope_id = cast(Tuple[str, int, int, int], sa)[3] addr = str(ipaddress.ip_address(int(ip) | (scope_id << 96))) return addr def ip_address(addr: str) -> IPAddress: """Wrapper for ipaddress.ip_address which supports scoped addresses""" return ipaddress.ip_address(_normalize_scoped_ip(addr)) def ip_network(addr: str) -> IPNetwork: """Wrapper for ipaddress.ip_network which supports scoped addresses""" idx = addr.find('/') if idx >= 0: addr, mask = addr[:idx], addr[idx:] else: mask = '' return ipaddress.ip_network(_normalize_scoped_ip(addr) + mask) def open_file(filename: FilePath, mode: str, buffering: int = -1) -> IO[bytes]: """Open a file with home directory expansion""" return open(Path(filename).expanduser(), mode, buffering=buffering) @overload def read_file(filename: FilePath) -> bytes: """Read from a binary file with home directory expansion""" @overload def read_file(filename: FilePath, mode: Literal['rb']) -> bytes: """Read from a binary file with home directory expansion""" @overload def read_file(filename: FilePath, mode: Literal['r']) -> str: """Read from a text file with home directory expansion""" def read_file(filename, mode = 'rb'): """Read from a file with home directory expansion""" with open_file(filename, mode) as f: return f.read() def write_file(filename: FilePath, data: bytes, mode: str = 'wb') -> int: """Write or append to a file with home directory expansion""" with open_file(filename, mode) as f: return f.write(data) def _parse_units(value: str, suffixes: Mapping[str, int], label: str) -> float: """Parse a series of integers followed by unit suffixes""" matches = _unit_pattern.split(value) if matches[-1]: matches.append('') else: matches.pop() try: return sum(float(matches[i]) * suffixes[matches[i+1].lower()] for i in range(0, len(matches), 2)) except KeyError: raise ValueError('Invalid ' + label) from None def parse_byte_count(value: str) -> int: """Parse a byte count with optional k, m, or g suffixes""" return int(_parse_units(value, _byte_units, 'byte count')) def parse_time_interval(value: str) -> float: """Parse a time interval with optional s, m, h, d, or w suffixes""" return _parse_units(value, _time_units, 'time interval') def split_args(command: str) -> Sequence[str]: """Split a command string into a list of arguments""" lex = shlex.shlex(command, posix=True) lex.whitespace_split = True if sys.platform == 'win32': # pragma: no cover lex.escape = [] return list(lex) _ACM = TypeVar('_ACM', bound=AsyncContextManager, covariant=True) class _ACMWrapper(Generic[_ACM]): """Async context manager wrapper""" def __init__(self, coro: Awaitable[_ACM]): self._coro = coro self._coro_result: Optional[_ACM] = None def __await__(self) -> Generator[Any, None, _ACM]: return self._coro.__await__() async def __aenter__(self) -> _ACM: self._coro_result = await self._coro return await self._coro_result.__aenter__() async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> Optional[bool]: assert self._coro_result is not None exit_result = await self._coro_result.__aexit__( exc_type, exc_value, traceback) self._coro_result = None return exit_result _ACMCoro = Callable[..., Awaitable[_ACM]] _ACMWrapperFunc = Callable[..., _ACMWrapper[_ACM]] def async_context_manager(coro: _ACMCoro[_ACM]) -> _ACMWrapperFunc[_ACM]: """Decorator for functions returning asynchronous context managers This decorator can be used on functions which return objects intended to be async context managers. The object returned by the function should implement __aenter__ and __aexit__ methods to run when the async context is entered and exited. This wrapper also allows the use of "await" on the function being decorated, to return the context manager without entering it. """ @functools.wraps(coro) def context_wrapper(*args, **kwargs) -> _ACMWrapper[_ACM]: """Return an async context manager wrapper for this coroutine""" return _ACMWrapper(coro(*args, **kwargs)) return context_wrapper async def maybe_wait_closed(writer: '_SupportsWaitClosed') -> None: """Wait for a StreamWriter to close, if Python version supports it Python 3.8 triggers a false error report about garbage collecting an open stream if a close is in progress when a StreamWriter is garbage collected. This can be avoided by calling wait_closed(), but that method is not available in Python releases prior to 3.7. This function wraps this call, ignoring the error if the method is not available. """ try: await writer.wait_closed() except AttributeError: # pragma: no cover pass async def run_in_executor(func: Callable[..., _T], *args: object) -> _T: """Run a function in an asyncio executor""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, func, *args) def set_terminal_size(tty: IO, width: int, height: int, pixwidth: int, pixheight: int) -> None: """Set the terminal size of a TTY""" fcntl.ioctl(tty, termios.TIOCSWINSZ, struct.pack('hhhh', height, width, pixwidth, pixheight)) class Options: """Container for configuration options""" kwargs: Dict[str, object] def __init__(self, options: Optional['Options'] = None, **kwargs: object): if options: if not isinstance(options, type(self)): raise TypeError(f'Invalid {type(self).__name__}, ' f'got {type(options).__name__}') self.kwargs = options.kwargs.copy() else: self.kwargs = {} self.kwargs.update(kwargs) self.prepare(**self.kwargs) def prepare(self, **kwargs: object) -> None: """Pre-process configuration options""" def update(self, **kwargs: object) -> None: """Update options based on keyword parameters passed in""" self.kwargs.update(kwargs) self.prepare(**self.kwargs) class _RecordMeta(type): """Metaclass for general-purpose record type""" __slots__: Dict[str, object] = {} def __new__(mcs: Type['_RecordMeta'], name: str, bases: Tuple[type, ...], ns: Dict[str, object]) -> '_RecordMeta': cls = cast(_RecordMeta, super().__new__(mcs, name, bases, ns)) if name != 'Record': fields = cast(Mapping[str, str], cls.__annotations__.keys()) defaults = {k: ns.get(k) for k in fields} cls.__slots__ = defaults return cls class Record(metaclass=_RecordMeta): """Generic Record class""" __slots__: Mapping[str, object] = {} def __init__(self, *args: object, **kwargs: object): for k, v in self.__slots__.items(): setattr(self, k, v) for k, v in zip(self.__slots__, args): setattr(self, k, v) for k, v in kwargs.items(): setattr(self, k, v) def __repr__(self) -> str: values = ', '.join(f'{k}={getattr(self, k)!r}' for k in self.__slots__) return f'{type(self).__name__}({values})' def __str__(self) -> str: values = ((k, self._format(k, getattr(self, k))) for k in self.__slots__) return ', '.join(f'{k}: {v}' for k, v in values if v is not None) def _format(self, k: str, v: object) -> Optional[str]: """Format a field as a string""" # pylint: disable=no-self-use,unused-argument return str(v) class Error(Exception): """General SSH error""" def __init__(self, code: int, reason: str, lang: str = DEFAULT_LANG): super().__init__(reason) self.code = code self.reason = reason self.lang = lang class DisconnectError(Error): """SSH disconnect error This exception is raised when a serious error occurs which causes the SSH connection to be disconnected. Exception codes should be taken from :ref:`disconnect reason codes `. See below for exception subclasses tied to specific disconnect reasons if you want to customize your handling by reason. :param code: Disconnect reason, taken from :ref:`disconnect reason codes ` :param reason: A human-readable reason for the disconnect :param lang: (optional) The language the reason is in :type code: `int` :type reason: `str` :type lang: `str` """ class CompressionError(DisconnectError): """SSH compression error This exception is raised when an error occurs while compressing or decompressing data sent on the SSH connection. :param reason: Details about the compression error :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_COMPRESSION_ERROR, reason, lang) class ConnectionLost(DisconnectError): """SSH connection lost This exception is raised when the SSH connection to the remote system is unexpectedly lost. It can also occur as a result of the remote system failing to respond to keepalive messages or as a result of a login timeout, when those features are enabled. :param reason: Details about the connection failure :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_CONNECTION_LOST, reason, lang) class HostKeyNotVerifiable(DisconnectError): """SSH host key not verifiable This exception is raised when the SSH server's host key or certificate is not verifiable. :param reason: Details about the host key verification failure :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_HOST_KEY_NOT_VERIFIABLE, reason, lang) class IllegalUserName(DisconnectError): """SSH illegal user name This exception is raised when an error occurs while processing the username sent during the SSL handshake. :param reason: Details about the illegal username :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_ILLEGAL_USER_NAME, reason, lang) class KeyExchangeFailed(DisconnectError): """SSH key exchange failed This exception is raised when the SSH key exchange fails. :param reason: Details about the connection failure :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_KEY_EXCHANGE_FAILED, reason, lang) class MACError(DisconnectError): """SSH MAC error This exception is raised when an error occurs while processing the message authentication code (MAC) of a message on the SSH connection. :param reason: Details about the MAC error :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_MAC_ERROR, reason, lang) class PermissionDenied(DisconnectError): """SSH permission denied This exception is raised when there are no authentication methods remaining to complete SSH client authentication. :param reason: Details about the SSH protocol error detected :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_NO_MORE_AUTH_METHODS_AVAILABLE, reason, lang) class ProtocolError(DisconnectError): """SSH protocol error This exception is raised when the SSH connection is disconnected due to an SSH protocol error being detected. :param reason: Details about the SSH protocol error detected :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_PROTOCOL_ERROR, reason, lang) class ProtocolNotSupported(DisconnectError): """SSH protocol not supported This exception is raised when the remote system sends an SSH protocol version which is not supported. :param reason: Details about the unsupported SSH protocol version :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_PROTOCOL_ERROR, reason, lang) class ServiceNotAvailable(DisconnectError): """SSH service not available This exception is raised when an unexpected service name is received during the SSH handshake. :param reason: Details about the unexpected SSH service :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(DISC_SERVICE_NOT_AVAILABLE, reason, lang) class ChannelOpenError(Error): """SSH channel open error This exception is raised by connection handlers to report channel open failures. :param code: Channel open failure reason, taken from :ref:`channel open failure reason codes ` :param reason: A human-readable reason for the channel open failure :param lang: The language the reason is in :type code: `int` :type reason: `str` :type lang: `str` """ class ChannelListenError(Exception): """SSH channel listen error This exception is raised to report failures in setting up remote SSH connection listeners. :param details: Details of the listen failure :type details: `str` """ class PasswordChangeRequired(Exception): """SSH password change required This exception is raised during password validation on the server to indicate that a password change is required. It should be raised when the password provided is valid but expired, to trigger the client to provide a new password. :param prompt: The prompt requesting that the user enter a new password :param lang: The language that the prompt is in :type prompt: `str` :type lang: `str` """ def __init__(self, prompt: str, lang: str = DEFAULT_LANG): super().__init__(f'Password change required: {prompt}') self.prompt = prompt self.lang = lang class BreakReceived(Exception): """SSH break request received This exception is raised on an SSH server stdin stream when the client sends a break on the channel. :param msec: The duration of the break in milliseconds :type msec: `int` """ def __init__(self, msec: int): super().__init__(f'Break for {msec} msec') self.msec = msec class SignalReceived(Exception): """SSH signal request received This exception is raised on an SSH server stdin stream when the client sends a signal on the channel. :param signal: The name of the signal sent by the client :type signal: `str` """ def __init__(self, signal: str): super().__init__(f'Signal: {signal}') self.signal = signal class SoftEOFReceived(Exception): """SSH soft EOF request received This exception is raised on an SSH server stdin stream when the client sends an EOF from within the line editor on the channel. """ def __init__(self) -> None: super().__init__('Soft EOF') class TerminalSizeChanged(Exception): """SSH terminal size change notification received This exception is raised on an SSH server stdin stream when the client sends a terminal size change on the channel. :param width: The new terminal width :param height: The new terminal height :param pixwidth: The new terminal width in pixels :param pixheight: The new terminal height in pixels :type width: `int` :type height: `int` :type pixwidth: `int` :type pixheight: `int` """ def __init__(self, width: int, height: int, pixwidth: int, pixheight: int): super().__init__(f'Terminal size change: ({width}, {height}, ' f'{pixwidth}, {pixheight})') self.width = width self.height = height self.pixwidth = pixwidth self.pixheight = pixheight @property def term_size(self) -> TermSize: """Return terminal size as a tuple of 4 integers""" return self.width, self.height, self.pixwidth, self.pixheight _disc_error_map = { DISC_PROTOCOL_ERROR: ProtocolError, DISC_KEY_EXCHANGE_FAILED: KeyExchangeFailed, DISC_MAC_ERROR: MACError, DISC_COMPRESSION_ERROR: CompressionError, DISC_SERVICE_NOT_AVAILABLE: ServiceNotAvailable, DISC_PROTOCOL_VERSION_NOT_SUPPORTED: ProtocolNotSupported, DISC_HOST_KEY_NOT_VERIFIABLE: HostKeyNotVerifiable, DISC_CONNECTION_LOST: ConnectionLost, DISC_NO_MORE_AUTH_METHODS_AVAILABLE: PermissionDenied, DISC_ILLEGAL_USER_NAME: IllegalUserName } def construct_disc_error(code: int, reason: str, lang: str) -> DisconnectError: """Map disconnect error code to appropriate DisconnectError exception""" try: return _disc_error_map[code](reason, lang) except KeyError: return DisconnectError(code, f'{reason} (error {code})', lang) asyncssh-2.20.0/asyncssh/packet.py000066400000000000000000000151531475467777400171430ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH packet encoding and decoding functions""" from typing import Any, Awaitable, Callable, Iterable, Mapping, Optional from typing import Sequence, Union from .logging import SSHLogger from .misc import MaybeAwait, plural _LoggedPacket = Union[bytes, 'SSHPacket'] _PacketHandler = Callable[[Any, int, int, 'SSHPacket'], MaybeAwait[None]] class PacketDecodeError(ValueError): """Packet decoding error""" def Byte(value: int) -> bytes: """Encode a single byte""" return bytes((value,)) def Boolean(value: bool) -> bytes: """Encode a boolean value""" return Byte(bool(value)) def UInt16(value: int) -> bytes: """Encode a 16-bit integer value""" return value.to_bytes(2, 'big') def UInt32(value: int) -> bytes: """Encode a 32-bit integer value""" return value.to_bytes(4, 'big') def UInt64(value: int) -> bytes: """Encode a 64-bit integer value""" return value.to_bytes(8, 'big') def String(value: Union[bytes, str]) -> bytes: """Encode a byte string or UTF-8 string value""" if isinstance(value, str): value = value.encode('utf-8', errors='strict') return len(value).to_bytes(4, 'big') + value def MPInt(value: int) -> bytes: """Encode a multiple precision integer value""" l = value.bit_length() l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1)) l = (l + 7) // 8 return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True) def NameList(value: Iterable[bytes]) -> bytes: """Encode a comma-separated list of byte strings""" return String(b','.join(value)) class SSHPacket: """Decoder class for SSH packets""" def __init__(self, packet: bytes): self._packet = packet self._idx = 0 self._len = len(packet) def __bool__(self) -> bool: return self._idx != self._len def check_end(self) -> None: """Confirm that all of the data in the packet has been consumed""" if self: raise PacketDecodeError('Unexpected data at end of packet') def get_consumed_payload(self) -> bytes: """Return the portion of the packet consumed so far""" return self._packet[:self._idx] def get_remaining_payload(self) -> bytes: """Return the portion of the packet not yet consumed""" return self._packet[self._idx:] def get_full_payload(self) -> bytes: """Return the full packet""" return self._packet def get_bytes(self, size: int) -> bytes: """Extract the requested number of bytes from the packet""" if self._idx + size > self._len: raise PacketDecodeError('Incomplete packet') value = self._packet[self._idx:self._idx+size] self._idx += size return value def get_byte(self) -> int: """Extract a single byte from the packet""" return self.get_bytes(1)[0] def get_boolean(self) -> bool: """Extract a boolean from the packet""" return bool(self.get_byte()) def get_uint16(self) -> int: """Extract a 16-bit integer from the packet""" return int.from_bytes(self.get_bytes(2), 'big') def get_uint32(self) -> int: """Extract a 32-bit integer from the packet""" return int.from_bytes(self.get_bytes(4), 'big') def get_uint64(self) -> int: """Extract a 64-bit integer from the packet""" return int.from_bytes(self.get_bytes(8), 'big') def get_string(self) -> bytes: """Extract a UTF-8 string from the packet""" return self.get_bytes(self.get_uint32()) def get_mpint(self) -> int: """Extract a multiple precision integer from the packet""" return int.from_bytes(self.get_string(), 'big', signed=True) def get_namelist(self) -> Sequence[bytes]: """Extract a comma-separated list of byte strings from the packet""" namelist = self.get_string() return namelist.split(b',') if namelist else [] class SSHPacketLogger: """Parent class for SSH packet loggers""" _handler_names: Mapping[int, str] = {} @property def logger(self) -> SSHLogger: """The logger to use for packet logging""" raise NotImplementedError def _log_packet(self, msg: str, pkttype: int, pktid: Optional[int], packet: _LoggedPacket, note: str) -> None: """Log a sent/received packet""" if isinstance(packet, SSHPacket): packet = packet.get_full_payload() try: name = f'{self._handler_names[pkttype]} ({pkttype})' except KeyError: name = f'packet type {pkttype}' count = plural(len(packet), 'byte') if note: note = f' ({note})' self.logger.packet(pktid, packet, '%s %s, %s%s', msg, name, count, note) def log_sent_packet(self, pkttype: int, pktid: Optional[int], packet: _LoggedPacket, note: str = '') -> None: """Log a sent packet""" self._log_packet('Sent', pkttype, pktid, packet, note) def log_received_packet(self, pkttype: int, pktid: Optional[int], packet: _LoggedPacket, note: str = '') -> None: """Log a received packet""" self._log_packet('Received', pkttype, pktid, packet, note) class SSHPacketHandler(SSHPacketLogger): """Parent class for SSH packet handlers""" _packet_handlers: Mapping[int, _PacketHandler] = {} @property def logger(self) -> SSHLogger: """The logger associated with this packet handler""" raise NotImplementedError def process_packet(self, pkttype: int, pktid: int, packet: SSHPacket) -> Union[bool, Awaitable[None]]: """Log and process a received packet""" if pkttype in self._packet_handlers: return self._packet_handlers[pkttype](self, pkttype, pktid, packet) or True else: return False asyncssh-2.20.0/asyncssh/pattern.py000066400000000000000000000110661475467777400173500ustar00rootroot00000000000000# Copyright (c) 2015-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Pattern matching for principal and host names""" from fnmatch import fnmatch from typing import Union from .misc import IPAddress, ip_network _HostPattern = Union['WildcardHostPattern', 'CIDRHostPattern'] _AnyPattern = Union['WildcardPattern', _HostPattern] class _BaseWildcardPattern: """A base class for matching '*' and '?' wildcards""" def __init__(self, pattern: str): # We need to escape square brackets in host patterns if we # want to use Python's fnmatch. self._pattern = ''.join('[[]' if ch == '[' else '[]]' if ch == ']' else ch for ch in pattern) def _matches(self, value: str) -> bool: """Return whether a wild card pattern matches a value""" return fnmatch(value, self._pattern) class WildcardPattern(_BaseWildcardPattern): """A pattern matcher for '*' and '?' wildcards""" def matches(self, value: str) -> bool: """Return whether a wild card pattern matches a value""" return super()._matches(value) class WildcardHostPattern(_BaseWildcardPattern): """Match a host name or address against a wildcard pattern""" def matches(self, host: str, addr: str, _ip: IPAddress) -> bool: """Return whether a host or address matches a wild card host pattern""" return (bool(host) and super()._matches(host)) or \ (bool(addr) and super()._matches(addr)) class CIDRHostPattern: """Match IPv4/v6 address against CIDR-style subnet pattern""" def __init__(self, pattern: str): self._network = ip_network(pattern) def matches(self, _host: str, _addr: str, ip: IPAddress) -> bool: """Return whether an IP address matches a CIDR address pattern""" return bool(ip) and ip in self._network class _PatternList: """Match against a list of comma-separated positive and negative patterns This class is a base class for building a pattern matcher that takes a set of comma-separated positive and negative patterns, returning `True` if one or more positive patterns match and no negative ones do. The pattern matching is done by objects returned by the build_pattern method. The arguments passed in when a match is performed will vary depending on what class build_pattern returns. """ def __init__(self, patterns: str): self._pos_patterns = [] self._neg_patterns = [] for pattern in patterns.split(','): if pattern.startswith('!'): negate = True pattern = pattern[1:] else: negate = False matcher = self.build_pattern(pattern) if negate: self._neg_patterns.append(matcher) else: self._pos_patterns.append(matcher) def build_pattern(self, pattern: str) -> _AnyPattern: """Abstract method to build a pattern object""" raise NotImplementedError def matches(self, *args) -> bool: """Match a set of values against positive & negative pattern lists""" pos_match = any(p.matches(*args) for p in self._pos_patterns) neg_match = any(p.matches(*args) for p in self._neg_patterns) return pos_match and not neg_match class WildcardPatternList(_PatternList): """Match names against wildcard patterns""" def build_pattern(self, pattern: str) -> WildcardPattern: """Build a wild card pattern""" return WildcardPattern(pattern) class HostPatternList(_PatternList): """Match host names & addresses against wildcard and CIDR patterns""" def build_pattern(self, pattern: str) -> _HostPattern: """Build a CIDR address or wild card host pattern""" try: return CIDRHostPattern(pattern) except ValueError: return WildcardHostPattern(pattern) asyncssh-2.20.0/asyncssh/pbe.py000066400000000000000000000503021475467777400164350ustar00rootroot00000000000000# Copyright (c) 2013-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Asymmetric key password based encryption functions""" from hashlib import md5, sha1 import os from typing import Callable, Dict, Sequence, Tuple, Union from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import BasicCipher, get_cipher_params, pbkdf2_hmac from .misc import BytesOrStr, HashType _Cipher = Union[BasicCipher, '_RFC1423Pad'] _PKCS8CipherHandler = Callable[[object, BytesOrStr, Callable, str], _Cipher] _PKCS8Cipher = Tuple[_PKCS8CipherHandler, Callable, str] _PBES2CipherHandler = Callable[[Sequence, str, bytes], _Cipher] _PBES2Cipher = Tuple[_PBES2CipherHandler, str] _PBES2KDFHandler = Callable[[Sequence, BytesOrStr, int], bytes] _PBES2KDF = Tuple[_PBES2KDFHandler, Tuple[object, ...]] _ES1_MD5_DES = ObjectIdentifier('1.2.840.113549.1.5.3') _ES1_SHA1_DES = ObjectIdentifier('1.2.840.113549.1.5.10') _ES2 = ObjectIdentifier('1.2.840.113549.1.5.13') _P12_RC4_128 = ObjectIdentifier('1.2.840.113549.1.12.1.1') _P12_RC4_40 = ObjectIdentifier('1.2.840.113549.1.12.1.2') _P12_DES3 = ObjectIdentifier('1.2.840.113549.1.12.1.3') _P12_DES2 = ObjectIdentifier('1.2.840.113549.1.12.1.4') _ES2_CAST128 = ObjectIdentifier('1.2.840.113533.7.66.10') _ES2_DES3 = ObjectIdentifier('1.2.840.113549.3.7') _ES2_BF = ObjectIdentifier('1.3.6.1.4.1.3029.1.2') _ES2_DES = ObjectIdentifier('1.3.14.3.2.7') _ES2_AES128 = ObjectIdentifier('2.16.840.1.101.3.4.1.2') _ES2_AES192 = ObjectIdentifier('2.16.840.1.101.3.4.1.22') _ES2_AES256 = ObjectIdentifier('2.16.840.1.101.3.4.1.42') _ES2_PBKDF2 = ObjectIdentifier('1.2.840.113549.1.5.12') _ES2_SHA1 = ObjectIdentifier('1.2.840.113549.2.7') _ES2_SHA224 = ObjectIdentifier('1.2.840.113549.2.8') _ES2_SHA256 = ObjectIdentifier('1.2.840.113549.2.9') _ES2_SHA384 = ObjectIdentifier('1.2.840.113549.2.10') _ES2_SHA512 = ObjectIdentifier('1.2.840.113549.2.11') _pkcs1_cipher: Dict[bytes, str] = {} _pkcs1_dek_name: Dict[str, bytes] = {} _pkcs8_handler: Dict[ObjectIdentifier, _PKCS8Cipher] = {} _pkcs8_cipher_oid: Dict[Tuple[str, str], ObjectIdentifier] = {} _pbes2_cipher: Dict[ObjectIdentifier, _PBES2Cipher] = {} _pbes2_cipher_oid: Dict[str, ObjectIdentifier] = {} _pbes2_kdf: Dict[ObjectIdentifier, _PBES2KDF] = {} _pbes2_kdf_oid: Dict[str, ObjectIdentifier] = {} _pbes2_prf: Dict[ObjectIdentifier, str] = {} _pbes2_prf_oid: Dict[str, ObjectIdentifier] = {} class KeyEncryptionError(ValueError): """Key encryption error This exception is raised by key decryption functions when the data provided is not a valid encrypted private key. """ class _RFC1423Pad: """RFC 1423 padding functions This class implements RFC 1423 padding for encryption and decryption of data by block ciphers. On encryption, the data is padded by between 1 and the cipher's block size number of bytes, with the padding value being equal to the length of the padding. """ def __init__(self, cipher_name: str, block_size: int, key: bytes, iv: bytes): self._cipher = BasicCipher(cipher_name, key, iv) self._block_size = block_size def encrypt(self, data: bytes) -> bytes: """Pad data before encrypting it""" pad = self._block_size - (len(data) % self._block_size) data += pad * bytes((pad,)) return self._cipher.encrypt(data) def decrypt(self, data: bytes) -> bytes: """Remove padding from data after decrypting it""" data = self._cipher.decrypt(data) if data: pad = data[-1] if (1 <= pad <= self._block_size and data[-pad:] == pad * bytes((pad,))): return data[:-pad] raise KeyEncryptionError('Unable to decrypt key') def _pbkdf1(hash_alg: HashType, passphrase: BytesOrStr, salt: bytes, count: int, key_size: int) -> bytes: """PKCS#5 v1.5 key derivation function for password-based encryption This function implements the PKCS#5 v1.5 algorithm for deriving an encryption key from a passphrase and salt. The standard PBKDF1 function cannot generate more key bytes than the hash digest size, but 3DES uses a modified form of it which calls PBKDF1 recursively on the result to generate more key data. Support for this is implemented here. """ if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') key = passphrase + salt for _ in range(count): key = hash_alg(key).digest() if len(key) <= key_size: return key + _pbkdf1(hash_alg, key + passphrase, salt, count, key_size - len(key)) else: return key[:key_size] def _pbkdf_p12(hash_alg: HashType, passphrase: BytesOrStr, salt: bytes, count: int, key_size: int, idx: int) -> bytes: """PKCS#12 key derivation function for password-based encryption This function implements the PKCS#12 algorithm for deriving an encryption key from a passphrase and salt. """ def _make_block(data: bytes, v: int) -> bytes: """Make a block a multiple of v bytes long by repeating data""" l = len(data) size = ((l + v - 1) // v) * v return (((size + l - 1) // l) * data)[:size] v = hash_alg().block_size D = v * bytes((idx,)) if isinstance(passphrase, str): passphrase = passphrase.encode('utf-16be') I = bytearray(_make_block(salt, v) + _make_block(passphrase + b'\0\0', v)) key = b'' while len(key) < key_size: A = D + I for i in range(count): A = hash_alg(A).digest() B = int.from_bytes(_make_block(A, v), 'big') for i in range(0, len(I), v): x = (int.from_bytes(I[i:i+v], 'big') + B + 1) % (1 << v*8) I[i:i+v] = x.to_bytes(v, 'big') key += A return key[:key_size] def _pbes1(params: object, passphrase: BytesOrStr, hash_alg: HashType, cipher_name: str) -> _Cipher: """PKCS#5 v1.5 cipher selection function for password-based encryption This function implements the PKCS#5 v1.5 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters, passphrase, and salt. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], bytes) or not isinstance(params[1], int)): raise KeyEncryptionError('Invalid PBES1 encryption parameters') salt, count = params key_size, iv_size, block_size = get_cipher_params(cipher_name) key = _pbkdf1(hash_alg, passphrase, salt, count, key_size + iv_size) key, iv = key[:key_size], key[key_size:] return _RFC1423Pad(cipher_name, block_size, key, iv) def _pbe_p12(params: object, passphrase: BytesOrStr, hash_alg: HashType, cipher_name: str) -> _Cipher: """PKCS#12 cipher selection function for password-based encryption This function implements the PKCS#12 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters, passphrase, and salt. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], bytes) or not params[0] or not isinstance(params[1], int) or params[1] == 0): raise KeyEncryptionError('Invalid PBES1 PKCS#12 encryption parameters') salt, count = params key_size, iv_size, block_size = get_cipher_params(cipher_name) key = _pbkdf_p12(hash_alg, passphrase, salt, count, key_size, 1) if block_size == 1: cipher: _Cipher = BasicCipher(cipher_name, key, b'') else: iv = _pbkdf_p12(hash_alg, passphrase, salt, count, iv_size, 2) cipher = _RFC1423Pad(cipher_name, block_size, key, iv) return cipher def _pbes2_iv(enc_params: Sequence, cipher_name: str, key: bytes) -> _Cipher: """PKCS#5 v2.0 handler for PBES2 ciphers with an IV as a parameter This function returns the appropriate cipher object to use for PBES2 encryption for ciphers that have only an IV as an encryption parameter. """ _, iv_size, block_size = get_cipher_params(cipher_name) if len(enc_params) != 1 or not isinstance(enc_params[0], bytes): raise KeyEncryptionError('Invalid PBES2 encryption parameters') if len(enc_params[0]) != iv_size: raise KeyEncryptionError('Invalid length IV for PBES2 encryption') return _RFC1423Pad(cipher_name, block_size, key, enc_params[0]) def _pbes2_pbkdf2(kdf_params: Sequence, passphrase: BytesOrStr, default_key_size: int) -> bytes: """PKCS#5 v2.0 handler for PBKDF2 key derivation This function parses the PBKDF2 arguments from a PKCS#8 encrypted key and returns the encryption key to use for encryption. """ if (len(kdf_params) != 1 or not isinstance(kdf_params[0], tuple) or len(kdf_params[0]) < 2): raise KeyEncryptionError('Invalid PBES2 key derivation parameters') kdf_params = list(kdf_params[0]) if (not isinstance(kdf_params[0], bytes) or not isinstance(kdf_params[1], int)): raise KeyEncryptionError('Invalid PBES2 key derivation parameters') salt = kdf_params.pop(0) count = kdf_params.pop(0) if kdf_params and isinstance(kdf_params[0], int): key_size = kdf_params.pop(0) # pragma: no cover, used only by RC2 else: key_size = default_key_size if kdf_params: if (isinstance(kdf_params[0], tuple) and len(kdf_params[0]) == 2 and isinstance(kdf_params[0][0], ObjectIdentifier)): prf_alg = kdf_params[0][0] if prf_alg in _pbes2_prf: hash_name = _pbes2_prf[prf_alg] else: raise KeyEncryptionError('Unknown PBES2 pseudo-random ' 'function') else: raise KeyEncryptionError('Invalid PBES2 pseudo-random function ' 'parameters') else: hash_name = 'sha1' if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') return pbkdf2_hmac(hash_name, passphrase, salt, count, key_size) def _pbes2(params: object, passphrase: BytesOrStr) -> _Cipher: """PKCS#5 v2.0 cipher selection function for password-based encryption This function implements the PKCS#5 v2.0 algorithm for password-based encryption. It returns a cipher object which can be used to encrypt or decrypt data based on the specified encryption parameters and passphrase. """ if (not isinstance(params, tuple) or len(params) != 2 or not isinstance(params[0], tuple) or len(params[0]) < 1 or not isinstance(params[1], tuple) or len(params[1]) < 1): raise KeyEncryptionError('Invalid PBES2 encryption parameters') kdf_params = list(params[0]) kdf_alg = kdf_params.pop(0) if kdf_alg not in _pbes2_kdf: raise KeyEncryptionError('Unknown PBES2 key derivation function') enc_params = list(params[1]) enc_alg = enc_params.pop(0) if enc_alg not in _pbes2_cipher: raise KeyEncryptionError('Unknown PBES2 encryption algorithm') kdf_handler, kdf_args = _pbes2_kdf[kdf_alg] enc_handler, cipher_name = _pbes2_cipher[enc_alg] default_key_size, _, _ = get_cipher_params(cipher_name) key = kdf_handler(kdf_params, passphrase, default_key_size, *kdf_args) return enc_handler(enc_params, cipher_name, key) def register_pkcs1_cipher(pkcs1_cipher_name: str, pkcs1_dek_name: bytes, cipher_name: str) -> None: """Register a cipher used for PKCS#1 private key encryption""" _pkcs1_cipher[pkcs1_dek_name] = cipher_name _pkcs1_dek_name[pkcs1_cipher_name] = pkcs1_dek_name def register_pkcs8_cipher(pkcs8_cipher_name: str, hash_name: str, pkcs8_cipher_oid: ObjectIdentifier, handler: _PKCS8CipherHandler, hash_alg: HashType, cipher_name: str) -> None: """Register a cipher used for PKCS#8 private key encryption""" _pkcs8_handler[pkcs8_cipher_oid] = (handler, hash_alg, cipher_name) _pkcs8_cipher_oid[pkcs8_cipher_name, hash_name] = pkcs8_cipher_oid def register_pbes2_cipher(pbes2_cipher_name: str, pbes2_cipher_oid: ObjectIdentifier, handler: _PBES2CipherHandler, cipher_name: str) -> None: """Register a PBES2 encryption algorithm""" _pbes2_cipher[pbes2_cipher_oid] = (handler, cipher_name) _pbes2_cipher_oid[pbes2_cipher_name] = pbes2_cipher_oid def register_pbes2_kdf(kdf_name: str, kdf_oid: ObjectIdentifier, handler: _PBES2KDFHandler, *args: object) -> None: """Register a PBES2 key derivation function""" _pbes2_kdf[kdf_oid] = (handler, args) _pbes2_kdf_oid[kdf_name] = kdf_oid def register_pbes2_prf(hash_name: str, prf_oid: ObjectIdentifier) -> None: """Register a PBES2 pseudo-random function""" _pbes2_prf[prf_oid] = hash_name _pbes2_prf_oid[hash_name] = prf_oid def pkcs1_encrypt(data: bytes, pkcs1_cipher_name: str, passphrase: BytesOrStr) -> Tuple[bytes, bytes, bytes]: """Encrypt PKCS#1 key data This function encrypts PKCS#1 key data using the specified cipher and passphrase. Available ciphers include: aes128-cbc, aes192-cbc, aes256-cbc, des-cbc, des3-cbc """ if pkcs1_cipher_name in _pkcs1_dek_name: pkcs1_dek_name = _pkcs1_dek_name[pkcs1_cipher_name] cipher_name = _pkcs1_cipher[pkcs1_dek_name] key_size, iv_size, block_size = get_cipher_params(cipher_name) iv = os.urandom(iv_size) key = _pbkdf1(md5, passphrase, iv[:8], 1, key_size) cipher = _RFC1423Pad(cipher_name, block_size, key, iv) return pkcs1_dek_name, iv, cipher.encrypt(data) else: raise KeyEncryptionError('Unknown PKCS#1 encryption algorithm') def pkcs1_decrypt(data: bytes, pkcs1_dek_name: bytes, iv: bytes, passphrase: BytesOrStr) -> bytes: """Decrypt PKCS#1 key data This function decrypts PKCS#1 key data using the specified algorithm, initialization vector, and passphrase. The algorithm name and IV should be taken from the PEM DEK-Info header. """ if pkcs1_dek_name in _pkcs1_cipher: cipher_name = _pkcs1_cipher[pkcs1_dek_name] key_size, _, block_size = get_cipher_params(cipher_name) key = _pbkdf1(md5, passphrase, iv[:8], 1, key_size) cipher = _RFC1423Pad(cipher_name, block_size, key, iv) return cipher.decrypt(data) else: raise KeyEncryptionError('Unknown PKCS#1 encryption algorithm') def pkcs8_encrypt(data: bytes, pkcs8_cipher_name: str, hash_name: str, version: int, passphrase: BytesOrStr) -> bytes: """Encrypt PKCS#8 key data This function encrypts PKCS#8 key data using the specified cipher, hash, encryption version, and passphrase. Available ciphers include: aes128-cbc, aes192-cbc, aes256-cbc, blowfish-cbc, cast128-cbc, des-cbc, des2-cbc, des3-cbc, rc4-40, and rc4-128 Available hashes include: md5, sha1, sha256, sha384, sha512 Available versions include 1 for PBES1 and 2 for PBES2. Only some combinations of cipher, hash, and version are supported. """ if version == 1 and (pkcs8_cipher_name, hash_name) in _pkcs8_cipher_oid: pkcs8_cipher_oid = _pkcs8_cipher_oid[pkcs8_cipher_name, hash_name] handler, hash_alg, cipher_name = _pkcs8_handler[pkcs8_cipher_oid] alg = pkcs8_cipher_oid params: object = (os.urandom(8), 2048) cipher = handler(params, passphrase, hash_alg, cipher_name) elif version == 2 and pkcs8_cipher_name in _pbes2_cipher_oid: pbes2_cipher_oid = _pbes2_cipher_oid[pkcs8_cipher_name] _, cipher_name = _pbes2_cipher[pbes2_cipher_oid] _, iv_size, _ = get_cipher_params(cipher_name) kdf_params = [os.urandom(8), 2048] iv = os.urandom(iv_size) enc_params = (pbes2_cipher_oid, iv) if hash_name != 'sha1': if hash_name in _pbes2_prf_oid: kdf_params.append((_pbes2_prf_oid[hash_name], None)) else: raise KeyEncryptionError('Unknown PBES2 hash function') alg = _ES2 params = ((_ES2_PBKDF2, tuple(kdf_params)), enc_params) cipher = _pbes2(params, passphrase) else: raise KeyEncryptionError('Unknown PKCS#8 encryption algorithm') return der_encode(((alg, params), cipher.encrypt(data))) def pkcs8_decrypt(key_data: object, passphrase: BytesOrStr) -> object: """Decrypt PKCS#8 key data This function decrypts key data in PKCS#8 EncryptedPrivateKeyInfo format using the specified passphrase. """ if not isinstance(key_data, tuple) or len(key_data) != 2: raise KeyEncryptionError('Invalid PKCS#8 encrypted key format') alg_params, data = key_data if (not isinstance(alg_params, tuple) or len(alg_params) != 2 or not isinstance(data, bytes)): raise KeyEncryptionError('Invalid PKCS#8 encrypted key format') alg, params = alg_params if alg == _ES2: cipher = _pbes2(params, passphrase) elif alg in _pkcs8_handler: handler, hash_alg, cipher_name = _pkcs8_handler[alg] cipher = handler(params, passphrase, hash_alg, cipher_name) else: raise KeyEncryptionError('Unknown PKCS#8 encryption algorithm') try: return der_decode(cipher.decrypt(data)) except (ASN1DecodeError, UnicodeDecodeError): raise KeyEncryptionError('Invalid PKCS#8 encrypted key data') from None _pkcs1_cipher_list = ( ('aes128-cbc', b'AES-128-CBC', 'aes128-cbc'), ('aes192-cbc', b'AES-192-CBC', 'aes192-cbc'), ('aes256-cbc', b'AES-256-CBC', 'aes256-cbc'), ('des-cbc', b'DES-CBC', 'des-cbc'), ('des3-cbc', b'DES-EDE3-CBC', 'des3-cbc') ) _pkcs8_cipher_list = ( ('des-cbc', 'md5', _ES1_MD5_DES, _pbes1, md5, 'des-cbc'), ('des-cbc', 'sha1', _ES1_SHA1_DES, _pbes1, sha1, 'des-cbc'), ('des2-cbc','sha1', _P12_DES2, _pbe_p12, sha1, 'des2-cbc'), ('des3-cbc','sha1', _P12_DES3, _pbe_p12, sha1, 'des3-cbc'), ('rc4-40', 'sha1', _P12_RC4_40, _pbe_p12, sha1, 'arcfour40'), ('rc4-128', 'sha1', _P12_RC4_128, _pbe_p12, sha1, 'arcfour') ) _pbes2_cipher_list = ( ('aes128-cbc', _ES2_AES128, _pbes2_iv, 'aes128-cbc'), ('aes192-cbc', _ES2_AES192, _pbes2_iv, 'aes192-cbc'), ('aes256-cbc', _ES2_AES256, _pbes2_iv, 'aes256-cbc'), ('blowfish-cbc', _ES2_BF, _pbes2_iv, 'blowfish-cbc'), ('cast128-cbc', _ES2_CAST128, _pbes2_iv, 'cast128-cbc'), ('des-cbc', _ES2_DES, _pbes2_iv, 'des-cbc'), ('des3-cbc', _ES2_DES3, _pbes2_iv, 'des3-cbc') ) _pbes2_kdf_list = ( ('pbkdf2', _ES2_PBKDF2, _pbes2_pbkdf2), ) _pbes2_prf_list = ( ('sha1', _ES2_SHA1), ('sha224', _ES2_SHA224), ('sha256', _ES2_SHA256), ('sha384', _ES2_SHA384), ('sha512', _ES2_SHA512) ) for _pkcs1_cipher_args in _pkcs1_cipher_list: register_pkcs1_cipher(*_pkcs1_cipher_args) for _pkcs8_cipher_args in _pkcs8_cipher_list: register_pkcs8_cipher(*_pkcs8_cipher_args) for _pbes2_cipher_args in _pbes2_cipher_list: register_pbes2_cipher(*_pbes2_cipher_args) for _pbes2_kdf_args in _pbes2_kdf_list: register_pbes2_kdf(*_pbes2_kdf_args) for _pbes2_prf_args in _pbes2_prf_list: register_pbes2_prf(*_pbes2_prf_args) asyncssh-2.20.0/asyncssh/pkcs11.py000066400000000000000000000263271475467777400170030ustar00rootroot00000000000000# Copyright (c) 2020-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """PKCS#11 smart card handler""" from types import TracebackType from typing import Dict, List, Optional, Sequence, Tuple, Type, Union, cast try: import pkcs11 from pkcs11 import Attribute, KeyType, Mechanism, ObjectClass from pkcs11 import PrivateKey, Token from pkcs11.util.rsa import encode_rsa_public_key from pkcs11.util.ec import encode_ec_public_key pkcs11_available = True except (ImportError, ModuleNotFoundError): # pragma: no cover pkcs11_available = False from .misc import BytesOrStr from .packet import MPInt, String from .public_key import SSHCertificate, SSHKey, SSHKeyPair from .public_key import import_certificate_chain, import_public_key _AttrDict = Dict['Attribute', Union[bool, bytes, str, 'ObjectClass']] _TokenID = Tuple[str, bytes] _SessionMap = Dict[_TokenID, 'SSHPKCS11Session'] if pkcs11_available: encoders = {KeyType.RSA: encode_rsa_public_key, KeyType.EC: encode_ec_public_key} mechanisms = {b'ssh-rsa': Mechanism.SHA1_RSA_PKCS, b'rsa-sha2-256': Mechanism.SHA256_RSA_PKCS, b'rsa-sha2-512': Mechanism.SHA512_RSA_PKCS, b'ssh-rsa-sha224@ssh.com': Mechanism.SHA224_RSA_PKCS, b'ssh-rsa-sha256@ssh.com': Mechanism.SHA256_RSA_PKCS, b'ssh-rsa-sha384@ssh.com': Mechanism.SHA384_RSA_PKCS, b'ssh-rsa-sha512@ssh.com': Mechanism.SHA512_RSA_PKCS, b'rsa1024-sha1': Mechanism.SHA1_RSA_PKCS, b'rsa2048-sha256': Mechanism.SHA256_RSA_PKCS, b'ecdsa-sha2-nistp256': Mechanism.ECDSA_SHA256, b'ecdsa-sha2-nistp384': Mechanism.ECDSA_SHA384, b'ecdsa-sha2-nistp521': Mechanism.ECDSA_SHA512} class SSHPKCS11KeyPair(SSHKeyPair): """Surrogate for a key accessed via a PKCS#11 provider""" _key_type = 'pkcs11' def __init__(self, session: 'SSHPKCS11Session', privkey: PrivateKey, pubkey: SSHKey, cert: Optional[SSHCertificate] = None): super().__init__(pubkey.algorithm, pubkey.algorithm, pubkey.sig_algorithms, pubkey.sig_algorithms, pubkey.public_data, privkey.label, cert, use_executor=True) self._session = session self._privkey = privkey def __del__(self) -> None: self._session.close() def sign(self, data: bytes) -> bytes: """Sign a block of data with this private key""" sig_algorithm = self.sig_algorithm if sig_algorithm.startswith(b'x509v3-'): sig_algorithm = sig_algorithm[7:] sig = self._privkey.sign(data, mechanism=mechanisms[sig_algorithm]) if self._privkey.key_type == KeyType.EC: length = len(sig) // 2 r = int.from_bytes(sig[:length], 'big') s = int.from_bytes(sig[length:], 'big') sig = MPInt(r) + MPInt(s) return String(sig_algorithm) + String(sig) class SSHPKCS11Session: """Work around PKCS#11 sessions not supporting simultaneous opens""" _sessions: _SessionMap = {} def __init__(self, token_id: _TokenID, token: Token, pin: Optional[str]): self._token_id = token_id self._session = token.open(user_pin=pin) self._refcount = 0 def __enter__(self) -> 'SSHPKCS11Session': """Allow SSHPKCS11Session to be used as a context manager""" return self def __exit__(self, _exc_type: Type[BaseException], _exc_value: BaseException, _traceback: TracebackType) -> None: """Drop one reference to the session when exiting""" self.close() @classmethod def open(cls, token: Token, pin: Optional[str]) -> 'SSHPKCS11Session': """Open a new session, or return an already-open one""" token_id = (token.manufacturer_id, token.serial) try: session = cls._sessions[token_id] except KeyError: session = cls(token_id, token, pin) cls._sessions[token_id] = session session._refcount += 1 return session def close(self) -> None: """Drop one reference to an open session""" self._refcount -= 1 if self._refcount == 0: self._session.close() del self._sessions[self._token_id] def get_keys(self, load_certs: bool, key_label: Optional[str], key_id: Optional[BytesOrStr]) -> \ Sequence[SSHPKCS11KeyPair]: """Return the private keys found on this token""" if isinstance(key_id, str): key_id = bytes.fromhex(key_id) key_attrs: _AttrDict = {Attribute.CLASS: ObjectClass.PRIVATE_KEY, Attribute.SIGN: True} if key_label is not None: key_attrs[Attribute.LABEL] = key_label if key_id is not None: key_attrs[Attribute.OBJECT_ID] = key_id cert_attrs: _AttrDict = {Attribute.CLASS: ObjectClass.CERTIFICATE} if load_certs: certs = [import_certificate_chain( cast(bytes, cert[Attribute.VALUE])) for cert in self._session.get_objects(cert_attrs)] certdict = {cert.key.public_data: cert for cert in certs if cert and 'Attest' not in str(cert.subject)} else: certdict = {} keys = [] for key in self._session.get_objects(key_attrs): privkey = cast(PrivateKey, key) encoder = encoders.get(privkey.key_type) if encoder: pubkey = import_public_key(encoder(privkey)) cert = certdict.get(pubkey.public_data) if cert: keys.append(SSHPKCS11KeyPair(self, privkey, pubkey, cert)) keys.append(SSHPKCS11KeyPair(self, privkey, pubkey)) self._refcount += len(keys) return keys def load_pkcs11_keys(provider: str, pin: Optional[str] = None, *, load_certs: bool = True, token_label: Optional[str] = None, token_serial: Optional[BytesOrStr] = None, key_label: Optional[str] = None, key_id: Optional[BytesOrStr] = None) -> \ Sequence[SSHPKCS11KeyPair]: """Load PIV keys and X.509 certificates from a PKCS#11 token This function loads a list of SSH keypairs with optional X.509 cerificates from attached PKCS#11 security tokens. The PKCS#11 provider must be specified, along with a user PIN if the tokens are set to require one. By default, this function loads both private key handles and the X.509 certificates associated with them, allowing for X.509 certificate based auth to SSH servers that support it. To disable loading of these certificates and perform only key-based authentication, load_certs may be set to `False`. If token_label and/or token_serial are specified, only tokens matching those values will be accessed. If key_label and/or key_id are specified, only keys matching those values will be loaded. Key IDs can be specified as either raw bytes or a string containing hex digits. .. note:: If you have an active asyncio event loop at the time you call this function, you may want to consider running it via a call to :meth:`asyncio.AbstractEventLoop.run_in_executor`. While retrieving the keys generally takes only a fraction of a second, calling this function directly could block asyncio event processing until it completes. :param provider: The path to the PKCS#11 provider's shared library. :param pin: (optional) The PIN to use when accessing tokens, if needed. :param load_certs: (optional) Whether or not to load X.509 certificates from the security tokens. :param token_label: (optional) A token label to match against. If set, only security tokens with this label will be accessed. :param token_serial: (optional) A token serial number to match against. If set, only security tokens with this serial number will be accessed. :param key_label: (optional) A key label to match against. If set, only keys with this label will be loaded. :param key_id: (optional) A key ID to match against. If set, only keys with this ID will be loaded. :type provider: `str` :type pin: `str` :type load_certs: `bool` :type token_label: `str` :type token_serial: `bytes` or `str` :type key_label: `str` :type key_id: `bytes` or `str` :returns: list of class:`SSHKeyPair` """ lib = pkcs11.lib(provider) keys: List[SSHPKCS11KeyPair] = [] if isinstance(token_serial, str): token_serial = token_serial.encode('utf-8') for token in lib.get_tokens(token_label=token_label, token_serial=token_serial): with SSHPKCS11Session.open(token, pin) as session: keys.extend(session.get_keys(load_certs, key_label, key_id)) return keys else: # pragma: no cover def load_pkcs11_keys(provider: str, pin: Optional[str] = None, *, load_certs: bool = True, token_label: Optional[str] = None, token_serial: Optional[BytesOrStr] = None, key_label: Optional[str] = None, key_id: Optional[BytesOrStr] = None) -> \ Sequence['SSHPKCS11KeyPair']: """Report that PKCS#11 support is not available""" raise ValueError('PKCS#11 support not available') from None asyncssh-2.20.0/asyncssh/process.py000066400000000000000000002002011475467777400173400ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH process handlers""" import asyncio from asyncio.subprocess import DEVNULL, PIPE, STDOUT import codecs import inspect import io import os from pathlib import PurePath import socket import stat from types import TracebackType from typing import Any, AnyStr, Awaitable, Callable, Dict, Generic, IO from typing import Iterable, List, Mapping, Optional, Set, TextIO from typing import Tuple, Type, TypeVar, Union, cast from typing_extensions import Protocol, Self from .channel import SSHChannel, SSHClientChannel, SSHServerChannel from .constants import DEFAULT_LANG, EXTENDED_DATA_STDERR from .logging import SSHLogger from .misc import BytesOrStr, Error, MaybeAwait, TermModes, TermSize from .misc import ProtocolError, Record, open_file, set_terminal_size from .misc import BreakReceived, SignalReceived, TerminalSizeChanged from .session import DataType from .stream import SSHReader, SSHWriter, SSHStreamSession from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SFTPServerFactory _AnyStrContra = TypeVar('_AnyStrContra', bytes, str, contravariant=True) _File = Union[IO[bytes], '_AsyncFileProtocol[bytes]'] ProcessSource = Union[int, str, socket.socket, PurePath, SSHReader[bytes], asyncio.StreamReader, _File] ProcessTarget = Union[int, str, socket.socket, PurePath, SSHWriter[bytes], asyncio.StreamWriter, _File] SSHServerProcessFactory = Callable[['SSHServerProcess[AnyStr]'], MaybeAwait[None]] _QUEUE_LOW_WATER = 8 _QUEUE_HIGH_WATER = 16 class _AsyncFileProtocol(Protocol[AnyStr]): """Protocol for an async file""" async def read(self, n: int = -1) -> AnyStr: """Read from an async file""" async def write(self, data: AnyStr) -> None: """Write to an async file""" async def close(self) -> None: """Close an async file""" class _ReaderProtocol(Protocol): """A class that can be used as a reader in SSHProcess""" def pause_reading(self) -> None: """Pause reading""" def resume_reading(self) -> None: """Resume reading""" def close(self) -> None: """Stop forwarding data""" class _WriterProtocol(Protocol[_AnyStrContra]): """A class that can be used as a writer in SSHProcess""" def write(self, data: _AnyStrContra) -> None: """Write data""" def write_exception(self, exc: Exception) -> None: """Write exception (break, signal, terminal size change)""" return # pragma: no cover def write_eof(self) -> None: """Close output when end of file is received""" def close(self) -> None: """Stop forwarding data""" def _is_regular_file(file: IO[bytes]) -> bool: """Return if argument is a regular file or file-like object""" try: return stat.S_ISREG(os.fstat(file.fileno()).st_mode) except OSError: return True class _UnicodeReader(_ReaderProtocol, Generic[AnyStr]): """Handle buffering partial Unicode data""" def __init__(self, encoding: Optional[str], errors: str, textmode: bool = False): super().__init__() if encoding and not textmode: self._decoder: Optional[codecs.IncrementalDecoder] = \ codecs.getincrementaldecoder(encoding)(errors) else: self._decoder = None def decode(self, data: bytes, final: bool = False) -> AnyStr: """Decode Unicode bytes when reading from binary sources""" if self._decoder: try: decoded_data = cast(AnyStr, self._decoder.decode(data, final)) except UnicodeDecodeError as exc: raise ProtocolError(str(exc)) from None else: decoded_data = cast(AnyStr, data) return decoded_data def check_partial(self) -> None: """Check if there's partial Unicode data left at EOF""" self.decode(b'', True) def close(self) -> None: """Perform necessary cleanup on error (provided by derived classes)""" class _UnicodeWriter(_WriterProtocol[AnyStr]): """Handle encoding Unicode data before writing it""" def __init__(self, encoding: Optional[str], errors: str, textmode: bool = False): super().__init__() if encoding and not textmode: self._encoder: Optional[codecs.IncrementalEncoder] = \ codecs.getincrementalencoder(encoding)(errors) else: self._encoder = None def encode(self, data: AnyStr) -> bytes: """Encode Unicode bytes when writing to binary targets""" if self._encoder: assert self._encoder is not None encoded_data = cast(bytes, self._encoder.encode(cast(str, data))) else: encoded_data = cast(bytes, data) return encoded_data class _FileReader(_UnicodeReader[AnyStr]): """Forward data from a file""" def __init__(self, process: 'SSHProcess[AnyStr]', file: IO[bytes], bufsize: int, datatype: DataType, encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._process: 'SSHProcess[AnyStr]' = process self._file = file self._bufsize = bufsize self._datatype = datatype self._paused = False def feed(self) -> None: """Feed file data""" while not self._paused: data = self._file.read(self._bufsize) if data: self._process.feed_data(self.decode(data), self._datatype) else: self.check_partial() self._process.feed_eof(self._datatype) break def pause_reading(self) -> None: """Pause reading from the file""" self._paused = True def resume_reading(self) -> None: """Resume reading from the file""" self._paused = False self.feed() def close(self) -> None: """Stop forwarding data from the file""" self._file.close() class _AsyncFileReader(_UnicodeReader[AnyStr]): """Forward data from an aiofile""" def __init__(self, process: 'SSHProcess[AnyStr]', file: _AsyncFileProtocol[bytes], bufsize: int, datatype: DataType, encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._conn = process.channel.get_connection() self._process: 'SSHProcess[AnyStr]' = process self._file = file self._bufsize = bufsize self._datatype = datatype self._paused = False async def _feed(self) -> None: """Feed file data""" while not self._paused: data = await self._file.read(self._bufsize) if data: self._process.feed_data(self.decode(data), self._datatype) else: self.check_partial() self._process.feed_eof(self._datatype) break def feed(self) -> None: """Start feeding file data""" self._conn.create_task(self._feed()) def pause_reading(self) -> None: """Pause reading from the file""" self._paused = True def resume_reading(self) -> None: """Resume reading from the file""" self._paused = False self.feed() def close(self) -> None: """Stop forwarding data from the file""" self._conn.create_task(self._file.close()) class _FileWriter(_UnicodeWriter[AnyStr]): """Forward data to a file""" def __init__(self, file: IO[bytes], needs_close: bool, encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._file = file self._needs_close = needs_close def write(self, data: AnyStr) -> None: """Write data to the file""" self._file.write(self.encode(data)) def write_eof(self) -> None: """Close output file when end of file is received""" self.close() def close(self) -> None: """Stop forwarding data to the file""" if self._needs_close: self._file.close() class _AsyncFileWriter(_UnicodeWriter[AnyStr]): """Forward data to an aiofile""" def __init__(self, process: 'SSHProcess[AnyStr]', file: _AsyncFileProtocol[bytes], needs_close: bool, datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._process: 'SSHProcess[AnyStr]' = process self._file = file self._needs_close = needs_close self._datatype = datatype self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._writer()) async def _writer(self) -> None: """Process writes to the file""" while True: data = await self._queue.get() if data is None: self._queue.task_done() break await self._file.write(self.encode(data)) self._queue.task_done() if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: self._process.resume_feeding(self._datatype) self._paused = False if self._needs_close: await self._file.close() def write(self, data: AnyStr) -> None: """Write data to the file""" self._queue.put_nowait(data) if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: self._paused = True self._process.pause_feeding(self._datatype) def write_eof(self) -> None: """Close output file when end of file is received""" self.close() def close(self) -> None: """Stop forwarding data to the file""" if self._write_task: self._write_task = None self._queue.put_nowait(None) self._process.add_cleanup_task(self._queue.join()) class _PipeReader(_UnicodeReader[AnyStr], asyncio.BaseProtocol): """Forward data from a pipe""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType, encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype self._transport: Optional[asyncio.ReadTransport] = None def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened pipe""" self._transport = cast(asyncio.ReadTransport, transport) def connection_lost(self, exc: Optional[Exception]) -> None: """Handle closing of the pipe""" self._process.feed_close(self._datatype) self.close() def data_received(self, data: bytes) -> None: """Forward data from the pipe""" self._process.feed_data(self.decode(data), self._datatype) def eof_received(self) -> None: """Forward EOF from the pipe""" self.check_partial() self._process.feed_eof(self._datatype) def pause_reading(self) -> None: """Pause reading from the pipe""" assert self._transport is not None self._transport.pause_reading() def resume_reading(self) -> None: """Resume reading from the pipe""" assert self._transport is not None self._transport.resume_reading() def close(self) -> None: """Stop forwarding data from the pipe""" assert self._transport is not None self._transport.close() class _PipeWriter(_UnicodeWriter[AnyStr], asyncio.BaseProtocol): """Forward data to a pipe""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType, encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype self._transport: Optional[asyncio.WriteTransport] = None self._tty: Optional[IO] = None self._close_event = asyncio.Event() def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened pipe""" self._transport = cast(asyncio.WriteTransport, transport) pipe = transport.get_extra_info('pipe') if isinstance(self._process, SSHServerProcess) and pipe.isatty(): self._tty = pipe set_terminal_size(pipe, *self._process.term_size) def connection_lost(self, exc: Optional[Exception]) -> None: """Handle closing of the pipe""" self._close_event.set() def pause_writing(self) -> None: """Pause writing to the pipe""" self._process.pause_feeding(self._datatype) def resume_writing(self) -> None: """Resume writing to the pipe""" self._process.resume_feeding(self._datatype) def write(self, data: AnyStr) -> None: """Write data to the pipe""" assert self._transport is not None self._transport.write(self.encode(data)) def write_exception(self, exc: Exception) -> None: """Write terminal size changes to the pipe if it is a TTY""" if isinstance(exc, TerminalSizeChanged) and self._tty: set_terminal_size(self._tty, *exc.term_size) def write_eof(self) -> None: """Write EOF to the pipe""" assert self._transport is not None self._transport.write_eof() def close(self) -> None: """Stop forwarding data to the pipe""" assert self._transport is not None self._transport.close() self._process.add_cleanup_task(self._close_event.wait()) class _ProcessReader(_ReaderProtocol, Generic[AnyStr]): """Forward data from another SSH process""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType): super().__init__() self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype def pause_reading(self) -> None: """Pause reading from the other channel""" self._process.pause_feeding(self._datatype) def resume_reading(self) -> None: """Resume reading from the other channel""" self._process.resume_feeding(self._datatype) def close(self) -> None: """Stop forwarding data from the other channel""" self._process.clear_writer(self._datatype) class _ProcessWriter(_WriterProtocol[AnyStr]): """Forward data to another SSH process""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType): super().__init__() self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype def write(self, data: AnyStr) -> None: """Write data to the other channel""" self._process.feed_data(data, self._datatype) def write_exception(self, exc: Exception) -> None: """Write an exception to the other channel""" cast(SSHClientProcess, self._process).feed_exception(exc) def write_eof(self) -> None: """Write EOF to the other channel""" self._process.feed_eof(self._datatype) def close(self) -> None: """Stop forwarding data to the other channel""" self._process.clear_reader(self._datatype) class _StreamReader(_UnicodeReader[AnyStr]): """Forward data from an asyncio stream""" def __init__(self, process: 'SSHProcess[AnyStr]', reader: asyncio.StreamReader, bufsize: int, datatype: DataType, encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._conn = process.channel.get_connection() self._reader = reader self._bufsize = bufsize self._datatype = datatype self._paused = False async def _feed(self) -> None: """Feed stream data""" while not self._paused: data = await self._reader.read(self._bufsize) if data: self._process.feed_data(self.decode(data), self._datatype) else: self.check_partial() self._process.feed_eof(self._datatype) break def feed(self) -> None: """Start feeding stream data""" self._conn.create_task(self._feed()) def pause_reading(self) -> None: """Pause reading from the stream""" self._paused = True def resume_reading(self) -> None: """Resume reading from the stream""" self._paused = False self.feed() def close(self) -> None: """Ignore close -- the caller must clean up the associated transport""" class _StreamWriter(_UnicodeWriter[AnyStr]): """Forward data to an asyncio stream""" def __init__(self, process: 'SSHProcess[AnyStr]', writer: asyncio.StreamWriter, recv_eof: bool, datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._writer = writer self._recv_eof = recv_eof self._datatype = datatype self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._feed()) async def _feed(self) -> None: """Feed data to the stream""" while True: data = await self._queue.get() if data is None: self._queue.task_done() break self._writer.write(self.encode(data)) await self._writer.drain() self._queue.task_done() if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: self._process.resume_feeding(self._datatype) self._paused = False if self._recv_eof: self._writer.write_eof() def write(self, data: AnyStr) -> None: """Write data to the stream""" self._queue.put_nowait(data) if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: self._paused = True self._process.pause_feeding(self._datatype) def write_eof(self) -> None: """Write EOF to the stream""" self.close() def close(self) -> None: """Stop forwarding data to the stream""" if self._write_task: self._write_task = None self._queue.put_nowait(None) self._process.add_cleanup_task(self._queue.join()) class _DevNullWriter(_WriterProtocol[AnyStr]): """Discard data""" def write(self, data: AnyStr) -> None: """Discard data being written""" def write_eof(self) -> None: """Ignore end of file""" def close(self) -> None: """Ignore close""" class _StdoutWriter(_WriterProtocol[AnyStr]): """Forward data to an SSH process' stdout instead of stderr""" def __init__(self, process: 'SSHProcess[AnyStr]'): super().__init__() self._process: 'SSHProcess[AnyStr]' = process def write(self, data: AnyStr) -> None: """Pretend data was received on stdout""" self._process.data_received(data, None) def write_eof(self) -> None: """Ignore end of file""" def close(self) -> None: """Ignore close""" class ProcessError(Error): """SSH Process error This exception is raised when an :class:`SSHClientProcess` exits with a non-zero exit status and error checking is enabled. In addition to the usual error code, reason, and language, it contains the following fields: ============ ======================================= ================= Field Description Type ============ ======================================= ================= env The environment the client requested `str` or `None` to be set for the process command The command the client requested the `str` or `None` process to execute (if any) subsystem The subsystem the client requested the `str` or `None` process to open (if any) exit_status The exit status returned, or -1 if an `int` or `None` exit signal is sent exit_signal The exit signal sent (if any) in the `tuple` or `None` form of a tuple containing the signal name, a `bool` for whether a core dump occurred, a message associated with the signal, and the language the message was in returncode The exit status returned, or negative `int` or `None` of the signal number when an exit signal is sent stdout The output sent by the process to `str` or `bytes` stdout (if not redirected) stderr The output sent by the process to `str` or `bytes` stderr (if not redirected) ============ ======================================= ================= """ def __init__(self, env: Optional[Mapping[str, str]], command: Optional[str], subsystem: Optional[str], exit_status: Optional[int], exit_signal: Optional[Tuple[str, bool, str, str]], returncode: Optional[int], stdout: BytesOrStr, stderr: BytesOrStr, reason: str = '', lang: str = DEFAULT_LANG): self.env = env self.command = command self.subsystem = subsystem self.exit_status = exit_status self.exit_signal = exit_signal self.returncode = returncode self.stdout = stdout self.stderr = stderr if exit_signal: signal, core_dumped, msg, lang = exit_signal reason = 'Process exited with signal ' + signal + \ (': ' + msg if msg else '') + \ (' (core dumped)' if core_dumped else '') elif exit_status: reason = f'Process exited with non-zero exit status {exit_status}' super().__init__(exit_status or 0, reason, lang) # pylint: disable=redefined-builtin class TimeoutError(ProcessError, asyncio.TimeoutError): """SSH Process timeout error This exception is raised when a timeout occurs when calling the :meth:`wait ` method on :class:`SSHClientProcess` or the :meth:`run ` method on :class:`SSHClientConnection`. It is a subclass of :class:`ProcessError` and contains all of the fields documented there, including any output received on stdout and stderr prior to when the timeout occurred. It is also a subclass of :class:`asyncio.TimeoutError`, for code that might be expecting that. """ # pylint: enable=redefined-builtin class SSHCompletedProcess(Record): """Results from running an SSH process This object is returned by the :meth:`run ` method on :class:`SSHClientConnection` when the requested command has finished running. It contains the following fields: ============ ======================================= ================= Field Description Type ============ ======================================= ================= env The environment the client requested `dict` or `None` to be set for the process command The command the client requested the `str` or `None` process to execute (if any) subsystem The subsystem the client requested the `str` or `None` process to open (if any) exit_status The exit status returned, or -1 if an `int` exit signal is sent exit_signal The exit signal sent (if any) in the `tuple` or `None` form of a tuple containing the signal name, a `bool` for whether a core dump occurred, a message associated with the signal, and the language the message was in returncode The exit status returned, or negative `int` of the signal number when an exit signal is sent stdout The output sent by the process to `str` or `bytes` stdout (if not redirected) stderr The output sent by the process to `str` or `bytes` stderr (if not redirected) ============ ======================================= ================= """ env: Optional[Mapping[str, str]] command: Optional[str] subsystem: Optional[str] exit_status: Optional[int] exit_signal: Optional[Tuple[str, bool, str, str]] returncode: Optional[int] stdout: Optional[BytesOrStr] stderr: Optional[BytesOrStr] class SSHProcess(SSHStreamSession, Generic[AnyStr]): """SSH process handler""" def __init__(self, *args) -> None: super().__init__(*args) self._cleanup_tasks: List[Awaitable[None]] = [] self._readers: Dict[Optional[int], _ReaderProtocol] = {} self._send_eof: Dict[Optional[int], bool] = {} self._writers: Dict[Optional[int], _WriterProtocol[AnyStr]] = {} self._recv_eof: Dict[Optional[int], bool] = {} self._paused_write_streams: Set[Optional[int]] = set() async def __aenter__(self) -> Self: """Allow SSHProcess to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: """Wait for a full channel close when exiting the async context""" self.close() await self.wait_closed() return False @property def channel(self) -> SSHChannel[AnyStr]: """The channel associated with the process""" assert self._chan is not None return self._chan @property def logger(self) -> SSHLogger: """The logger associated with the process""" assert self._chan is not None return self._chan.logger @property def command(self) -> Optional[str]: """The command the client requested to execute, if any If the client did not request that a command be executed, this property will be set to `None`. """ assert self._chan is not None return self._chan.get_command() @property def subsystem(self) -> Optional[str]: """The subsystem the client requested to open, if any If the client did not request that a subsystem be opened, this property will be set to `None`. """ assert self._chan is not None return self._chan.get_subsystem() @property def env(self) -> Mapping[str, str]: """A mapping containing the environment set by the client""" assert self._chan is not None return self._chan.get_environment() def get_extra_info(self, name: str, default: Any = None) -> Any: """Return additional information about this process This method returns extra information about the channel associated with this process. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ assert self._chan is not None return self._chan.get_extra_info(name, default) async def _create_reader(self, source: ProcessSource, bufsize: int, send_eof: bool, recv_eof: bool, datatype: DataType = None) -> None: """Create a reader to forward data to the SSH channel""" def pipe_factory() -> _PipeReader: """Return a pipe read handler""" return _PipeReader(self, datatype, self._encoding, self._errors) if source == PIPE: reader: Optional[_ReaderProtocol] = None elif source == DEVNULL: assert self._chan is not None self._chan.write_eof() reader = None elif isinstance(source, SSHReader): reader_stream, reader_datatype = source.get_redirect_info() reader_process = cast('SSHProcess[AnyStr]', reader_stream) writer = _ProcessWriter[AnyStr](self, datatype) reader_process.set_writer(writer, recv_eof, reader_datatype) reader = _ProcessReader(reader_process, reader_datatype) elif isinstance(source, asyncio.StreamReader): reader = _StreamReader(self, source, bufsize, datatype, self._encoding, self._errors) else: file: _File if isinstance(source, str): file = open_file(source, 'rb', buffering=bufsize) elif isinstance(source, PurePath): file = open_file(str(source), 'rb', buffering=bufsize) elif isinstance(source, int): file = os.fdopen(source, 'rb', buffering=bufsize) elif isinstance(source, socket.socket): file = os.fdopen(source.detach(), 'rb', buffering=bufsize) else: file = source if hasattr(file, 'read') and \ (inspect.iscoroutinefunction(file.read) or inspect.isgeneratorfunction(file.read)): reader = _AsyncFileReader(self, cast(_AsyncFileProtocol, file), bufsize, datatype, self._encoding, self._errors) elif _is_regular_file(cast(IO[bytes], file)): reader = _FileReader(self, cast(IO[bytes], file), bufsize, datatype, self._encoding, self._errors) else: if hasattr(source, 'buffer'): # If file was opened in text mode, remove that wrapper file = cast(TextIO, source).buffer assert self._loop is not None _, protocol = \ await self._loop.connect_read_pipe(pipe_factory, file) reader = cast(_PipeReader, protocol) self.set_reader(reader, send_eof, datatype) if isinstance(reader, (_FileReader, _AsyncFileReader, _StreamReader)): reader.feed() elif isinstance(reader, _ProcessReader): reader_process.feed_recv_buf(reader_datatype, writer) async def _create_writer(self, target: ProcessTarget, bufsize: int, send_eof: bool, recv_eof: bool, datatype: DataType = None) -> None: """Create a writer to forward data from the SSH channel""" def pipe_factory() -> _PipeWriter: """Return a pipe write handler""" return _PipeWriter(self, datatype, self._encoding, self._errors) if target == PIPE: writer: Optional[_WriterProtocol[AnyStr]] = None elif target == DEVNULL: writer = _DevNullWriter() elif target == STDOUT: writer = _StdoutWriter(self) elif isinstance(target, SSHWriter): writer_stream, writer_datatype = target.get_redirect_info() writer_process = cast('SSHProcess[AnyStr]', writer_stream) reader = _ProcessReader(self, datatype) writer_process.set_reader(reader, send_eof, writer_datatype) writer = _ProcessWriter[AnyStr](writer_process, writer_datatype) elif isinstance(target, asyncio.StreamWriter): writer = _StreamWriter(self, target, recv_eof, datatype, self._encoding, self._errors) else: file: _File needs_close = True if isinstance(target, str): file = open_file(target, 'wb', buffering=bufsize) elif isinstance(target, PurePath): file = open_file(str(target), 'wb', buffering=bufsize) elif isinstance(target, int): file = os.fdopen(target, 'wb', buffering=bufsize, closefd=recv_eof) elif isinstance(target, socket.socket): fd = target.detach() if recv_eof else target.fileno() file = os.fdopen(fd, 'wb', buffering=bufsize, closefd=recv_eof) else: file = target needs_close = recv_eof if hasattr(file, 'write') and \ (inspect.iscoroutinefunction(file.write) or inspect.isgeneratorfunction(file.write)): writer = _AsyncFileWriter( self, cast(_AsyncFileProtocol, file), needs_close, datatype, self._encoding, self._errors) elif _is_regular_file(cast(IO[bytes], file)): writer = _FileWriter(cast(IO[bytes], file), needs_close, self._encoding, self._errors) else: if hasattr(target, 'buffer'): # If file was opened in text mode, remove that wrapper file = cast(TextIO, target).buffer if not recv_eof: fd = os.dup(cast(IO[bytes], file).fileno()) file = os.fdopen(fd, 'wb', buffering=0) assert self._loop is not None _, protocol = \ await self._loop.connect_write_pipe(pipe_factory, file) writer = cast(_PipeWriter, protocol) self.set_writer(writer, recv_eof, datatype) if writer: self.feed_recv_buf(datatype, writer) def _should_block_drain(self, datatype: DataType) -> bool: """Return whether output is still being written to the channel""" return (datatype in self._readers or super()._should_block_drain(datatype)) def _should_pause_reading(self) -> bool: """Return whether to pause reading from the channel""" return bool(self._paused_write_streams) or \ super()._should_pause_reading() def add_cleanup_task(self, task: Awaitable) -> None: """Add a task to run when the process exits""" self._cleanup_tasks.append(task) def connection_lost(self, exc: Optional[Exception]) -> None: """Handle a close of the SSH channel""" super().connection_lost(exc) # type: ignore for reader in list(self._readers.values()): reader.close() for writer in list(self._writers.values()): writer.close() self._readers = {} self._writers = {} def data_received(self, data: AnyStr, datatype: DataType) -> None: """Handle incoming data from the SSH channel""" writer = self._writers.get(datatype) if writer: writer.write(data) else: super().data_received(data, datatype) def eof_received(self) -> bool: """Handle an incoming end of file from the SSH channel""" for datatype, writer in list(self._writers.items()): if self._recv_eof[datatype]: writer.write_eof() return super().eof_received() def pause_writing(self) -> None: """Pause forwarding data to the channel""" super().pause_writing() for reader in list(self._readers.values()): reader.pause_reading() def resume_writing(self) -> None: """Resume forwarding data to the channel""" super().resume_writing() for reader in list(self._readers.values()): reader.resume_reading() def feed_data(self, data: AnyStr, datatype: DataType) -> None: """Feed data to the channel""" assert self._chan is not None self._chan.write(data, datatype) def feed_eof(self, datatype: DataType) -> None: """Feed EOF to the channel""" if self._send_eof[datatype]: assert self._chan is not None self._chan.write_eof() self._readers[datatype].close() self.clear_reader(datatype) def feed_close(self, datatype: DataType) -> None: """Feed pipe close to the channel""" if datatype in self._readers: self.feed_eof(datatype) def feed_recv_buf(self, datatype: DataType, writer: _WriterProtocol[AnyStr]) -> None: """Feed current receive buffer to a newly set writer""" for buf in self._recv_buf[datatype]: if isinstance(buf, Exception): writer.write_exception(buf) else: writer.write(buf) self._recv_buf_len -= len(buf) self._recv_buf[datatype].clear() if self._eof_received: writer.write_eof() self._maybe_resume_reading() def pause_feeding(self, datatype: DataType) -> None: """Pause feeding data from the channel""" self._paused_write_streams.add(datatype) self._maybe_pause_reading() def resume_feeding(self, datatype: DataType) -> None: """Resume feeding data from the channel""" self._paused_write_streams.remove(datatype) self._maybe_resume_reading() def set_reader(self, reader: Optional[_ReaderProtocol], send_eof: bool, datatype: DataType) -> None: """Set a reader used to forward data to the channel""" old_reader = self._readers.get(datatype) if old_reader: old_reader.close() if reader: self._readers[datatype] = reader self._send_eof[datatype] = send_eof if self._write_paused: reader.pause_reading() elif old_reader: self.clear_reader(datatype) def clear_reader(self, datatype: DataType) -> None: """Clear a reader forwarding data to the channel""" del self._readers[datatype] del self._send_eof[datatype] self._unblock_drain(datatype) def set_writer(self, writer: Optional[_WriterProtocol[AnyStr]], recv_eof: bool, datatype: DataType) -> None: """Set a writer used to forward data from the channel""" old_writer = self._writers.get(datatype) if old_writer: old_writer.close() self.clear_writer(datatype) if writer: self._writers[datatype] = writer self._recv_eof[datatype] = recv_eof def clear_writer(self, datatype: DataType) -> None: """Clear a writer forwarding data from the channel""" if datatype in self._paused_write_streams: self.resume_feeding(datatype) del self._writers[datatype] def close(self) -> None: """Shut down the process""" assert self._chan is not None self._chan.close() def is_closing(self) -> bool: """Return if the channel is closing or is closed""" assert self._chan is not None return self._chan.is_closing() async def wait_closed(self) -> None: """Wait for the process to finish shutting down""" assert self._chan is not None await self._chan.wait_closed() for task in self._cleanup_tasks: await task self._cleanup_tasks = [] class SSHClientProcess(SSHProcess[AnyStr], SSHClientStreamSession[AnyStr]): """SSH client process handler""" _chan: SSHClientChannel[AnyStr] channel: SSHClientChannel[AnyStr] def __init__(self) -> None: super().__init__() self._stdin: Optional[SSHWriter[AnyStr]] = None self._stdout: Optional[SSHReader[AnyStr]] = None self._stderr: Optional[SSHReader[AnyStr]] = None def _collect_output(self, datatype: DataType = None) -> AnyStr: """Return output from the process""" recv_buf = self._recv_buf[datatype] if recv_buf and isinstance(recv_buf[-1], Exception): recv_buf, self._recv_buf[datatype] = recv_buf[:-1], recv_buf[-1:] else: self._recv_buf[datatype] = [] buf = cast(AnyStr, '' if self._encoding else b'') return buf.join(cast(Iterable[AnyStr], recv_buf)) def session_started(self) -> None: """Start a process for this newly opened client channel""" self._stdin = SSHWriter[AnyStr](self, self._chan) self._stdout = SSHReader[AnyStr](self, self._chan) self._stderr = SSHReader[AnyStr](self, self._chan, EXTENDED_DATA_STDERR) @property def exit_status(self) -> Optional[int]: """The exit status of the process""" return self._chan.get_exit_status() @property def exit_signal(self) -> Optional[Tuple[str, bool, str, str]]: """Exit signal information for the process""" return self._chan.get_exit_signal() @property def returncode(self) -> Optional[int]: """The exit status or negative exit signal number for the process""" return self._chan.get_returncode() @property def stdin(self) -> SSHWriter[AnyStr]: """The :class:`SSHWriter` to use to write to stdin of the process""" assert self._stdin is not None return self._stdin @property def stdout(self) -> SSHReader[AnyStr]: """The :class:`SSHReader` to use to read from stdout of the process""" assert self._stdout is not None return self._stdout @property def stderr(self) -> SSHReader[AnyStr]: """The :class:`SSHReader` to use to read from stderr of the process""" assert self._stderr is not None return self._stderr def feed_exception(self, exc: Exception) -> None: """Feed exception to the channel""" if isinstance(exc, TerminalSizeChanged): self._chan.change_terminal_size(exc.width, exc.height, exc.pixwidth, exc.pixheight) elif isinstance(exc, BreakReceived): self._chan.send_break(exc.msec) elif isinstance(exc, SignalReceived): # pragma: no branch self._chan.send_signal(exc.signal) async def redirect(self, stdin: Optional[ProcessSource] = None, stdout: Optional[ProcessTarget] = None, stderr: Optional[ProcessTarget] = None, bufsize: int =io.DEFAULT_BUFFER_SIZE, send_eof: bool = True, recv_eof: bool = True) -> None: """Perform I/O redirection for the process This method redirects data going to or from any or all of standard input, standard output, and standard error for the process. The `stdin` argument can be any of the following: * An :class:`SSHReader` object * An :class:`asyncio.StreamReader` object * A file object open for read * An `int` file descriptor open for read * A connected socket object * A string or :class:`PurePath ` containing the name of a file or device to open * `DEVNULL` to provide no input to standard input * `PIPE` to interactively write standard input The `stdout` and `stderr` arguments can be any of the following: * An :class:`SSHWriter` object * An :class:`asyncio.StreamWriter` object * A file object open for write * An `int` file descriptor open for write * A connected socket object * A string or :class:`PurePath ` containing the name of a file or device to open * `DEVNULL` to discard standard error output * `PIPE` to interactively read standard error output The `stderr` argument also accepts the value `STDOUT` to request that standard error output be delivered to stdout. File objects passed in can be associated with plain files, pipes, sockets, or ttys. The default value of `None` means to not change redirection for that stream. .. note:: While it is legal to use buffered I/O streams such as sys.stdin, sys.stdout, and sys.stderr as redirect targets, you must make sure buffers are flushed before redirection begins and that these streams are put back into blocking mode before attempting to go back using buffered I/O again. Also, no buffered I/O should be performed while redirection is active. .. note:: When passing in asyncio streams, it is the responsibility of the caller to close the associated transport when it is no longer needed. :param stdin: Source of data to feed to standard input :param stdout: Target to feed data from standard output to :param stderr: Target to feed data from standard error to :param bufsize: Buffer size to use when forwarding data from a file :param send_eof: Whether or not to send EOF to the channel when EOF is received from stdin, defaulting to `True`. If set to `False`, the channel will remain open after EOF is received on stdin, and multiple sources can be redirected to the channel. :param recv_eof: Whether or not to send EOF to stdout and stderr when EOF is received from the channel, defaulting to `True`. If set to `False`, the redirect targets of stdout and stderr will remain open after EOF is received on the channel and can be used for multiple redirects. :type bufsize: `int` :type send_eof: `bool` :type recv_eof: `bool` """ if stdin: await self._create_reader(stdin, bufsize, send_eof, recv_eof) if stdout: await self._create_writer(stdout, bufsize, send_eof, recv_eof) if stderr: await self._create_writer(stderr, bufsize, send_eof, recv_eof, EXTENDED_DATA_STDERR) async def redirect_stdin(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True) -> None: """Redirect standard input of the process""" await self.redirect(source, None, None, bufsize, send_eof, True) async def redirect_stdout(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, recv_eof: bool = True) -> None: """Redirect standard output of the process""" await self.redirect(None, target, None, bufsize, True, recv_eof) async def redirect_stderr(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, recv_eof: bool = True) -> None: """Redirect standard error of the process""" await self.redirect(None, None, target, bufsize, True, recv_eof) def collect_output(self) -> Tuple[AnyStr, AnyStr]: """Collect output from the process without blocking This method returns a tuple of the output that the process has written to stdout and stderr which has not yet been read. It is intended to be called instead of read() by callers that want to collect received data without blocking. :returns: A tuple of output to stdout and stderr """ return (self._collect_output(), self._collect_output(EXTENDED_DATA_STDERR)) # pylint: disable=redefined-builtin async def communicate(self, input: Optional[AnyStr] = None) -> \ Tuple[AnyStr, AnyStr]: """Send input to and/or collect output from the process This method is a coroutine which optionally provides input to the process and then waits for the process to exit, returning a tuple of the data written to stdout and stderr. :param input: Input data to feed to standard input of the process. Data should be a `str` if encoding is set, or `bytes` if not. :type input: `str` or `bytes` :returns: A tuple of output to stdout and stderr """ self._limit = 0 self._maybe_resume_reading() if input: self._chan.write(input) self._chan.write_eof() await self.wait_closed() return self.collect_output() # pylint: enable=redefined-builtin def change_terminal_size(self, width: int, height: int, pixwidth: int = 0, pixheight: int = 0) -> None: """Change the terminal window size for this process This method changes the width and height of the terminal associated with this process. :param width: The width of the terminal in characters :param height: The height of the terminal in characters :param pixwidth: (optional) The width of the terminal in pixels :param pixheight: (optional) The height of the terminal in pixels :type width: `int` :type height: `int` :type pixwidth: `int` :type pixheight: `int` :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.change_terminal_size(width, height, pixwidth, pixheight) def send_break(self, msec: int) -> None: """Send a break to the process :param msec: The duration of the break in milliseconds :type msec: `int` :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.send_break(msec) def send_signal(self, signal: str) -> None: """Send a signal to the process :param signal: The signal to deliver :type signal: `str` :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.send_signal(signal) def terminate(self) -> None: """Terminate the process :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.terminate() def kill(self) -> None: """Forcibly kill the process :raises: :exc:`OSError` if the SSH channel is not open """ self._chan.kill() async def wait(self, check: bool = False, timeout: Optional[float] = None) -> SSHCompletedProcess: """Wait for process to exit This method is a coroutine which waits for the process to exit. It returns an :class:`SSHCompletedProcess` object with the exit status or signal information and the output sent to stdout and stderr if those are redirected to pipes. If the check argument is set to `True`, a non-zero exit status from the process with trigger the :exc:`ProcessError` exception to be raised. If a timeout is specified and it expires before the process exits, the :exc:`TimeoutError` exception will be raised. By default, no timeout is set and this call will wait indefinitely. :param check: Whether or not to raise an error on non-zero exit status :param timeout: Amount of time in seconds to wait for process to exit, or `None` to wait indefinitely :type check: `bool` :type timeout: `int`, `float`, or `None` :returns: :class:`SSHCompletedProcess` :raises: | :exc:`ProcessError` if check is set to `True` and the process returns a non-zero exit status | :exc:`TimeoutError` if the timeout expires before the process exits """ try: stdout_data, stderr_data = \ await asyncio.wait_for(self.communicate(), timeout) except asyncio.TimeoutError: stdout_data, stderr_data = self.collect_output() raise TimeoutError(self.env, self.command, self.subsystem, self.exit_status, self.exit_signal, self.returncode, stdout_data, stderr_data) from None if check and self.exit_status: raise ProcessError(self.env, self.command, self.subsystem, self.exit_status, self.exit_signal, self.returncode, stdout_data, stderr_data) else: return SSHCompletedProcess(self.env, self.command, self.subsystem, self.exit_status, self.exit_signal, self.returncode, stdout_data, stderr_data) class SSHServerProcess(SSHProcess[AnyStr], SSHServerStreamSession[AnyStr]): """SSH server process handler""" _chan: SSHServerChannel[AnyStr] channel: SSHServerChannel[AnyStr] def __init__(self, process_factory: SSHServerProcessFactory, sftp_factory: Optional[SFTPServerFactory], sftp_version: int, allow_scp: bool): super().__init__(self._start_process, sftp_factory, sftp_version, allow_scp) self._process_factory = process_factory self._stdin: Optional[SSHReader[AnyStr]] = None self._stdout: Optional[SSHWriter[AnyStr]] = None self._stderr: Optional[SSHWriter[AnyStr]] = None def _start_process(self, stdin: SSHReader[AnyStr], stdout: SSHWriter[AnyStr], stderr: SSHWriter[AnyStr]) -> MaybeAwait[None]: """Start a new server process""" self._stdin = stdin self._stdout = stdout self._stderr = stderr return self._process_factory(self) @property def term_type(self) -> Optional[str]: """The terminal type set by the client If the client didn't request a pseudo-terminal, this property will be set to `None`. """ return self._chan.get_terminal_type() @property def term_size(self) -> TermSize: """The terminal size set by the client This property contains a tuple of four `int` values representing the width and height of the terminal in characters followed by the width and height of the terminal in pixels. If the client hasn't set terminal size information, the values will be set to zero. """ return self._chan.get_terminal_size() @property def term_modes(self) -> TermModes: """A mapping containing the TTY modes set by the client If the client didn't request a pseudo-terminal, this property will be set to an empty mapping. """ return self._chan.get_terminal_modes() @property def stdin(self) -> SSHReader[AnyStr]: """The :class:`SSHReader` to use to read from stdin of the process""" assert self._stdin is not None return self._stdin @property def stdout(self) -> SSHWriter[AnyStr]: """The :class:`SSHWriter` to use to write to stdout of the process""" assert self._stdout is not None return self._stdout @property def stderr(self) -> SSHWriter[AnyStr]: """The :class:`SSHWriter` to use to write to stderr of the process""" assert self._stderr is not None return self._stderr def exception_received(self, exc: Exception) -> None: """Handle an incoming exception on the channel""" writer = self._writers.get(None) if writer: writer.write_exception(exc) else: super().exception_received(exc) async def redirect(self, stdin: Optional[ProcessTarget] = None, stdout: Optional[ProcessSource] = None, stderr: Optional[ProcessSource] = None, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True, recv_eof: bool = True) -> None: """Perform I/O redirection for the process This method redirects data going to or from any or all of standard input, standard output, and standard error for the process. The `stdin` argument can be any of the following: * An :class:`SSHWriter` object * An :class:`asyncio.StreamWriter` object * A file object open for write * An `int` file descriptor open for write * A connected socket object * A string or :class:`PurePath ` containing the name of a file or device to open * `DEVNULL` to discard standard error output * `PIPE` to interactively read standard error output The `stdout` and `stderr` arguments can be any of the following: * An :class:`SSHReader` object * An :class:`asyncio.StreamReader` object * A file object open for read * An `int` file descriptor open for read * A connected socket object * A string or :class:`PurePath ` containing the name of a file or device to open * `DEVNULL` to provide no input to standard input * `PIPE` to interactively write standard input File objects passed in can be associated with plain files, pipes, sockets, or ttys. The default value of `None` means to not change redirection for that stream. .. note:: When passing in asyncio streams, it is the responsibility of the caller to close the associated transport when it is no longer needed. :param stdin: Target to feed data from standard input to :param stdout: Source of data to feed to standard output :param stderr: Source of data to feed to standard error :param bufsize: Buffer size to use when forwarding data from a file :param send_eof: Whether or not to send EOF to the channel when EOF is received from stdout or stderr, defaulting to `True`. If set to `False`, the channel will remain open after EOF is received on stdout or stderr, and multiple sources can be redirected to the channel. :param recv_eof: Whether or not to send EOF to stdin when EOF is received on the channel, defaulting to `True`. If set to `False`, the redirect target of stdin will remain open after EOF is received on the channel and can be used for multiple redirects. :type bufsize: `int` :type send_eof: `bool` :type recv_eof: `bool` """ if stdin: await self._create_writer(stdin, bufsize, send_eof, recv_eof) if stdout: await self._create_reader(stdout, bufsize, send_eof, recv_eof) if stderr: await self._create_reader(stderr, bufsize, send_eof, recv_eof, EXTENDED_DATA_STDERR) async def redirect_stdin(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, recv_eof: bool = True) -> None: """Redirect standard input of the process""" await self.redirect(target, None, None, bufsize, True, recv_eof) async def redirect_stdout(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True) -> None: """Redirect standard output of the process""" await self.redirect(None, source, None, bufsize, send_eof, True) async def redirect_stderr(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True) -> None: """Redirect standard error of the process""" await self.redirect(None, None, source, bufsize, send_eof, True) def get_terminal_type(self) -> Optional[str]: """Return the terminal type set by the client for the process This method returns the terminal type set by the client when the process was started. If the client didn't request a pseudo-terminal, this method will return `None`. :returns: A `str` containing the terminal type or `None` if no pseudo-terminal was requested """ return self.term_type def get_terminal_size(self) -> Tuple[int, int, int, int]: """Return the terminal size set by the client for the process This method returns the latest terminal size information set by the client. If the client didn't set any terminal size information, all values returned will be zero. :returns: A tuple of four `int` values containing the width and height of the terminal in characters and the width and height of the terminal in pixels """ return self.term_size def get_terminal_mode(self, mode: int) -> Optional[int]: """Return the requested TTY mode for this session This method looks up the value of a POSIX terminal mode set by the client when the process was started. If the client didn't request a pseudo-terminal or didn't set the requested TTY mode opcode, this method will return `None`. :param mode: POSIX terminal mode taken from :ref:`POSIX terminal modes ` to look up :type mode: `int` :returns: An `int` containing the value of the requested POSIX terminal mode or `None` if the requested mode was not set """ return self.term_modes.get(mode) def exit(self, status: int) -> None: """Send exit status and close the channel This method can be called to report an exit status for the process back to the client and close the channel. :param status: The exit status to report to the client :type status: `int` """ self._chan.exit(status) def exit_with_signal(self, signal: str, core_dumped: bool = False, msg: str = '', lang: str = DEFAULT_LANG) -> None: """Send exit signal and close the channel This method can be called to report that the process terminated abnormslly with a signal. A more detailed error message may also provided, along with an indication of whether or not the process dumped core. After reporting the signal, the channel is closed. :param signal: The signal which caused the process to exit :param core_dumped: (optional) Whether or not the process dumped core :param msg: (optional) Details about what error occurred :param lang: (optional) The language the error message is in :type signal: `str` :type core_dumped: `bool` :type msg: `str` :type lang: `str` """ return self._chan.exit_with_signal(signal, core_dumped, msg, lang) asyncssh-2.20.0/asyncssh/public_key.py000066400000000000000000004201061475467777400200200ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH asymmetric encryption handlers""" import asyncio import binascii import inspect import os import re import time from datetime import datetime from hashlib import md5, sha1, sha256, sha384, sha512 from pathlib import Path, PurePath from typing import Callable, Dict, List, Mapping, Optional, Sequence, Set from typing import Tuple, Type, Union, cast from typing_extensions import Protocol from .crypto import ed25519_available, ed448_available from .encryption import Encryption from .sk import sk_available try: # pylint: disable=unused-import from .crypto import X509Certificate from .crypto import generate_x509_certificate, import_x509_certificate _x509_available = True except ImportError: # pragma: no cover _x509_available = False try: import bcrypt _bcrypt_available = hasattr(bcrypt, 'kdf') except ImportError: # pragma: no cover _bcrypt_available = False from .asn1 import ASN1DecodeError, BitString, ObjectIdentifier from .asn1 import der_encode, der_decode, der_decode_partial from .crypto import CryptoKey, PyCAKey from .encryption import get_encryption_params, get_encryption from .misc import BytesOrStr, DefTuple, FilePath, IPNetwork from .misc import ip_network, read_file, write_file, parse_time_interval from .packet import NameList, String, UInt32, UInt64 from .packet import PacketDecodeError, SSHPacket from .pbe import KeyEncryptionError, pkcs1_encrypt, pkcs8_encrypt from .pbe import pkcs1_decrypt, pkcs8_decrypt from .sk import SSH_SK_USER_PRESENCE_REQD, sk_get_resident _Comment = Optional[BytesOrStr] _CertPrincipals = Union[str, Sequence[str]] _Time = Union[int, float, datetime, str] _PubKeyAlgMap = Dict[bytes, Type['SSHKey']] _CertAlgMap = Dict[bytes, Tuple[Optional[Type['SSHKey']], Type['SSHCertificate']]] _CertSigAlgMap = Dict[bytes, bytes] _CertVersionMap = Dict[Tuple[bytes, int], Tuple[bytes, Type['SSHOpenSSHCertificate']]] _PEMMap = Dict[bytes, Type['SSHKey']] _PKCS8OIDMap = Dict[ObjectIdentifier, Type['SSHKey']] _SKAlgMap = Dict[int, Tuple[Type['SSHKey'], Tuple[object, ...]]] _OpenSSHCertOptions = Dict[str, object] _OpenSSHCertParams = Tuple[object, int, int, bytes, bytes, int, int, bytes, bytes] _OpenSSHCertEncoders = Sequence[Tuple[str, Callable[[object], bytes]]] _OpenSSHCertDecoders = Dict[bytes, Callable[[SSHPacket], object]] X509CertPurposes = Union[None, str, Sequence[str]] _IdentityArg = Union[bytes, FilePath, 'SSHKey', 'SSHCertificate'] IdentityListArg = Union[_IdentityArg, Sequence[_IdentityArg]] _KeyArg = Union[bytes, FilePath, 'SSHKey'] KeyListArg = Union[FilePath, Sequence[_KeyArg]] _CertArg = Union[bytes, FilePath, 'SSHCertificate'] CertListArg = Union[_CertArg, Sequence[_CertArg]] _KeyPairArg = Union['SSHKeyPair', _KeyArg, Tuple[_KeyArg, _CertArg]] KeyPairListArg = Union[_KeyPairArg, Sequence[_KeyPairArg]] # Default file names in .ssh directory to read private keys from _DEFAULT_KEY_FILES = ( ('id_ed25519_sk', ed25519_available and sk_available), ('id_ecdsa_sk', sk_available), ('id_ed448', ed448_available), ('id_ed25519', ed25519_available), ('id_ecdsa', True), ('id_rsa', True), ('id_dsa', True) ) # Default directories and file names to read host keys from _DEFAULT_HOST_KEY_DIRS = ('/opt/local/etc', '/opt/local/etc/ssh', '/usr/local/etc', '/usr/local/etc/ssh', '/etc', '/etc/ssh') _DEFAULT_HOST_KEY_FILES = ('ssh_host_ed448_key', 'ssh_host_ed25519_key', 'ssh_host_ecdsa_key', 'ssh_host_rsa_key', 'ssh_host_dsa_key') _hashes = {'md5': md5, 'sha1': sha1, 'sha256': sha256, 'sha384': sha384, 'sha512': sha512} _public_key_algs: List[bytes] = [] _default_public_key_algs: List[bytes] = [] _certificate_algs: List[bytes] = [] _default_certificate_algs: List[bytes] = [] _x509_certificate_algs: List[bytes] = [] _default_x509_certificate_algs: List[bytes] = [] _public_key_alg_map: _PubKeyAlgMap = {} _certificate_alg_map: _CertAlgMap = {} _certificate_sig_alg_map: _CertSigAlgMap = {} _certificate_version_map: _CertVersionMap = {} _pem_map: _PEMMap = {} _pkcs8_oid_map: _PKCS8OIDMap = {} _sk_alg_map: _SKAlgMap = {} _abs_date_pattern = re.compile(r'\d{8}') _abs_time_pattern = re.compile(r'\d{14}') _subject_pattern = re.compile(r'(?:Distinguished[ -_]?Name|Subject|DN)[=:]?\s?', re.IGNORECASE) # SSH certificate types CERT_TYPE_USER = 1 CERT_TYPE_HOST = 2 # Flag to omit second argument in alg_params OMIT = object() _OPENSSH_KEY_V1 = b'openssh-key-v1\0' _OPENSSH_SALT_LEN = 16 _OPENSSH_WRAP_LEN = 70 def _parse_time(t: _Time) -> int: """Parse a time value""" if isinstance(t, int): return t elif isinstance(t, float): return int(t) elif isinstance(t, datetime): return int(t.timestamp()) elif isinstance(t, str): if t == 'now': return int(time.time()) match = _abs_date_pattern.fullmatch(t) if match: return int(datetime.strptime(t, '%Y%m%d').timestamp()) match = _abs_time_pattern.fullmatch(t) if match: return int(datetime.strptime(t, '%Y%m%d%H%M%S').timestamp()) try: return int(time.time() + parse_time_interval(t)) except ValueError: pass raise ValueError('Unrecognized time value') def _wrap_base64(data: bytes, wrap: int = 64) -> bytes: """Break a Base64 value into multiple lines.""" data = binascii.b2a_base64(data)[:-1] return b'\n'.join(data[i:i+wrap] for i in range(0, len(data), wrap)) + b'\n' class KeyGenerationError(ValueError): """Key generation error This exception is raised by :func:`generate_private_key`, :meth:`generate_user_certificate() ` or :meth:`generate_host_certificate() ` when the requested parameters are unsupported. """ class KeyImportError(ValueError): """Key import error This exception is raised by key import functions when the data provided cannot be imported as a valid key. """ class KeyExportError(ValueError): """Key export error This exception is raised by key export functions when the requested format is unknown or encryption is requested for a format which doesn't support it. """ class SigningKey(Protocol): """Protocol for signing a block of data""" def sign(self, data: bytes) -> bytes: """Sign a block of data with a private key""" class VerifyingKey(Protocol): """Protocol for verifying a signature on a block of data""" def verify(self, data: bytes, sig: bytes) -> bool: """Verify a signature on a block of data with a public key""" class SSHKey: """Parent class which holds an asymmetric encryption key""" algorithm: bytes = b'' sig_algorithms: Sequence[bytes] = () cert_algorithms: Sequence[bytes] = () x509_algorithms: Sequence[bytes] = () all_sig_algorithms: Set[bytes] = set() default_x509_hash: str = '' pem_name: bytes = b'' pkcs8_oid: Optional[ObjectIdentifier] = None use_executor: bool = False use_webauthn: bool = False def __init__(self, key: Optional[CryptoKey] = None): self._key = key self._comment: Optional[bytes] = None self._filename: Optional[bytes] = None self._touch_required = False @classmethod def generate(cls, algorithm: bytes, **kwargs) -> 'SSHKey': """Generate a new SSH private key""" raise NotImplementedError @classmethod def make_private(cls, key_params: object) -> 'SSHKey': """Construct a private key""" raise NotImplementedError @classmethod def make_public(cls, key_params: object) -> 'SSHKey': """Construct a public key""" raise NotImplementedError @classmethod def decode_pkcs1_private(cls, key_data: object) -> object: """Decode a PKCS#1 format private key""" @classmethod def decode_pkcs1_public(cls, key_data: object) -> object: """Decode a PKCS#1 format public key""" @classmethod def decode_pkcs8_private(cls, alg_params: object, data: bytes) -> object: """Decode a PKCS#8 format private key""" @classmethod def decode_pkcs8_public(cls, alg_params: object, data: bytes) -> object: """Decode a PKCS#8 format public key""" @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> object: """Decode an SSH format private key""" @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> object: """Decode an SSH format public key""" @property def private_data(self) -> bytes: """Return private key data in OpenSSH binary format""" return String(self.algorithm) + self.encode_ssh_private() @property def public_data(self) -> bytes: """Return public key data in OpenSSH binary format""" return String(self.algorithm) + self.encode_ssh_public() @property def pyca_key(self) -> PyCAKey: """Return PyCA key for use in X.509 module""" assert self._key is not None return self._key.pyca_key def _generate_certificate(self, key: 'SSHKey', version: int, serial: int, cert_type: int, key_id: str, principals: _CertPrincipals, valid_after: _Time, valid_before: _Time, cert_options: _OpenSSHCertOptions, sig_alg_name: DefTuple[str], comment: DefTuple[_Comment]) -> \ 'SSHOpenSSHCertificate': """Generate a new SSH certificate""" if isinstance(principals, str): principals = [p.strip() for p in principals.split(',')] else: principals = list(principals) valid_after = _parse_time(valid_after) valid_before = _parse_time(valid_before) if valid_before <= valid_after: raise ValueError('Valid before time must be later than ' 'valid after time') if sig_alg_name == (): sig_alg = self.sig_algorithms[0] else: sig_alg = cast(str, sig_alg_name).encode() if comment == (): comment = key.get_comment_bytes() comment: _Comment try: algorithm, cert_handler = _certificate_version_map[key.algorithm, version] except KeyError: raise KeyGenerationError('Unknown certificate version') from None return cert_handler.generate(self, algorithm, key, serial, cert_type, key_id, principals, valid_after, valid_before, cert_options, sig_alg, comment) def _generate_x509_certificate(self, key: 'SSHKey', subject: str, issuer: Optional[str], serial: Optional[int], valid_after: _Time, valid_before: _Time, ca: bool, ca_path_len: Optional[int], purposes: X509CertPurposes, user_principals: _CertPrincipals, host_principals: _CertPrincipals, hash_name: DefTuple[str], comment: DefTuple[_Comment]) -> \ 'SSHX509Certificate': """Generate a new X.509 certificate""" if not _x509_available: # pragma: no cover raise KeyGenerationError('X.509 certificate generation ' 'requires PyOpenSSL') if not self.x509_algorithms: raise KeyGenerationError('X.509 certificate generation not ' 'supported for ' + self.get_algorithm() + ' keys') valid_after = _parse_time(valid_after) valid_before = _parse_time(valid_before) if valid_before <= valid_after: raise ValueError('Valid before time must be later than ' 'valid after time') if hash_name == (): hash_name = key.default_x509_hash if comment == (): comment = key.get_comment_bytes() hash_name: str comment: _Comment return SSHX509Certificate.generate(self, key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_name, comment) def get_algorithm(self) -> str: """Return the algorithm associated with this key""" return self.algorithm.decode('ascii') def has_comment(self) -> bool: """Return whether a comment is set for this key :returns: `bool` """ return bool(self._comment) def get_comment_bytes(self) -> Optional[bytes]: """Return the comment associated with this key as a byte string :returns: `bytes` or `None` """ return self._comment or self._filename def get_comment(self, encoding: str = 'utf-8', errors: str = 'strict') -> Optional[str]: """Return the comment associated with this key as a Unicode string :param encoding: The encoding to use to decode the comment as a Unicode string, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode decode errors :type encoding: `str` :type errors: `str` :returns: `str` or `None` :raises: :exc:`UnicodeDecodeError` if the comment cannot be decoded using the specified encoding """ comment = self.get_comment_bytes() return comment.decode(encoding, errors) if comment else None def set_comment(self, comment: _Comment, encoding: str = 'utf-8', errors: str = 'strict') -> None: """Set the comment associated with this key :param comment: The new comment to associate with this key :param encoding: The Unicode encoding to use to encode the comment, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode encode errors :type comment: `str`, `bytes`, or `None` :type encoding: `str` :type errors: `str` :raises: :exc:`UnicodeEncodeError` if the comment cannot be encoded using the specified encoding """ if isinstance(comment, str): comment = comment.encode(encoding, errors) self._comment = comment or None def get_filename(self) -> Optional[bytes]: """Return the filename associated with this key :returns: `bytes` or `None` """ return self._filename def set_filename(self, filename: Union[None, bytes, FilePath]) -> None: """Set the filename associated with this key :param filename: The new filename to associate with this key :type filename: `PurePath`, `str`, `bytes`, or `None` """ if isinstance(filename, PurePath): filename = str(filename) if isinstance(filename, str): filename = filename.encode('utf-8') self._filename = filename or None def get_fingerprint(self, hash_name: str = 'sha256') -> str: """Get the fingerprint of this key Available hashes include: md5, sha1, sha256, sha384, sha512 :param hash_name: (optional) The hash algorithm to use to construct the fingerprint. :type hash_name: `str` :returns: `str` :raises: :exc:`ValueError` if the hash name is invalid """ try: hash_alg = _hashes[hash_name] except KeyError: raise ValueError('Unknown hash algorithm') from None h = hash_alg(self.public_data) if hash_name == 'md5': fp = h.hexdigest() fp_text = ':'.join(fp[i:i+2] for i in range(0, len(fp), 2)) else: fpb = h.digest() fp_text = binascii.b2a_base64(fpb).decode('ascii')[:-1].strip('=') return hash_name.upper() + ':' + fp_text def set_touch_required(self, touch_required: bool) -> None: """Set whether touch is required when using a security key""" self._touch_required = touch_required def sign_raw(self, data: bytes, hash_name: str) -> bytes: """Return a raw signature of the specified data""" assert self._key is not None return self._key.sign(data, hash_name) def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Abstract method to compute an SSH-encoded signature""" raise NotImplementedError def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Abstract method to verify an SSH-encoded signature""" raise NotImplementedError def sign(self, data: bytes, sig_algorithm: bytes) -> bytes: """Return an SSH-encoded signature of the specified data""" if sig_algorithm.startswith(b'x509v3-'): sig_algorithm = sig_algorithm[7:] if sig_algorithm not in self.all_sig_algorithms: raise ValueError('Unrecognized signature algorithm') return b''.join((String(sig_algorithm), self.sign_ssh(data, sig_algorithm))) def verify(self, data: bytes, sig: bytes) -> bool: """Verify an SSH signature of the specified data using this key""" try: packet = SSHPacket(sig) sig_algorithm = packet.get_string() if sig_algorithm not in self.all_sig_algorithms: return False return self.verify_ssh(data, sig_algorithm, packet) except PacketDecodeError: return False def encode_pkcs1_private(self) -> object: """Export parameters associated with a PKCS#1 private key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#1 private key export not supported') def encode_pkcs1_public(self) -> object: """Export parameters associated with a PKCS#1 public key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#1 public key export not supported') def encode_pkcs8_private(self) -> Tuple[object, object]: """Export parameters associated with a PKCS#8 private key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#8 private key export not supported') def encode_pkcs8_public(self) -> Tuple[object, object]: """Export parameters associated with a PKCS#8 public key""" # pylint: disable=no-self-use raise KeyExportError('PKCS#8 public key export not supported') def encode_ssh_private(self) -> bytes: """Export parameters associated with an OpenSSH private key""" # pylint: disable=no-self-use raise KeyExportError('OpenSSH private key export not supported') def encode_ssh_public(self) -> bytes: """Export parameters associated with an OpenSSH public key""" # pylint: disable=no-self-use raise KeyExportError('OpenSSH public key export not supported') def encode_agent_cert_private(self) -> bytes: """Encode certificate private key data for agent""" raise NotImplementedError def convert_to_public(self) -> 'SSHKey': """Return public key corresponding to this key This method converts an :class:`SSHKey` object which contains a private key into one which contains only the corresponding public key. If it is called on something which is already a public key, it has no effect. """ result = decode_ssh_public_key(self.public_data) result.set_comment(self._comment) result.set_filename(self._filename) return result def generate_user_certificate( self, user_key: 'SSHKey', key_id: str, version: int = 1, serial: int = 0, principals: _CertPrincipals = (), valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, force_command: Optional[str] = None, source_address: Optional[Sequence[str]] = None, permit_x11_forwarding: bool = True, permit_agent_forwarding: bool = True, permit_port_forwarding: bool = True, permit_pty: bool = True, permit_user_rc: bool = True, touch_required: bool = True, sig_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> 'SSHOpenSSHCertificate': """Generate a new SSH user certificate This method returns an SSH user certificate with the requested attributes signed by this private key. :param user_key: The user's public key. :param key_id: The key identifier associated with this certificate. :param version: (optional) The version of certificate to create, defaulting to 1. :param serial: (optional) The serial number of the certificate, defaulting to 0. :param principals: (optional) The user names this certificate is valid for. By default, it can be used with any user name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param force_command: (optional) The command (if any) to force a session to run when this certificate is used. :param source_address: (optional) A list of source addresses and networks for which the certificate is valid, defaulting to all addresses. :param permit_x11_forwarding: (optional) Whether or not to allow this user to use X11 forwarding, defaulting to `True`. :param permit_agent_forwarding: (optional) Whether or not to allow this user to use agent forwarding, defaulting to `True`. :param permit_port_forwarding: (optional) Whether or not to allow this user to use port forwarding, defaulting to `True`. :param permit_pty: (optional) Whether or not to allow this user to allocate a pseudo-terminal, defaulting to `True`. :param permit_user_rc: (optional) Whether or not to run the user rc file when this certificate is used, defaulting to `True`. :param touch_required: (optional) Whether or not to require the user to touch the security key when authenticating with it, defaulting to `True`. :param sig_alg: (optional) The algorithm to use when signing the new certificate. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on user_key. :type user_key: :class:`SSHKey` :type key_id: `str` :type version: `int` :type serial: `int` :type principals: `str` or `list` of `str` :type force_command: `str` or `None` :type source_address: list of ip_address and ip_network values :type permit_x11_forwarding: `bool` :type permit_agent_forwarding: `bool` :type permit_port_forwarding: `bool` :type permit_pty: `bool` :type permit_user_rc: `bool` :type touch_required: `bool` :type sig_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ cert_options: _OpenSSHCertOptions = {} if force_command: cert_options['force-command'] = force_command if source_address: cert_options['source-address'] = [ip_network(addr) for addr in source_address] if permit_x11_forwarding: cert_options['permit-X11-forwarding'] = True if permit_agent_forwarding: cert_options['permit-agent-forwarding'] = True if permit_port_forwarding: cert_options['permit-port-forwarding'] = True if permit_pty: cert_options['permit-pty'] = True if permit_user_rc: cert_options['permit-user-rc'] = True if not touch_required: cert_options['no-touch-required'] = True return self._generate_certificate(user_key, version, serial, CERT_TYPE_USER, key_id, principals, valid_after, valid_before, cert_options, sig_alg, comment) def generate_host_certificate(self, host_key: 'SSHKey', key_id: str, version: int = 1, serial: int = 0, principals: _CertPrincipals = (), valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, sig_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> \ 'SSHOpenSSHCertificate': """Generate a new SSH host certificate This method returns an SSH host certificate with the requested attributes signed by this private key. :param host_key: The host's public key. :param key_id: The key identifier associated with this certificate. :param version: (optional) The version of certificate to create, defaulting to 1. :param serial: (optional) The serial number of the certificate, defaulting to 0. :param principals: (optional) The host names this certificate is valid for. By default, it can be used with any host name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param sig_alg: (optional) The algorithm to use when signing the new certificate. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on host_key. :type host_key: :class:`SSHKey` :type key_id: `str` :type version: `int` :type serial: `int` :type principals: `str` or `list` of `str` :type sig_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ if comment == (): comment = host_key.get_comment_bytes() return self._generate_certificate(host_key, version, serial, CERT_TYPE_HOST, key_id, principals, valid_after, valid_before, {}, sig_alg, comment) def generate_x509_user_certificate( self, user_key: 'SSHKey', subject: str, issuer: Optional[str] = None, serial: Optional[int] = None, principals: _CertPrincipals = (), valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, purposes: X509CertPurposes = 'secureShellClient', hash_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> 'SSHX509Certificate': """Generate a new X.509 user certificate This method returns an X.509 user certificate with the requested attributes signed by this private key. :param user_key: The user's public key. :param subject: The subject name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. :param issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param principals: (optional) The user names this certificate is valid for. By default, it can be used with any user name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param purposes: (optional) The allowed purposes for this certificate or `None` to not restrict the certificate's purpose, defaulting to 'secureShellClient' :param hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on user_key. :type user_key: :class:`SSHKey` :type subject: `str` :type issuer: `str` :type serial: `int` :type principals: `str` or `list` of `str` :type purposes: `str`, `list` of `str`, or `None` :type hash_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ return self._generate_x509_certificate(user_key, subject, issuer, serial, valid_after, valid_before, False, None, purposes, principals, (), hash_alg, comment) def generate_x509_host_certificate( self, host_key: 'SSHKey', subject: str, issuer: Optional[str] = None, serial: Optional[int] = None, principals: _CertPrincipals = (), valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, purposes: X509CertPurposes = 'secureShellServer', hash_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> 'SSHX509Certificate': """Generate a new X.509 host certificate This method returns an X.509 host certificate with the requested attributes signed by this private key. :param host_key: The host's public key. :param subject: The subject name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. :param issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param principals: (optional) The host names this certificate is valid for. By default, it can be used with any host name. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param purposes: (optional) The allowed purposes for this certificate or `None` to not restrict the certificate's purpose, defaulting to 'secureShellServer' :param hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on host_key. :type host_key: :class:`SSHKey` :type subject: `str` :type issuer: `str` :type serial: `int` :type principals: `str` or `list` of `str` :type purposes: `str`, `list` of `str`, or `None` :type hash_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ return self._generate_x509_certificate(host_key, subject, issuer, serial, valid_after, valid_before, False, None, purposes, (), principals, hash_alg, comment) def generate_x509_ca_certificate(self, ca_key: 'SSHKey', subject: str, issuer: Optional[str] = None, serial: Optional[int] = None, valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, ca_path_len: Optional[int] = None, hash_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> \ 'SSHX509Certificate': """Generate a new X.509 CA certificate This method returns an X.509 CA certificate with the requested attributes signed by this private key. :param ca_key: The new CA's public key. :param subject: The subject name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. :param issuer: (optional) The issuer name in the certificate, expresed as a comma-separated list of X.509 `name=value` pairs. If not specified, the subject name will be used, creating a self-signed certificate. :param serial: (optional) The serial number of the certificate, defaulting to a random 64-bit value. :param valid_after: (optional) The earliest time the certificate is valid for, defaulting to no restriction on when the certificate starts being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param valid_before: (optional) The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. :param ca_path_len: (optional) The maximum number of levels of intermediate CAs allowed below this new CA or `None` to not enforce a limit, defaulting to no limit. :param hash_alg: (optional) The hash algorithm to use when signing the new certificate, defaulting to SHA256. :param comment: (optional) The comment to associate with this certificate. By default, the comment will be set to the comment currently set on ca_key. :type ca_key: :class:`SSHKey` :type subject: `str` :type issuer: `str` :type serial: `int` :type ca_path_len: `int` or `None` :type hash_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` :raises: | :exc:`ValueError` if the validity times are invalid | :exc:`KeyGenerationError` if the requested certificate parameters are unsupported """ return self._generate_x509_certificate(ca_key, subject, issuer, serial, valid_after, valid_before, True, ca_path_len, None, (), (), hash_alg, comment) def export_private_key(self, format_name: str = 'openssh', passphrase: Optional[BytesOrStr] = None, cipher_name: str = 'aes256-cbc', hash_name: str = 'sha256', pbe_version: int = 2, rounds: int = 128, ignore_few_rounds: bool = False) -> bytes: """Export a private key in the requested format This method returns this object's private key encoded in the requested format. If a passphrase is specified, the key will be exported in encrypted form. Available formats include: pkcs1-der, pkcs1-pem, pkcs8-der, pkcs8-pem, openssh By default, openssh format will be used. Encryption is supported in pkcs1-pem, pkcs8-der, pkcs8-pem, and openssh formats. For pkcs1-pem, only the cipher can be specified. For pkcs8-der and pkcs-8, cipher, hash and PBE version can be specified. For openssh, cipher and rounds can be specified. Available ciphers for pkcs1-pem are: aes128-cbc, aes192-cbc, aes256-cbc, des-cbc, des3-cbc Available ciphers for pkcs8-der and pkcs8-pem are: aes128-cbc, aes192-cbc, aes256-cbc, blowfish-cbc, cast128-cbc, des-cbc, des2-cbc, des3-cbc, rc4-40, rc4-128 Available ciphers for openssh format include the following :ref:`encryption algorithms `. Available hashes include: md5, sha1, sha256, sha384, sha512 Available PBE versions include 1 for PBES1 and 2 for PBES2. Not all combinations of cipher, hash, and version are supported. The default cipher is aes256. In the pkcs8 formats, the default hash is sha256 and default version is PBES2. In openssh format, the default number of rounds is 128. .. note:: The openssh format uses bcrypt for encryption, but unlike the traditional bcrypt cost factor used in password hashing which scales logarithmically, the encryption strength here scales linearly with the rounds value. Since the cipher is rekeyed 64 times per round, the default rounds value of 128 corresponds to 8192 total iterations, which is the equivalent of a bcrypt cost factor of 13. :param format_name: (optional) The format to export the key in. :param passphrase: (optional) A passphrase to encrypt the private key with. :param cipher_name: (optional) The cipher to use for private key encryption. :param hash_name: (optional) The hash to use for private key encryption. :param pbe_version: (optional) The PBE version to use for private key encryption. :param rounds: (optional) The number of KDF rounds to apply to the passphrase. :type format_name: `str` :type passphrase: `str` or `bytes` :type cipher_name: `str` :type hash_name: `str` :type pbe_version: `int` :type rounds: `int` :returns: `bytes` representing the exported private key """ if format_name in ('pkcs1-der', 'pkcs1-pem'): data = der_encode(self.encode_pkcs1_private()) if passphrase is not None: if format_name == 'pkcs1-der': raise KeyExportError('PKCS#1 DER format does not support ' 'private key encryption') alg, iv, data = pkcs1_encrypt(data, cipher_name, passphrase) headers = (b'Proc-Type: 4,ENCRYPTED\n' + b'DEK-Info: ' + alg + b',' + binascii.b2a_hex(iv).upper() + b'\n\n') else: headers = b'' if format_name == 'pkcs1-pem': keytype = self.pem_name + b' PRIVATE KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + headers + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name in ('pkcs8-der', 'pkcs8-pem'): alg_params, pkcs8_data = self.encode_pkcs8_private() if alg_params is OMIT: data = der_encode((0, (self.pkcs8_oid,), pkcs8_data)) else: data = der_encode((0, (self.pkcs8_oid, alg_params), pkcs8_data)) if passphrase is not None: data = pkcs8_encrypt(data, cipher_name, hash_name, pbe_version, passphrase) if format_name == 'pkcs8-pem': if passphrase is not None: keytype = b'ENCRYPTED PRIVATE KEY' else: keytype = b'PRIVATE KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name == 'openssh': check = os.urandom(4) nkeys = 1 data = b''.join((check, check, self.private_data, String(self._comment or b''))) cipher: Optional[Encryption] if passphrase is not None: try: alg = cipher_name.encode('ascii') key_size, iv_size, block_size, _, _, _ = \ get_encryption_params(alg) except (KeyError, UnicodeEncodeError): raise KeyEncryptionError('Unknown cipher: ' + cipher_name) from None if not _bcrypt_available: # pragma: no cover raise KeyExportError('OpenSSH private key encryption ' 'requires bcrypt with KDF support') kdf = b'bcrypt' salt = os.urandom(_OPENSSH_SALT_LEN) kdf_data = b''.join((String(salt), UInt32(rounds))) if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') key = bcrypt.kdf(passphrase, salt, key_size + iv_size, rounds, ignore_few_rounds) cipher = get_encryption(alg, key[:key_size], key[key_size:]) block_size = max(block_size, 8) else: cipher = None alg = b'none' kdf = b'none' kdf_data = b'' block_size = 8 mac = b'' pad = len(data) % block_size if pad: # pragma: no branch data = data + bytes(range(1, block_size + 1 - pad)) if cipher: data, mac = cipher.encrypt_packet(0, b'', data) else: mac = b'' data = b''.join((_OPENSSH_KEY_V1, String(alg), String(kdf), String(kdf_data), UInt32(nkeys), String(self.public_data), String(data), mac)) return (b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + _wrap_base64(data, _OPENSSH_WRAP_LEN) + b'-----END OPENSSH PRIVATE KEY-----\n') else: raise KeyExportError('Unknown export format') def export_public_key(self, format_name: str = 'openssh') -> bytes: """Export a public key in the requested format This method returns this object's public key encoded in the requested format. Available formats include: pkcs1-der, pkcs1-pem, pkcs8-der, pkcs8-pem, openssh, rfc4716 By default, openssh format will be used. :param format_name: (optional) The format to export the key in. :type format_name: `str` :returns: `bytes` representing the exported public key """ if format_name in ('pkcs1-der', 'pkcs1-pem'): data = der_encode(self.encode_pkcs1_public()) if format_name == 'pkcs1-pem': keytype = self.pem_name + b' PUBLIC KEY' data = (b'-----BEGIN ' + keytype + b'-----\n' + _wrap_base64(data) + b'-----END ' + keytype + b'-----\n') return data elif format_name in ('pkcs8-der', 'pkcs8-pem'): alg_params, pkcs8_data = self.encode_pkcs8_public() pkcs8_data = BitString(pkcs8_data) if alg_params is OMIT: data = der_encode(((self.pkcs8_oid,), pkcs8_data)) else: data = der_encode(((self.pkcs8_oid, alg_params), pkcs8_data)) if format_name == 'pkcs8-pem': data = (b'-----BEGIN PUBLIC KEY-----\n' + _wrap_base64(data) + b'-----END PUBLIC KEY-----\n') return data elif format_name == 'openssh': if self._comment: comment = b' ' + self._comment else: comment = b'' return (self.algorithm + b' ' + binascii.b2a_base64(self.public_data)[:-1] + comment + b'\n') elif format_name == 'rfc4716': if self._comment: comment = (b'Comment: "' + self._comment + b'"\n') else: comment = b'' return (b'---- BEGIN SSH2 PUBLIC KEY ----\n' + comment + _wrap_base64(self.public_data) + b'---- END SSH2 PUBLIC KEY ----\n') else: raise KeyExportError('Unknown export format') def write_private_key(self, filename: FilePath, *args, **kwargs) -> None: """Write a private key to a file in the requested format This method is a simple wrapper around :meth:`export_private_key` which writes the exported key data to a file. :param filename: The filename to write the private key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_private_key`. :type filename: :class:`PurePath ` or `str` """ write_file(filename, self.export_private_key(*args, **kwargs)) def write_public_key(self, filename: FilePath, *args, **kwargs) -> None: """Write a public key to a file in the requested format This method is a simple wrapper around :meth:`export_public_key` which writes the exported key data to a file. :param filename: The filename to write the public key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_public_key`. :type filename: :class:`PurePath ` or `str` """ write_file(filename, self.export_public_key(*args, **kwargs)) def append_private_key(self, filename: FilePath, *args, **kwargs) -> None: """Append a private key to a file in the requested format This method is a simple wrapper around :meth:`export_private_key` which appends the exported key data to an existing file. :param filename: The filename to append the private key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_private_key`. :type filename: :class:`PurePath ` or `str` """ write_file(filename, self.export_private_key(*args, **kwargs), 'ab') def append_public_key(self, filename: FilePath, *args, **kwargs) -> None: """Append a public key to a file in the requested format This method is a simple wrapper around :meth:`export_public_key` which appends the exported key data to an existing file. :param filename: The filename to append the public key to. :param \\*args,\\ \\*\\*kwargs: Additional arguments to pass through to :meth:`export_public_key`. :type filename: :class:`PurePath ` or `str` """ write_file(filename, self.export_public_key(*args, **kwargs), 'ab') class SSHCertificate: """Parent class which holds an SSH certificate""" is_x509 = False is_x509_chain = False def __init__(self, algorithm: bytes, sig_algorithms: Sequence[bytes], host_key_algorithms: Sequence[bytes], key: SSHKey, public_data: bytes, comment: _Comment): self.algorithm = algorithm self.sig_algorithms = sig_algorithms self.host_key_algorithms = host_key_algorithms self.key = key self.public_data = public_data self.set_comment(comment) @classmethod def construct(cls, packet: SSHPacket, algorithm: bytes, key_handler: Optional[Type[SSHKey]], comment: _Comment) -> 'SSHCertificate': """Construct an SSH certificate from packetized data""" raise NotImplementedError def __eq__(self, other: object) -> bool: return (isinstance(other, type(self)) and self.public_data == other.public_data) def __hash__(self) -> int: return hash(self.public_data) def get_algorithm(self) -> str: """Return the algorithm associated with this certificate""" return self.algorithm.decode('ascii') def has_comment(self) -> bool: """Return whether a comment is set for this certificate :returns: `bool` """ return bool(self._comment) def get_comment_bytes(self) -> Optional[bytes]: """Return the comment associated with this certificate as a byte string :returns: `bytes` or `None` """ return self._comment def get_comment(self, encoding: str = 'utf-8', errors: str = 'strict') -> Optional[str]: """Return the comment associated with this certificate as a Unicode string :param encoding: The encoding to use to decode the comment as a Unicode string, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode decode errors :type encoding: `str` :type errors: `str` :returns: `str` or `None` :raises: :exc:`UnicodeDecodeError` if the comment cannot be decoded using the specified encoding """ return self._comment.decode(encoding, errors) if self._comment else None def set_comment(self, comment: _Comment, encoding: str = 'utf-8', errors: str = 'strict') -> None: """Set the comment associated with this certificate :param comment: The new comment to associate with this key :param encoding: The Unicode encoding to use to encode the comment, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode encode errors :type comment: `str`, `bytes`, or `None` :type encoding: `str` :type errors: `str` :raises: :exc:`UnicodeEncodeError` if the comment cannot be encoded using the specified encoding """ if isinstance(comment, str): comment = comment.encode(encoding, errors) self._comment = comment or None def export_certificate(self, format_name: str = 'openssh') -> bytes: """Export a certificate in the requested format This function returns this certificate encoded in the requested format. Available formats include: der, pem, openssh, rfc4716 By default, OpenSSH format will be used. :param format_name: (optional) The format to export the certificate in. :type format_name: `str` :returns: `bytes` representing the exported certificate """ if self.is_x509: if format_name == 'rfc4716': raise KeyExportError('RFC4716 format is not supported for ' 'X.509 certificates') else: if format_name in ('der', 'pem'): raise KeyExportError('DER and PEM formats are not supported ' 'for OpenSSH certificates') if format_name == 'der': return self.public_data elif format_name == 'pem': return (b'-----BEGIN CERTIFICATE-----\n' + _wrap_base64(self.public_data) + b'-----END CERTIFICATE-----\n') elif format_name == 'openssh': if self._comment: comment = b' ' + self._comment else: comment = b'' return (self.algorithm + b' ' + binascii.b2a_base64(self.public_data)[:-1] + comment + b'\n') elif format_name == 'rfc4716': if self._comment: comment = (b'Comment: "' + self._comment + b'"\n') else: comment = b'' return (b'---- BEGIN SSH2 PUBLIC KEY ----\n' + comment + _wrap_base64(self.public_data) + b'---- END SSH2 PUBLIC KEY ----\n') else: raise KeyExportError('Unknown export format') def write_certificate(self, filename: FilePath, format_name: str = 'openssh') -> None: """Write a certificate to a file in the requested format This function is a simple wrapper around export_certificate which writes the exported certificate to a file. :param filename: The filename to write the certificate to. :param format_name: (optional) The format to export the certificate in. :type filename: :class:`PurePath ` or `str` :type format_name: `str` """ write_file(filename, self.export_certificate(format_name)) def append_certificate(self, filename: FilePath, format_name: str = 'openssh') -> None: """Append a certificate to a file in the requested format This function is a simple wrapper around export_certificate which appends the exported certificate to an existing file. :param filename: The filename to append the certificate to. :param format_name: (optional) The format to export the certificate in. :type filename: :class:`PurePath ` or `str` :type format_name: `str` """ write_file(filename, self.export_certificate(format_name), 'ab') class SSHOpenSSHCertificate(SSHCertificate): """Class which holds an OpenSSH certificate""" _user_option_encoders: _OpenSSHCertEncoders = () _user_extension_encoders: _OpenSSHCertEncoders = () _host_option_encoders: _OpenSSHCertEncoders = () _host_extension_encoders: _OpenSSHCertEncoders = () _user_option_decoders: _OpenSSHCertDecoders = {} _user_extension_decoders: _OpenSSHCertDecoders = {} _host_option_decoders: _OpenSSHCertDecoders = {} _host_extension_decoders: _OpenSSHCertDecoders = {} def __init__(self, algorithm: bytes, key: SSHKey, data: bytes, principals: Sequence[str], options: _OpenSSHCertOptions, signing_key: SSHKey, serial: int, cert_type: int, key_id: str, valid_after: int, valid_before: int, comment: _Comment): super().__init__(algorithm, key.sig_algorithms, key.cert_algorithms or (algorithm,), key, data, comment) self.principals = principals self.options = options self.signing_key = signing_key self._serial = serial self._cert_type = cert_type self._key_id = key_id self._valid_after = valid_after self._valid_before = valid_before @classmethod def generate(cls, signing_key: 'SSHKey', algorithm: bytes, key: 'SSHKey', serial: int, cert_type: int, key_id: str, principals: Sequence[str], valid_after: int, valid_before: int, options: _OpenSSHCertOptions, sig_alg: bytes, comment: _Comment) -> 'SSHOpenSSHCertificate': """Generate a new SSH certificate""" principal_bytes = b''.join(String(p) for p in principals) if cert_type == CERT_TYPE_USER: cert_options = cls._encode_options(options, cls._user_option_encoders) cert_extensions = cls._encode_options(options, cls._user_extension_encoders) else: cert_options = cls._encode_options(options, cls._host_option_encoders) cert_extensions = cls._encode_options(options, cls._host_extension_encoders) key = key.convert_to_public() data = b''.join((String(algorithm), cls._encode(key, serial, cert_type, key_id, principal_bytes, valid_after, valid_before, cert_options, cert_extensions), String(signing_key.public_data))) data += String(signing_key.sign(data, sig_alg)) signing_key = signing_key.convert_to_public() return cls(algorithm, key, data, principals, options, signing_key, serial, cert_type, key_id, valid_after, valid_before, comment) @classmethod def construct(cls, packet: SSHPacket, algorithm: bytes, key_handler: Optional[Type[SSHKey]], comment: _Comment) -> 'SSHOpenSSHCertificate': """Construct an SSH certificate from packetized data""" assert key_handler is not None key_params, serial, cert_type, key_id, \ principals, valid_after, valid_before, \ options, extensions = cls._decode(packet, key_handler) signing_key = decode_ssh_public_key(packet.get_string()) data = packet.get_consumed_payload() signature = packet.get_string() packet.check_end() if not signing_key.verify(data, signature): raise KeyImportError('Invalid certificate signature') key = key_handler.make_public(key_params) data = packet.get_consumed_payload() try: key_id_bytes = key_id.decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in key ID') from None packet = SSHPacket(principals) principals: List[str] = [] while packet: try: principal = packet.get_string().decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in principal ' 'name') from None principals.append(principal) if cert_type == CERT_TYPE_USER: cert_options = cls._decode_options( options, cls._user_option_decoders, True) cert_options.update(cls._decode_options( extensions, cls._user_extension_decoders, False)) elif cert_type == CERT_TYPE_HOST: cert_options = cls._decode_options( options, cls._host_option_decoders, True) cert_options.update(cls._decode_options( extensions, cls._host_extension_decoders, False)) else: raise KeyImportError('Unknown certificate type') return cls(algorithm, key, data, principals, cert_options, signing_key, serial, cert_type, key_id_bytes, valid_after, valid_before, comment) @classmethod def _encode(cls, key: SSHKey, serial: int, cert_type: int, key_id: str, principals: bytes, valid_after: int, valid_before: int, options: bytes, extensions: bytes) -> bytes: """Encode an SSH certificate""" raise NotImplementedError @classmethod def _decode(cls, packet: SSHPacket, key_handler: Type[SSHKey]) -> _OpenSSHCertParams: """Decode an SSH certificate""" raise NotImplementedError @staticmethod def _encode_options(options: _OpenSSHCertOptions, encoders: _OpenSSHCertEncoders) -> bytes: """Encode options found in this certificate""" result = [] for name, encoder in encoders: value = options.get(name) if value: result.append(String(name) + String(encoder(value))) return b''.join(result) @staticmethod def _encode_bool(_value: object) -> bytes: """Encode a boolean option value""" return b'' @staticmethod def _encode_force_cmd(force_command: object) -> bytes: """Encode a force-command option""" return String(cast(BytesOrStr, force_command)) @staticmethod def _encode_source_addr(source_address: object) -> bytes: """Encode a source-address option""" return NameList(str(addr).encode('ascii') for addr in cast(Sequence[IPNetwork], source_address)) @staticmethod def _decode_bool(_packet: SSHPacket) -> bool: """Decode a boolean option value""" return True @staticmethod def _decode_force_cmd(packet: SSHPacket) -> str: """Decode a force-command option""" try: return packet.get_string().decode('utf-8') except UnicodeDecodeError: raise KeyImportError('Invalid characters in command') from None @staticmethod def _decode_source_addr(packet: SSHPacket) -> Sequence[IPNetwork]: """Decode a source-address option""" try: return [ip_network(addr.decode('ascii')) for addr in packet.get_namelist()] except (UnicodeDecodeError, ValueError): raise KeyImportError('Invalid source address') from None @staticmethod def _decode_options(options: bytes, decoders: _OpenSSHCertDecoders, critical: bool = True) -> _OpenSSHCertOptions: """Decode options found in this certificate""" packet = SSHPacket(options) result: _OpenSSHCertOptions = {} while packet: name = packet.get_string() decoder = decoders.get(name) if decoder: data_packet = SSHPacket(packet.get_string()) result[name.decode('ascii')] = decoder(data_packet) data_packet.check_end() elif critical: raise KeyImportError('Unrecognized critical option: ' + name.decode('ascii', errors='replace')) return result def validate(self, cert_type: int, principal: Optional[str]) -> None: """Validate an OpenSSH certificate""" if self._cert_type != cert_type: raise ValueError('Invalid certificate type') now = time.time() if now < self._valid_after: raise ValueError('Certificate not yet valid') if now >= self._valid_before: raise ValueError('Certificate expired') if principal is not None and self.principals and \ principal not in self.principals: raise ValueError('Certificate principal mismatch') class SSHOpenSSHCertificateV01(SSHOpenSSHCertificate): """Encoder/decoder class for version 01 OpenSSH certificates""" _user_option_encoders = ( ('force-command', SSHOpenSSHCertificate._encode_force_cmd), ('source-address', SSHOpenSSHCertificate._encode_source_addr) ) _user_extension_encoders = ( ('permit-X11-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-agent-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-port-forwarding', SSHOpenSSHCertificate._encode_bool), ('permit-pty', SSHOpenSSHCertificate._encode_bool), ('permit-user-rc', SSHOpenSSHCertificate._encode_bool), ('no-touch-required', SSHOpenSSHCertificate._encode_bool) ) _user_option_decoders = { b'force-command': SSHOpenSSHCertificate._decode_force_cmd, b'source-address': SSHOpenSSHCertificate._decode_source_addr } _user_extension_decoders = { b'permit-X11-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-agent-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-port-forwarding': SSHOpenSSHCertificate._decode_bool, b'permit-pty': SSHOpenSSHCertificate._decode_bool, b'permit-user-rc': SSHOpenSSHCertificate._decode_bool, b'no-touch-required': SSHOpenSSHCertificate._decode_bool } @classmethod def _encode(cls, key: SSHKey, serial: int, cert_type: int, key_id: str, principals: bytes, valid_after: int, valid_before: int, options: bytes, extensions: bytes) -> bytes: """Encode a version 01 SSH certificate""" return b''.join((String(os.urandom(32)), key.encode_ssh_public(), UInt64(serial), UInt32(cert_type), String(key_id), String(principals), UInt64(valid_after), UInt64(valid_before), String(options), String(extensions), String(''))) @classmethod def _decode(cls, packet: SSHPacket, key_handler: Type[SSHKey]) -> _OpenSSHCertParams: """Decode a version 01 SSH certificate""" _ = packet.get_string() # nonce key_params = key_handler.decode_ssh_public(packet) serial = packet.get_uint64() cert_type = packet.get_uint32() key_id = packet.get_string() principals = packet.get_string() valid_after = packet.get_uint64() valid_before = packet.get_uint64() options = packet.get_string() extensions = packet.get_string() _ = packet.get_string() # reserved return (key_params, serial, cert_type, key_id, principals, valid_after, valid_before, options, extensions) class SSHX509Certificate(SSHCertificate): """Encoder/decoder class for SSH X.509 certificates""" is_x509 = True def __init__(self, key: SSHKey, x509_cert: 'X509Certificate', comment: _Comment = None): super().__init__(b'x509v3-' + key.algorithm, key.x509_algorithms, key.x509_algorithms, key, x509_cert.data, x509_cert.comment or comment) self.subject = x509_cert.subject self.issuer = x509_cert.issuer self.issuer_hash = x509_cert.issuer_hash self.user_principals = x509_cert.user_principals self.x509_cert = x509_cert def _expand_trust_store(self, cert: 'SSHX509Certificate', trusted_cert_paths: Sequence[FilePath], trust_store: Set['SSHX509Certificate']) -> None: """Look up certificates by issuer hash to build a trust store""" issuer_hash = cert.issuer_hash for path in trusted_cert_paths: idx = 0 try: while True: cert_path = Path(path, issuer_hash + '.' + str(idx)) idx += 1 c = cast('SSHX509Certificate', read_certificate(cert_path)) if c.subject != cert.issuer or c in trust_store: continue trust_store.add(c) self._expand_trust_store(c, trusted_cert_paths, trust_store) except (OSError, KeyImportError): pass @classmethod def construct(cls, packet: SSHPacket, algorithm: bytes, key_handler: Optional[Type[SSHKey]], comment: _Comment) -> 'SSHX509Certificate': """Construct an SSH X.509 certificate from packetized data""" raise RuntimeError # pragma: no cover @classmethod def generate(cls, signing_key: 'SSHKey', key: 'SSHKey', subject: str, issuer: Optional[str], serial: Optional[int], valid_after: int, valid_before: int, ca: bool, ca_path_len: Optional[int], purposes: X509CertPurposes, user_principals: _CertPrincipals, host_principals: _CertPrincipals, hash_name: str, comment: _Comment) -> 'SSHX509Certificate': """Generate a new X.509 certificate""" key = key.convert_to_public() x509_cert = generate_x509_certificate(signing_key.pyca_key, key.pyca_key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_name, comment) return cls(key, x509_cert) @classmethod def construct_from_der(cls, data: bytes, comment: _Comment = None) -> 'SSHX509Certificate': """Construct an SSH X.509 certificate from DER data""" try: x509_cert = import_x509_certificate(data) key = import_public_key(x509_cert.key_data) except ValueError as exc: raise KeyImportError(str(exc)) from None return cls(key, x509_cert, comment) def validate_chain(self, trust_chain: Sequence['SSHX509Certificate'], trusted_certs: Sequence['SSHX509Certificate'], trusted_cert_paths: Sequence[FilePath], purposes: X509CertPurposes, user_principal: str = '', host_principal: str = '') -> None: """Validate an X.509 certificate chain""" trust_store = {c for c in trust_chain if c.subject != c.issuer} | \ set(trusted_certs) if trusted_cert_paths: self._expand_trust_store(self, trusted_cert_paths, trust_store) for c in trust_chain: self._expand_trust_store(c, trusted_cert_paths, trust_store) self.x509_cert.validate([c.x509_cert for c in trust_store], purposes, user_principal, host_principal) class SSHX509CertificateChain(SSHCertificate): """Encoder/decoder class for an SSH X.509 certificate chain""" is_x509_chain = True def __init__(self, algorithm: bytes, certs: Sequence[SSHCertificate], ocsp_responses: Sequence[bytes], comment: _Comment): key = certs[0].key data = self._public_data(algorithm, certs, ocsp_responses) super().__init__(algorithm, key.x509_algorithms, key.x509_algorithms, key, data, comment) x509_certs = cast(Sequence[SSHX509Certificate], certs) first_cert = x509_certs[0] last_cert = x509_certs[-1] self.subject = first_cert.subject self.issuer = last_cert.issuer self.user_principals = first_cert.user_principals self._certs = x509_certs self._ocsp_responses = ocsp_responses @staticmethod def _public_data(algorithm: bytes, certs: Sequence[SSHCertificate], ocsp_responses: Sequence[bytes]) -> bytes: """Return the X509 chain public data""" return (String(algorithm) + UInt32(len(certs)) + b''.join(String(c.public_data) for c in certs) + UInt32(len(ocsp_responses)) + b''.join(String(resp) for resp in ocsp_responses)) @classmethod def construct(cls, packet: SSHPacket, algorithm: bytes, key_handler: Optional[Type[SSHKey]], comment: _Comment) -> 'SSHX509CertificateChain': """Construct an SSH X.509 certificate from packetized data""" cert_count = packet.get_uint32() certs = [import_certificate(packet.get_string()) for _ in range(cert_count)] ocsp_resp_count = packet.get_uint32() ocsp_responses = [packet.get_string() for _ in range(ocsp_resp_count)] packet.check_end() if not certs: raise KeyImportError('No certificates present') return cls(algorithm, certs, ocsp_responses, comment) @classmethod def construct_from_certs(cls, certs: Sequence['SSHCertificate']) -> \ 'SSHX509CertificateChain': """Construct an SSH X.509 certificate chain from certificates""" cert = certs[0] return cls(cert.algorithm, certs, (), cert.get_comment_bytes()) def adjust_public_data(self, algorithm: bytes) -> bytes: """Adjust public data to reflect chosen signature algorithm""" return self._public_data(algorithm, self._certs, self._ocsp_responses) def validate_chain(self, trusted_certs: Sequence[SSHX509Certificate], trusted_cert_paths: Sequence[FilePath], revoked_certs: Set[SSHX509Certificate], purposes: X509CertPurposes, user_principal: str = '', host_principal: str = '') -> None: """Validate an X.509 certificate chain""" if revoked_certs: for cert in self._certs: if cert in revoked_certs: raise ValueError('Revoked X.509 certificate in ' 'certificate chain') self._certs[0].validate_chain(self._certs[1:], trusted_certs, trusted_cert_paths, purposes, user_principal, host_principal) class SSHKeyPair: """Parent class which represents an asymmetric key pair This is an abstract class which provides a method to sign data with a private key and members to access the corresponding algorithm and public key or certificate information needed to identify what key was used for signing. """ _key_type = 'unknown' def __init__(self, algorithm: bytes, sig_algorithm: bytes, sig_algorithms: Sequence[bytes], host_key_algorithms: Sequence[bytes], public_data: bytes, comment: _Comment, cert: Optional[SSHCertificate] = None, filename: Optional[bytes] = None, use_executor: bool = False, use_webauthn: bool = False): self.key_algorithm = algorithm self.key_public_data = public_data self.set_comment(comment) self._cert = cert self._filename = filename self.use_executor = use_executor self.use_webauthn = use_webauthn if cert: if cert.key.public_data != self.key_public_data: raise ValueError('Certificate key mismatch') self.algorithm = cert.algorithm if cert.is_x509_chain: self.sig_algorithm = cert.algorithm else: self.sig_algorithm = sig_algorithm self.sig_algorithms = cert.sig_algorithms self.host_key_algorithms = cert.host_key_algorithms self.public_data = cert.public_data else: self.algorithm = algorithm self.sig_algorithm = algorithm self.sig_algorithms = sig_algorithms self.host_key_algorithms = host_key_algorithms self.public_data = public_data def get_key_type(self) -> str: """Return what type of key pair this is This method returns 'local' for locally loaded keys, and 'agent' for keys managed by an SSH agent. """ return self._key_type @property def has_cert(self) -> bool: """ Return if this key pair has an associated cert""" return bool(self._cert) @property def has_x509_chain(self) -> bool: """ Return if this key pair has an associated X.509 cert chain""" return self._cert.is_x509_chain if self._cert else False def get_algorithm(self) -> str: """Return the algorithm associated with this key pair""" return self.algorithm.decode('ascii') def get_agent_private_key(self) -> bytes: """Return binary encoding of keypair for upload to SSH agent""" # pylint: disable=no-self-use raise KeyImportError('Private key export to agent not supported') def get_comment_bytes(self) -> Optional[bytes]: """Return the comment associated with this key pair as a byte string :returns: `bytes` or `None` """ return self._comment or self._filename def get_comment(self, encoding: str = 'utf-8', errors: str = 'strict') -> Optional[str]: """Return the comment associated with this key pair as a Unicode string :param encoding: The encoding to use to decode the comment as a Unicode string, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode decode errors :type encoding: `str` :type errors: `str` :returns: `str` or `None` :raises: :exc:`UnicodeDecodeError` if the comment cannot be decoded using the specified encoding """ comment = self.get_comment_bytes() return comment.decode(encoding, errors) if comment else None def set_comment(self, comment: _Comment, encoding: str = 'utf-8', errors: str = 'strict') -> None: """Set the comment associated with this key pair :param comment: The new comment to associate with this key :param encoding: The Unicode encoding to use to encode the comment, defaulting to UTF-8 :param errors: The error handling scheme to use for Unicode encode errors :type comment: `str`, `bytes`, or `None` :type encoding: `str` :type errors: `str` :raises: :exc:`UnicodeEncodeError` if the comment cannot be encoded using the specified encoding """ if isinstance(comment, str): comment = comment.encode(encoding, errors) self._comment = comment or None def set_certificate(self, cert: SSHCertificate) -> None: """Set certificate to use with this key""" if cert.key.public_data != self.key_public_data: raise ValueError('Certificate key mismatch') self._cert = cert self.algorithm = cert.algorithm if cert.is_x509_chain: self.sig_algorithm = cert.algorithm else: self.sig_algorithm = self.key_algorithm self.sig_algorithms = cert.sig_algorithms self.host_key_algorithms = cert.host_key_algorithms self.public_data = cert.public_data def set_sig_algorithm(self, sig_algorithm: bytes) -> None: """Set the signature algorithm to use when signing data""" try: sig_algorithm = _certificate_sig_alg_map[sig_algorithm] except KeyError: pass self.sig_algorithm = sig_algorithm if not self.has_cert: self.algorithm = sig_algorithm elif self.has_x509_chain: self.algorithm = sig_algorithm cert = cast('SSHX509CertificateChain', self._cert) self.public_data = cert.adjust_public_data(sig_algorithm) def sign(self, data: bytes) -> bytes: """Sign a block of data with this private key""" # pylint: disable=no-self-use raise RuntimeError # pragma: no cover class SSHLocalKeyPair(SSHKeyPair): """Class which holds a local asymmetric key pair This class holds a private key and associated public data which can either be the matching public key or a certificate which has signed that public key. """ _key_type = 'local' def __init__(self, key: SSHKey, pubkey: Optional[SSHKey] = None, cert: Optional[SSHCertificate] = None): if pubkey and pubkey.public_data != key.public_data: raise ValueError('Public key mismatch') if key.has_comment(): comment = key.get_comment_bytes() elif cert and cert.has_comment(): comment = cert.get_comment_bytes() elif pubkey and pubkey.has_comment(): comment = pubkey.get_comment_bytes() else: comment = None super().__init__(key.algorithm, key.algorithm, key.sig_algorithms, key.sig_algorithms, key.public_data, comment, cert, key.get_filename(), key.use_executor, key.use_webauthn) self._key = key def get_agent_private_key(self) -> bytes: """Return binary encoding of keypair for upload to SSH agent""" if self._cert: data = String(self.public_data) + \ self._key.encode_agent_cert_private() else: data = self._key.encode_ssh_private() return String(self.algorithm) + data def sign(self, data: bytes) -> bytes: """Sign a block of data with this private key""" return self._key.sign(data, self.sig_algorithm) def _parse_openssh(data: bytes) -> Tuple[bytes, Optional[bytes], bytes]: """Parse an OpenSSH format public key or certificate""" line = data.split(None, 2) if len(line) < 2: raise KeyImportError('Invalid OpenSSH public key or certificate') elif len(line) == 2: comment = None else: comment = line[2] if (line[0] not in _public_key_alg_map and line[0] not in _certificate_alg_map): raise KeyImportError('Unknown OpenSSH public key algorithm') try: return line[0], comment, binascii.a2b_base64(line[1]) except binascii.Error: raise KeyImportError('Invalid OpenSSH public key ' 'or certificate') from None def _parse_pem(data: bytes) -> Tuple[Mapping[bytes, bytes], bytes]: """Parse a PEM data block""" start = 0 end: Optional[int] = None headers: Dict[bytes, bytes] = {} while True: end = data.find(b'\n', start) + 1 line = data[start:end] if end else data[start:] line = line.rstrip() if b':' in line: hdr, value = line.split(b':', 1) headers[hdr.strip()] = value.strip() else: break start = end if end != 0 else len(data) try: return headers, binascii.a2b_base64(data[start:]) except binascii.Error: raise KeyImportError('Invalid PEM data') from None def _parse_rfc4716(data: bytes) -> Tuple[Optional[bytes], bytes]: """Parse an RFC 4716 data block""" start = 0 end = None hdr = b'' comment = None while True: end = data.find(b'\n', start) + 1 line = data[start:end] if end else data[start:] line = line.rstrip() if line[-1:] == b'\\': hdr += line[:-1] else: hdr += line if b':' in hdr: hdr, value = hdr.split(b':', 1) if hdr.strip() == b'Comment': comment = value.strip() if comment[:1] == b'"' and comment[-1:] == b'"': comment = comment[1:-1] hdr = b'' else: break start = end if end != 0 else len(data) try: return comment, binascii.a2b_base64(data[start:]) except binascii.Error: raise KeyImportError('Invalid RFC 4716 data') from None def _match_block(data: bytes, start: int, header: bytes, fmt: str) -> Tuple[bytes, int]: """Match a block of data wrapped in a header/footer""" match = re.compile(b'^' + header[:5] + b'END' + header[10:] + rb'[ \t\r\f\v]*$', re.M).search(data, start) if not match: raise KeyImportError(f'Missing {fmt} footer') return data[start:match.start()], match.end() def _match_next(data: bytes, keytype: bytes, public: bool = False) -> \ Tuple[Optional[str], Tuple, Optional[int]]: """Find the next key/certificate and call the appropriate decode""" end: Optional[int] if data.startswith(b'\x30'): try: key_data, end = der_decode_partial(data) return 'der', (key_data,), end except ASN1DecodeError: pass start = 0 end = None while end != 0: end = data.find(b'\n', start) + 1 line = data[start:end] if end else data[start:] line = line.rstrip() if (line.startswith(b'-----BEGIN ') and line.endswith(b' ' + keytype + b'-----')): pem_name = line[11:-(6+len(keytype))].strip() data, end = _match_block(data, end, line, 'PEM') headers, data = _parse_pem(data) return 'pem', (pem_name, headers, data), end elif public: if line == b'---- BEGIN SSH2 PUBLIC KEY ----': data, end = _match_block(data, end, line, 'RFC 4716') return 'rfc4716', _parse_rfc4716(data), end else: try: cert = _parse_openssh(line) except KeyImportError: pass else: return 'openssh', cert, (end if end else len(data)) start = end return None, (), len(data) def _decode_pkcs1_private( pem_name: bytes, key_data: object, unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PKCS#1 format private key""" handler = _pem_map.get(pem_name) if handler is None: raise KeyImportError('Unknown PEM key type: ' + pem_name.decode('ascii')) key_params = handler.decode_pkcs1_private(key_data) if key_params is None: raise KeyImportError( f'Invalid {pem_name.decode("ascii")} private key') if pem_name == b'RSA': key_params = cast(Tuple, key_params) + \ (unsafe_skip_rsa_key_validation,) return handler.make_private(key_params) def _decode_pkcs1_public(pem_name: bytes, key_data: object) -> SSHKey: """Decode a PKCS#1 format public key""" handler = _pem_map.get(pem_name) if handler is None: raise KeyImportError('Unknown PEM key type: ' + pem_name.decode('ascii')) key_params = handler.decode_pkcs1_public(key_data) if key_params is None: raise KeyImportError(f'Invalid {pem_name.decode("ascii")} public key') return handler.make_public(key_params) def _decode_pkcs8_private( key_data: object, unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PKCS#8 format private key""" if (isinstance(key_data, tuple) and len(key_data) >= 3 and key_data[0] in (0, 1) and isinstance(key_data[1], tuple) and 1 <= len(key_data[1]) <= 2 and isinstance(key_data[2], bytes)): if len(key_data[1]) == 2: alg, alg_params = key_data[1] else: alg, alg_params = key_data[1][0], OMIT handler = _pkcs8_oid_map.get(alg) if handler is None: raise KeyImportError('Unknown PKCS#8 algorithm') key_params = handler.decode_pkcs8_private(alg_params, key_data[2]) if key_params is None: key_type = handler.pem_name.decode('ascii') if \ handler.pem_name else 'PKCS#8' raise KeyImportError(f'Invalid {key_type} private key') if alg == ObjectIdentifier('1.2.840.113549.1.1.1'): key_params = cast(Tuple, key_params) + \ (unsafe_skip_rsa_key_validation,) return handler.make_private(key_params) else: raise KeyImportError('Invalid PKCS#8 private key') def _decode_pkcs8_public(key_data: object) -> SSHKey: """Decode a PKCS#8 format public key""" if (isinstance(key_data, tuple) and len(key_data) == 2 and isinstance(key_data[0], tuple) and 1 <= len(key_data[0]) <= 2 and isinstance(key_data[1], BitString) and key_data[1].unused == 0): if len(key_data[0]) == 2: alg, alg_params = key_data[0] else: alg, alg_params = key_data[0][0], OMIT handler = _pkcs8_oid_map.get(alg) if handler is None: raise KeyImportError('Unknown PKCS#8 algorithm') key_params = handler.decode_pkcs8_public(alg_params, key_data[1].value) if key_params is None: key_type = handler.pem_name.decode('ascii') if \ handler.pem_name else 'PKCS#8' raise KeyImportError(f'Invalid {key_type} public key') return handler.make_public(key_params) else: raise KeyImportError('Invalid PKCS#8 public key') def _decode_openssh_private( data: bytes, passphrase: Optional[BytesOrStr], unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode an OpenSSH format private key""" try: if not data.startswith(_OPENSSH_KEY_V1): raise KeyImportError('Unrecognized OpenSSH private key type') data = data[len(_OPENSSH_KEY_V1):] packet = SSHPacket(data) cipher_name = packet.get_string() kdf = packet.get_string() kdf_data = packet.get_string() nkeys = packet.get_uint32() _ = packet.get_string() # public_key key_data = packet.get_string() mac = packet.get_remaining_payload() if nkeys != 1: raise KeyImportError('Invalid OpenSSH private key') if cipher_name != b'none': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') try: key_size, iv_size, _, _, _, _ = \ get_encryption_params(cipher_name) except KeyError: raise KeyEncryptionError('Unknown cipher: ' + cipher_name.decode('ascii')) from None if kdf != b'bcrypt': raise KeyEncryptionError('Unknown kdf: ' + kdf.decode('ascii')) if not _bcrypt_available: # pragma: no cover raise KeyEncryptionError('OpenSSH private key encryption ' 'requires bcrypt with KDF support') packet = SSHPacket(kdf_data) salt = packet.get_string() rounds = packet.get_uint32() packet.check_end() if isinstance(passphrase, str): passphrase = passphrase.encode('utf-8') try: bcrypt_key = bcrypt.kdf(passphrase, salt, key_size + iv_size, rounds, ignore_few_rounds=True) except ValueError: raise KeyEncryptionError('Invalid OpenSSH ' 'private key') from None cipher = get_encryption(cipher_name, bcrypt_key[:key_size], bcrypt_key[key_size:]) decrypted_key = cipher.decrypt_packet(0, b'', key_data, 0, mac) if decrypted_key is None: raise KeyEncryptionError('Incorrect passphrase') key_data = decrypted_key packet = SSHPacket(key_data) check1 = packet.get_uint32() check2 = packet.get_uint32() if check1 != check2: if cipher_name != b'none': raise KeyEncryptionError('Incorrect passphrase') from None else: raise KeyImportError('Invalid OpenSSH private key') alg = packet.get_string() handler = _public_key_alg_map.get(alg) if not handler: raise KeyImportError('Unknown OpenSSH private key algorithm') key_params = handler.decode_ssh_private(packet) comment = packet.get_string() pad = packet.get_remaining_payload() if len(pad) >= 256 or pad != bytes(range(1, len(pad) + 1)): raise KeyImportError('Invalid OpenSSH private key') if alg == b'ssh-rsa': key_params = cast(Tuple, key_params) + \ (unsafe_skip_rsa_key_validation,) key = handler.make_private(key_params) key.set_comment(comment) return key except PacketDecodeError: raise KeyImportError('Invalid OpenSSH private key') from None def _decode_openssh_public(data: bytes) -> SSHKey: """Decode public key within OpenSSH format private key""" try: if not data.startswith(_OPENSSH_KEY_V1): raise KeyImportError('Unrecognized OpenSSH private key type') data = data[len(_OPENSSH_KEY_V1):] packet = SSHPacket(data) _ = packet.get_string() # cipher_name _ = packet.get_string() # kdf _ = packet.get_string() # kdf_data nkeys = packet.get_uint32() pubkey = packet.get_string() if nkeys != 1: raise KeyImportError('Invalid OpenSSH private key') return decode_ssh_public_key(pubkey) except PacketDecodeError: raise KeyImportError('Invalid OpenSSH private key') from None def _decode_der_private( key_data: object, passphrase: Optional[BytesOrStr], unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a DER format private key""" # First, if there's a passphrase, try to decrypt PKCS#8 if passphrase is not None: try: key_data = pkcs8_decrypt(key_data, passphrase) except KeyEncryptionError: # Decryption failed - try decoding it as unencrypted pass # Then, try to decode PKCS#8 try: return _decode_pkcs8_private(key_data, unsafe_skip_rsa_key_validation) except KeyImportError: # PKCS#8 failed - try PKCS#1 instead pass # If that fails, try each of the possible PKCS#1 encodings for pem_name in _pem_map: try: return _decode_pkcs1_private(pem_name, key_data, unsafe_skip_rsa_key_validation) except KeyImportError: # Try the next PKCS#1 encoding pass raise KeyImportError('Invalid DER private key') def _decode_der_public(key_data: object) -> SSHKey: """Decode a DER format public key""" # First, try to decode PKCS#8 try: return _decode_pkcs8_public(key_data) except KeyImportError: # PKCS#8 failed - try PKCS#1 instead pass # If that fails, try each of the possible PKCS#1 encodings for pem_name in _pem_map: try: return _decode_pkcs1_public(pem_name, key_data) except KeyImportError: # Try the next PKCS#1 encoding pass raise KeyImportError('Invalid DER public key') def _decode_der_certificate(data: bytes, comment: _Comment = None) -> SSHCertificate: """Decode a DER format X.509 certificate""" return SSHX509Certificate.construct_from_der(data, comment) def _decode_pem_private( pem_name: bytes, headers: Mapping[bytes, bytes], data: bytes, passphrase: Optional[BytesOrStr], unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PEM format private key""" if pem_name == b'OPENSSH': return _decode_openssh_private(data, passphrase, unsafe_skip_rsa_key_validation) if headers.get(b'Proc-Type') == b'4,ENCRYPTED': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') dek_info = headers.get(b'DEK-Info', b'').split(b',') if len(dek_info) != 2: raise KeyImportError('Invalid PEM encryption params') alg, iv = dek_info try: iv = binascii.a2b_hex(iv) except binascii.Error: raise KeyImportError('Invalid PEM encryption params') from None try: data = pkcs1_decrypt(data, alg, iv, passphrase) except KeyEncryptionError: raise KeyImportError('Unable to decrypt PKCS#1 ' 'private key') from None try: key_data = der_decode(data) except ASN1DecodeError: raise KeyImportError('Invalid PEM private key') from None if pem_name == b'ENCRYPTED': if passphrase is None: raise KeyImportError('Passphrase must be specified to import ' 'encrypted private keys') pem_name = b'' try: key_data = pkcs8_decrypt(key_data, passphrase) except KeyEncryptionError: raise KeyImportError('Unable to decrypt PKCS#8 ' 'private key') from None if pem_name: return _decode_pkcs1_private(pem_name, key_data, unsafe_skip_rsa_key_validation) else: return _decode_pkcs8_private(key_data, unsafe_skip_rsa_key_validation) def _decode_pem_public(pem_name: bytes, data: bytes) -> SSHKey: """Decode a PEM format public key""" try: key_data = der_decode(data) except ASN1DecodeError: raise KeyImportError('Invalid PEM public key') from None if pem_name: return _decode_pkcs1_public(pem_name, key_data) else: return _decode_pkcs8_public(key_data) def _decode_pem_certificate(pem_name: bytes, data: bytes) -> SSHCertificate: """Decode a PEM format X.509 certificate""" if pem_name == b'TRUSTED': # Strip off OpenSSL trust information try: _, end = der_decode_partial(data) data = data[:end] except ASN1DecodeError: raise KeyImportError('Invalid PEM trusted certificate') from None elif pem_name: raise KeyImportError('Invalid PEM certificate') return SSHX509Certificate.construct_from_der(data) def _decode_private( data: bytes, passphrase: Optional[BytesOrStr], unsafe_skip_rsa_key_validation: Optional[bool]) -> \ Tuple[Optional[SSHKey], Optional[int]]: """Decode a private key""" fmt, key_info, end = _match_next(data, b'PRIVATE KEY') key: Optional[SSHKey] if fmt == 'der': key = _decode_der_private(key_info[0], passphrase, unsafe_skip_rsa_key_validation) elif fmt == 'pem': pem_name, headers, data = key_info key = _decode_pem_private(pem_name, headers, data, passphrase, unsafe_skip_rsa_key_validation) else: key = None return key, end def _decode_public(data: bytes) -> Tuple[Optional[SSHKey], Optional[int]]: """Decode a public key""" fmt, key_info, end = _match_next(data, b'PUBLIC KEY', public=True) key: Optional[SSHKey] if fmt == 'der': key = _decode_der_public(key_info[0]) elif fmt == 'pem': pem_name, _, data = key_info key = _decode_pem_public(pem_name, data) elif fmt == 'openssh': algorithm, comment, data = key_info key = decode_ssh_public_key(data) if algorithm != key.algorithm: raise KeyImportError('Public key algorithm mismatch') key.set_comment(comment) elif fmt == 'rfc4716': comment, data = key_info key = decode_ssh_public_key(data) key.set_comment(comment) else: fmt, key_info, end = _match_next(data, b'PRIVATE KEY') if fmt == 'pem' and key_info[0] == b'OPENSSH': key = _decode_openssh_public(key_info[2]) else: key, _ = _decode_private(data, None, False) if key: key = key.convert_to_public() return key, end def _decode_certificate(data: bytes) -> \ Tuple[Optional[SSHCertificate], Optional[int]]: """Decode a certificate""" fmt, key_info, end = _match_next(data, b'CERTIFICATE', public=True) cert: Optional[SSHCertificate] if fmt == 'der': cert = _decode_der_certificate(data[:end]) elif fmt == 'pem': pem_name, _, data = key_info cert = _decode_pem_certificate(pem_name, data) elif fmt == 'openssh': algorithm, comment, data = key_info if algorithm.startswith(b'x509v3-'): cert = _decode_der_certificate(data, comment) else: cert = decode_ssh_certificate(data, comment) elif fmt == 'rfc4716': comment, data = key_info cert = decode_ssh_certificate(data, comment) else: cert = None return cert, end def _decode_private_list( data: bytes, passphrase: Optional[BytesOrStr], unsafe_skip_rsa_key_validation: Optional[bool]) -> Sequence[SSHKey]: """Decode a private key list""" keys: List[SSHKey] = [] while data: key, end = _decode_private(data, passphrase, unsafe_skip_rsa_key_validation) if key: keys.append(key) data = data[end:] return keys def _decode_public_list(data: bytes) -> Sequence[SSHKey]: """Decode a public key list""" keys: List[SSHKey] = [] while data: key, end = _decode_public(data) if key: keys.append(key) data = data[end:] return keys def _decode_certificate_list(data: bytes) -> Sequence[SSHCertificate]: """Decode a certificate list""" certs: List[SSHCertificate] = [] while data: cert, end = _decode_certificate(data) if cert: certs.append(cert) data = data[end:] return certs def register_sk_alg(sk_alg: int, handler: Type[SSHKey], *args: object) -> None: """Register a new security key algorithm""" _sk_alg_map[sk_alg] = handler, args def register_public_key_alg(algorithm: bytes, handler: Type[SSHKey], default: bool, sig_algorithms: Optional[Sequence[bytes]] = \ None) -> None: """Register a new public key algorithm""" if not sig_algorithms: sig_algorithms = handler.sig_algorithms _public_key_algs.extend(sig_algorithms) if default: _default_public_key_algs.extend(sig_algorithms) _public_key_alg_map[algorithm] = handler if handler.pem_name: _pem_map[handler.pem_name] = handler if handler.pkcs8_oid: # pragma: no branch _pkcs8_oid_map[handler.pkcs8_oid] = handler def register_certificate_alg(version: int, algorithm: bytes, cert_algorithm: bytes, key_handler: Type[SSHKey], cert_handler: Type[SSHOpenSSHCertificate], default: bool) -> None: """Register a new certificate algorithm""" _certificate_algs.append(cert_algorithm) if default: _default_certificate_algs.append(cert_algorithm) _certificate_alg_map[cert_algorithm] = (key_handler, cert_handler) _certificate_sig_alg_map[cert_algorithm] = algorithm _certificate_version_map[algorithm, version] = \ (cert_algorithm, cert_handler) def register_x509_certificate_alg(cert_algorithm: bytes, default: bool) -> None: """Register a new X.509 certificate algorithm""" if _x509_available: # pragma: no branch _x509_certificate_algs.append(cert_algorithm) if default: _default_x509_certificate_algs.append(cert_algorithm) _certificate_alg_map[cert_algorithm] = (None, SSHX509CertificateChain) def get_public_key_algs() -> List[bytes]: """Return supported public key algorithms""" return _public_key_algs def get_default_public_key_algs() -> List[bytes]: """Return default public key algorithms""" return _default_public_key_algs def get_certificate_algs() -> List[bytes]: """Return supported certificate-based public key algorithms""" return _certificate_algs def get_default_certificate_algs() -> List[bytes]: """Return default certificate-based public key algorithms""" return _default_certificate_algs def get_x509_certificate_algs() -> List[bytes]: """Return supported X.509 certificate-based public key algorithms""" return _x509_certificate_algs def get_default_x509_certificate_algs() -> List[bytes]: """Return default X.509 certificate-based public key algorithms""" return _default_x509_certificate_algs def decode_ssh_public_key(data: bytes) -> SSHKey: """Decode a packetized SSH public key""" try: packet = SSHPacket(data) alg = packet.get_string() handler = _public_key_alg_map.get(alg) if handler: key_params = handler.decode_ssh_public(packet) packet.check_end() key = handler.make_public(key_params) key.algorithm = alg return key else: raise KeyImportError('Unknown key algorithm: ' + alg.decode('ascii', errors='replace')) except PacketDecodeError: raise KeyImportError('Invalid public key') from None def decode_ssh_certificate(data: bytes, comment: _Comment = None) -> SSHCertificate: """Decode a packetized SSH certificate""" try: packet = SSHPacket(data) alg = packet.get_string() key_handler, cert_handler = _certificate_alg_map.get(alg, (None, None)) if cert_handler: return cert_handler.construct(packet, alg, key_handler, comment) else: raise KeyImportError('Unknown certificate algorithm: ' + alg.decode('ascii', errors='replace')) except (PacketDecodeError, ValueError): raise KeyImportError('Invalid OpenSSH certificate') from None def generate_private_key(alg_name: str, comment: _Comment = None, **kwargs) -> SSHKey: """Generate a new private key This function generates a new private key of a type matching the requested SSH algorithm. Depending on the algorithm, additional parameters can be passed which affect the generated key. Available algorithms include: ssh-dss, ssh-rsa, ecdsa-sha2-nistp256, ecdsa-sha2-nistp384, ecdsa-sha2-nistp521, ecdsa-sha2-1.3.132.0.10, ssh-ed25519, ssh-ed448, sk-ecdsa-sha2-nistp256\\@openssh.com, sk-ssh-ed25519\\@openssh.com For dss keys, no parameters are supported. The key size is fixed at 1024 bits due to the use of SHA1 signatures. For rsa keys, the key size can be specified using the `key_size` parameter, and the RSA public exponent can be changed using the `exponent` parameter. By default, generated keys are 2048 bits with a public exponent of 65537. For ecdsa keys, the curve to use is part of the SSH algorithm name and that determines the key size. No other parameters are supported. For ed25519 and ed448 keys, no parameters are supported. The key size is fixed by the algorithms at 256 bits and 448 bits, respectively. For sk keys, the application name to associate with the generated key can be specified using the `application` parameter. It defaults to `'ssh:'`. The user name to associate with the generated key can be specified using the `user` parameter. It defaults to `'AsyncSSH'`. When generating an sk key, a PIN can be provided via the `pin` parameter if the security key requires it. The `resident` parameter can be set to `True` to request that a resident key be created on the security key. This allows the key handle and public key information to later be retrieved so that the generated key can be used without having to store any information on the client system. It defaults to `False`. You can enable or disable the security key touch requirement by setting the `touch_required` parameter. It defaults to `True`, requiring that the user confirm their presence by touching the security key each time they use it to authenticate. :param alg_name: The SSH algorithm name corresponding to the desired type of key. :param comment: (optional) A comment to associate with this key. :param key_size: (optional) The key size in bits for RSA keys. :param exponent: (optional) The public exponent for RSA keys. :param application: (optional) The application name to associate with the generated SK key, defaulting to `'ssh:'`. :param user: (optional) The user name to associate with the generated SK key, defaulting to `'AsyncSSH'`. :param pin: (optional) The PIN to use to access the security key, defaulting to `None`. :param resident: (optional) Whether or not to create a resident key on the security key, defaulting to `False`. :param touch_required: (optional) Whether or not to require the user to touch the security key when authenticating with it, defaulting to `True`. :type alg_name: `str` :type comment: `str`, `bytes`, or `None` :type key_size: `int` :type exponent: `int` :type application: `str` :type user: `str` :type pin: `str` :type resident: `bool` :type touch_required: `bool` :returns: An :class:`SSHKey` private key :raises: :exc:`KeyGenerationError` if the requested key parameters are unsupported """ algorithm = alg_name.encode('utf-8') handler = _public_key_alg_map.get(algorithm) if handler: try: key = handler.generate(algorithm, **kwargs) except (TypeError, ValueError) as exc: raise KeyGenerationError(str(exc)) from None else: raise KeyGenerationError('Unknown algorithm: ' + alg_name) key.set_comment(comment) return key def import_private_key( data: BytesOrStr, passphrase: Optional[BytesOrStr] = None, unsafe_skip_rsa_key_validation: Optional[bool] = None) -> SSHKey: """Import a private key This function imports a private key encoded in PKCS#1 or PKCS#8 DER or PEM format or OpenSSH format. Encrypted private keys can be imported by specifying the passphrase needed to decrypt them. :param data: The data to import. :param passphrase: (optional) The passphrase to use to decrypt the key. :param unsafe_skip_rsa_key_validation: (optional) Whether or not to skip key validation when loading RSA private keys, defaulting to performing these checks unless changed by calling :func:`set_default_skip_rsa_key_validation`. :type data: `bytes` or ASCII `str` :type passphrase: `str` or `bytes` :type unsafe_skip_rsa_key_validation: bool :returns: An :class:`SSHKey` private key """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for key') from None key, _ = _decode_private(data, passphrase, unsafe_skip_rsa_key_validation) if key: return key else: raise KeyImportError('Invalid private key') def import_private_key_and_certs( data: bytes, passphrase: Optional[BytesOrStr] = None, unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \ Tuple[SSHKey, Optional[SSHX509CertificateChain]]: """Import a private key and optional certificate chain""" key, end = _decode_private(data, passphrase, unsafe_skip_rsa_key_validation) if key: return key, import_certificate_chain(data[end:]) else: raise KeyImportError('Invalid private key') def import_public_key(data: BytesOrStr) -> SSHKey: """Import a public key This function imports a public key encoded in OpenSSH, RFC4716, or PKCS#1 or PKCS#8 DER or PEM format. :param data: The data to import. :type data: `bytes` or ASCII `str` :returns: An :class:`SSHKey` public key """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for key') from None key, _ = _decode_public(data) if key: return key else: raise KeyImportError('Invalid public key') def import_certificate(data: BytesOrStr) -> SSHCertificate: """Import a certificate This function imports an SSH certificate in DER, PEM, OpenSSH, or RFC4716 format. :param data: The data to import. :type data: `bytes` or ASCII `str` :returns: An :class:`SSHCertificate` object """ if isinstance(data, str): try: data = data.encode('ascii') except UnicodeEncodeError: raise KeyImportError('Invalid encoding for key') from None cert, _ = _decode_certificate(data) if cert: return cert else: raise KeyImportError('Invalid certificate') def import_certificate_chain(data: bytes) -> Optional[SSHX509CertificateChain]: """Import an X.509 certificate chain""" certs = _decode_certificate_list(data) chain: Optional[SSHX509CertificateChain] if certs: chain = SSHX509CertificateChain.construct_from_certs(certs) else: chain = None return chain def import_certificate_subject(data: str) -> str: """Import an X.509 certificate subject name""" try: algorithm, data = data.strip().split(None, 1) except ValueError: raise KeyImportError('Missing certificate subject algorithm') from None if algorithm.startswith('x509v3-'): match = _subject_pattern.match(data) if match: return data[match.end():] raise KeyImportError('Invalid certificate subject') def read_private_key( filename: FilePath, passphrase: Optional[BytesOrStr] = None, unsafe_skip_rsa_key_validation: Optional[bool] = None) -> SSHKey: """Read a private key from a file This function reads a private key from a file. See the function :func:`import_private_key` for information about the formats supported. :param filename: The file to read the key from. :param passphrase: (optional) The passphrase to use to decrypt the key. :param unsafe_skip_rsa_key_validation: (optional) Whether or not to skip key validation when loading RSA private keys, defaulting to performing these checks unless changed by calling :func:`set_default_skip_rsa_key_validation`. :type filename: :class:`PurePath ` or `str` :type passphrase: `str` or `bytes` :type unsafe_skip_rsa_key_validation: bool :returns: An :class:`SSHKey` private key """ key = import_private_key(read_file(filename), passphrase, unsafe_skip_rsa_key_validation) key.set_filename(filename) return key def read_private_key_and_certs( filename: FilePath, passphrase: Optional[BytesOrStr] = None, unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \ Tuple[SSHKey, Optional[SSHX509CertificateChain]]: """Read a private key and optional certificate chain from a file""" key, cert = import_private_key_and_certs(read_file(filename), passphrase, unsafe_skip_rsa_key_validation) key.set_filename(filename) return key, cert def read_public_key(filename: FilePath) -> SSHKey: """Read a public key from a file This function reads a public key from a file. See the function :func:`import_public_key` for information about the formats supported. :param filename: The file to read the key from. :type filename: :class:`PurePath ` or `str` :returns: An :class:`SSHKey` public key """ key = import_public_key(read_file(filename)) key.set_filename(filename) return key def read_certificate(filename: FilePath) -> SSHCertificate: """Read a certificate from a file This function reads an SSH certificate from a file. See the function :func:`import_certificate` for information about the formats supported. :param filename: The file to read the certificate from. :type filename: :class:`PurePath ` or `str` :returns: An :class:`SSHCertificate` object """ return import_certificate(read_file(filename)) def read_private_key_list( filename: FilePath, passphrase: Optional[BytesOrStr] = None, unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \ Sequence[SSHKey]: """Read a list of private keys from a file This function reads a list of private keys from a file. See the function :func:`import_private_key` for information about the formats supported. If any of the keys are encrypted, they must all be encrypted with the same passphrase. :param filename: The file to read the keys from. :param passphrase: (optional) The passphrase to use to decrypt the keys. :param unsafe_skip_rsa_key_validation: (optional) Whether or not to skip key validation when loading RSA private keys, defaulting to performing these checks unless changed by calling :func:`set_default_skip_rsa_key_validation`. :type filename: :class:`PurePath ` or `str` :type passphrase: `str` or `bytes` :type unsafe_skip_rsa_key_validation: bool :returns: A list of :class:`SSHKey` private keys """ keys = _decode_private_list(read_file(filename), passphrase, unsafe_skip_rsa_key_validation) for key in keys: key.set_filename(filename) return keys def read_public_key_list(filename: FilePath) -> Sequence[SSHKey]: """Read a list of public keys from a file This function reads a list of public keys from a file. See the function :func:`import_public_key` for information about the formats supported. :param filename: The file to read the keys from. :type filename: :class:`PurePath ` or `str` :returns: A list of :class:`SSHKey` public keys """ keys = _decode_public_list(read_file(filename)) for key in keys: key.set_filename(filename) return keys def read_certificate_list(filename: FilePath) -> Sequence[SSHCertificate]: """Read a list of certificates from a file This function reads a list of SSH certificates from a file. See the function :func:`import_certificate` for information about the formats supported. :param filename: The file to read the certificates from. :type filename: :class:`PurePath ` or `str` :returns: A list of :class:`SSHCertificate` certificates """ return _decode_certificate_list(read_file(filename)) def load_keypairs( keylist: KeyPairListArg, passphrase: Optional[BytesOrStr] = None, certlist: CertListArg = (), skip_public: bool = False, ignore_encrypted: bool = False, unsafe_skip_rsa_key_validation: Optional[bool] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> \ Sequence[SSHKeyPair]: """Load SSH private keys and optional matching certificates This function loads a list of SSH keys and optional matching certificates. When certificates are specified, the private key is added to the list both with and without the certificate. :param keylist: The list of private keys and certificates to load. :param passphrase: (optional) The passphrase to use to decrypt the keys, or a `callable` which takes a filename and returns the passphrase to decrypt it. :param certlist: (optional) A list of certificates to attempt to pair with the provided list of private keys. :param skip_public: (optional) An internal parameter used to skip public keys and certificates when IdentitiesOnly and IdentityFile are used to specify a mixture of private and public keys. :param unsafe_skip_rsa_key_validation: (optional) Whether or not to skip key validation when loading RSA private keys, defaulting to performing these checks unless changed by calling :func:`set_default_skip_rsa_key_validation`. :type keylist: *see* :ref:`SpecifyingPrivateKeys` :type passphrase: `str` or `bytes` :type certlist: *see* :ref:`SpecifyingCertificates` :type skip_public: `bool` :type unsafe_skip_rsa_key_validation: bool :returns: A list of :class:`SSHKeyPair` objects """ keys_to_load: Sequence[_KeyPairArg] result: List[SSHKeyPair] = [] certlist = load_certificates(certlist) certdict = {cert.key.public_data: cert for cert in certlist} if isinstance(keylist, (PurePath, str)): try: if callable(passphrase): resolved_passphrase = passphrase(str(keylist)) else: resolved_passphrase = passphrase if loop and inspect.isawaitable(resolved_passphrase): resolved_passphrase = asyncio.run_coroutine_threadsafe( resolved_passphrase, loop).result() priv_keys = read_private_key_list(keylist, resolved_passphrase, unsafe_skip_rsa_key_validation) if len(priv_keys) <= 1: keys_to_load = [keylist] passphrase = resolved_passphrase else: keys_to_load = priv_keys except KeyImportError: keys_to_load = [keylist] elif isinstance(keylist, (tuple, bytes, SSHKey, SSHKeyPair)): keys_to_load = [cast(_KeyPairArg, keylist)] else: keys_to_load = keylist if keylist else [] for key_to_load in keys_to_load: allow_certs = False key_prefix = None saved_exc = None pubkey_or_certs = None pubkey_to_load: Optional[_KeyArg] = None certs_to_load: Optional[_CertArg] = None key: Union['SSHKey', 'SSHKeyPair'] if isinstance(key_to_load, (PurePath, str, bytes)): allow_certs = True elif isinstance(key_to_load, tuple): key_to_load, pubkey_or_certs = key_to_load try: if isinstance(key_to_load, (PurePath, str)): key_prefix = str(key_to_load) if callable(passphrase): resolved_passphrase = passphrase(key_prefix) else: resolved_passphrase = passphrase if loop and inspect.isawaitable(resolved_passphrase): resolved_passphrase = asyncio.run_coroutine_threadsafe( resolved_passphrase, loop).result() if allow_certs: key, certs_to_load = read_private_key_and_certs( key_to_load, resolved_passphrase, unsafe_skip_rsa_key_validation) if not certs_to_load: certs_to_load = key_prefix + '-cert.pub' else: key = read_private_key(key_to_load, resolved_passphrase, unsafe_skip_rsa_key_validation) pubkey_to_load = key_prefix + '.pub' elif isinstance(key_to_load, bytes): if allow_certs: key, certs_to_load = import_private_key_and_certs( key_to_load, passphrase, unsafe_skip_rsa_key_validation) else: key = import_private_key(key_to_load, passphrase, unsafe_skip_rsa_key_validation) else: key = key_to_load except KeyImportError as exc: if skip_public or \ (ignore_encrypted and str(exc).startswith('Passphrase')): continue raise certs: Optional[Sequence[SSHCertificate]] if pubkey_or_certs: try: certs = load_certificates(pubkey_or_certs) except (TypeError, OSError, KeyImportError) as exc: saved_exc = exc certs = None if not certs: pubkey_to_load = cast(_KeyArg, pubkey_or_certs) elif certs_to_load: try: certs = load_certificates(certs_to_load) except (OSError, KeyImportError): certs = None else: certs = None pubkey: Optional[SSHKey] if pubkey_to_load: try: if isinstance(pubkey_to_load, (PurePath, str)): pubkey = read_public_key(pubkey_to_load) elif isinstance(pubkey_to_load, bytes): pubkey = import_public_key(pubkey_to_load) else: pubkey = pubkey_to_load except (OSError, KeyImportError): pubkey = None else: saved_exc = None else: pubkey = None if saved_exc: raise saved_exc # pylint: disable=raising-bad-type if not certs: if isinstance(key, SSHKeyPair): pubdata = key.key_public_data else: pubdata = key.public_data cert = certdict.get(pubdata) if cert and cert.is_x509: cert = SSHX509CertificateChain.construct_from_certs(certlist) elif len(certs) == 1 and not certs[0].is_x509: cert = certs[0] else: cert = SSHX509CertificateChain.construct_from_certs(certs) if isinstance(key, SSHKeyPair): if cert: key.set_certificate(cert) result.append(key) else: if cert: result.append(SSHLocalKeyPair(key, pubkey, cert)) result.append(SSHLocalKeyPair(key, pubkey)) return result def load_default_keypairs(passphrase: Optional[BytesOrStr] = None, certlist: CertListArg = ()) -> \ Sequence[SSHKeyPair]: """Return a list of default keys from the user's home directory""" result: List[SSHKeyPair] = [] for file, condition in _DEFAULT_KEY_FILES: if condition: # pragma: no branch try: path = Path('~', '.ssh', file).expanduser() result.extend(load_keypairs(path, passphrase, certlist, ignore_encrypted=True)) except OSError: pass return result def load_public_keys(keylist: KeyListArg) -> Sequence[SSHKey]: """Load public keys This function loads a list of SSH public keys. :param keylist: The list of public keys to load. :type keylist: *see* :ref:`SpecifyingPublicKeys` :returns: A list of :class:`SSHKey` objects """ if isinstance(keylist, (PurePath, str)): return read_public_key_list(keylist) else: result: List[SSHKey] = [] for key in keylist: if isinstance(key, (PurePath, str)): key = read_public_key(key) elif isinstance(key, bytes): key = import_public_key(key) result.append(key) return result def load_default_host_public_keys() -> Sequence[Union[SSHKey, SSHCertificate]]: """Return a list of default host public keys or certificates""" result: List[Union[SSHKey, SSHCertificate]] = [] for host_key_dir in _DEFAULT_HOST_KEY_DIRS: for file in _DEFAULT_HOST_KEY_FILES: try: cert = read_certificate(Path(host_key_dir, file + '-cert.pub')) except (OSError, KeyImportError): pass else: result.append(cert) for host_key_dir in _DEFAULT_HOST_KEY_DIRS: for file in _DEFAULT_HOST_KEY_FILES: try: key = read_public_key(Path(host_key_dir, file + '.pub')) except (OSError, KeyImportError): pass else: result.append(key) return result def load_certificates(certlist: CertListArg) -> Sequence[SSHCertificate]: """Load certificates This function loads a list of OpenSSH or X.509 certificates. :param certlist: The list of certificates to load. :type certlist: *see* :ref:`SpecifyingCertificates` :returns: A list of :class:`SSHCertificate` objects """ if isinstance(certlist, SSHCertificate): return [certlist] elif isinstance(certlist, (PurePath, str, bytes)): certlist = [certlist] result: List[SSHCertificate] = [] for cert in certlist: if isinstance(cert, (PurePath, str)): certs = read_certificate_list(cert) elif isinstance(cert, bytes): certs = _decode_certificate_list(cert) elif isinstance(cert, SSHCertificate): certs = [cert] else: certs = cert result.extend(certs) return result def load_identities(keylist: IdentityListArg, skip_private: bool = False) -> Sequence[bytes]: """Load public key and certificate identities""" if isinstance(keylist, (bytes, str, PurePath, SSHKey, SSHCertificate)): identities: Sequence[_IdentityArg] = [keylist] else: identities = keylist result = [] for identity in identities: if isinstance(identity, (PurePath, str)): try: pubdata = read_certificate(identity).public_data except KeyImportError: try: pubdata = read_public_key(identity).public_data except KeyImportError: if skip_private: continue raise elif isinstance(identity, (SSHKey, SSHCertificate)): pubdata = identity.public_data else: pubdata = identity result.append(pubdata) return result def load_default_identities() -> Sequence[bytes]: """Return a list of default public key and certificate identities""" result: List[bytes] = [] for file, condition in _DEFAULT_KEY_FILES: if condition: # pragma: no branch try: cert = read_certificate(Path('~', '.ssh', file + '-cert.pub')) except (OSError, KeyImportError): pass else: result.append(cert.public_data) try: key = read_public_key(Path('~', '.ssh', file + '.pub')) except (OSError, KeyImportError): pass else: result.append(key.public_data) return result def load_resident_keys(pin: str, *, application: str = 'ssh:', user: Optional[str] = None, touch_required: bool = True) -> Sequence[SSHKey]: """Load keys resident on attached FIDO2 security keys This function loads keys resident on any FIDO2 security keys currently attached to the system. The user name associated with each key is returned in the key's comment field. :param pin: The PIN to use to access the security keys, defaulting to `None`. :param application: (optional) The application name associated with the keys to load, defaulting to `'ssh:'`. :param user: (optional) The user name associated with the keys to load. By default, keys for all users are loaded. :param touch_required: (optional) Whether or not to require the user to touch the security key when authenticating with it, defaulting to `True`. :type application: `str` :type user: `str` :type pin: `str` :type touch_required: `bool` """ flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 reserved = b'' try: resident_keys = sk_get_resident(application, user, pin) except ValueError as exc: raise KeyImportError(str(exc)) from None result: List[SSHKey] = [] for sk_alg, name, public_value, key_handle in resident_keys: handler, key_params = _sk_alg_map[sk_alg] key_params += (public_value, application, flags, key_handle, reserved) key = handler.make_private(key_params) key.set_comment(name) result.append(key) return result asyncssh-2.20.0/asyncssh/py.typed000066400000000000000000000000001475467777400170020ustar00rootroot00000000000000asyncssh-2.20.0/asyncssh/rsa.py000066400000000000000000000263221475467777400164610ustar00rootroot00000000000000# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """RSA public key encryption handler""" from typing import Optional, Tuple, Union, cast from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode from .crypto import RSAPrivateKey, RSAPublicKey from .misc import all_ints from .packet import MPInt, String, SSHPacket from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_x509_certificate_alg _hash_algs = {b'ssh-rsa': 'sha1', b'rsa-sha2-256': 'sha256', b'rsa-sha2-512': 'sha512', b'ssh-rsa-sha224@ssh.com': 'sha224', b'ssh-rsa-sha256@ssh.com': 'sha256', b'ssh-rsa-sha384@ssh.com': 'sha384', b'ssh-rsa-sha512@ssh.com': 'sha512', b'rsa1024-sha1': 'sha1', b'rsa2048-sha256': 'sha256'} _PrivateKeyArgs = Tuple[int, int, int, int, int, int, int, int] _PrivateKeyConstructArgs = Tuple[int, int, int, int, int, int, int, int, bool] _PublicKeyArgs = Tuple[int, int] _default_skip_rsa_key_validation = False def set_default_skip_rsa_key_validation(skip_validation: bool) -> None: """Set whether to disable RSA key validation in OpenSSL OpenSSL 3.x does additional validation when loading RSA keys as an added security measure. However, the result is that loading a key can take significantly longer than it did before. If all your RSA keys are coming from a trusted source, you can call this function with a value of `True` to default to skipping these checks on RSA keys, reducing the cost back down to what it was in earlier releases. This can also be set on a case by case basis by using the new `unsafe_skip_rsa_key_validation` argument on the functions used to load keys. This will only affect loading keys of type RSA. .. note:: The extra cost only applies to loading existing keys, and not to generating new keys. Also, in cases where a key is used repeatedly, it can be loaded once into an `SSHKey` object and reused without having to pay the cost each time. So, this call should not be needed in most applications. If an application does need this, it is strongly recommended that the `unsafe_skip_rsa_key_validation` argument be used rather than using this function to change the default behavior for all load operations. """ # pylint: disable=global-statement global _default_skip_rsa_key_validation _default_skip_rsa_key_validation = skip_validation class RSAKey(SSHKey): """Handler for RSA public key encryption""" _key: Union[RSAPrivateKey, RSAPublicKey] algorithm = b'ssh-rsa' default_x509_hash = 'sha256' pem_name = b'RSA' pkcs8_oid = ObjectIdentifier('1.2.840.113549.1.1.1') sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa-sha224@ssh.com', b'ssh-rsa-sha256@ssh.com', b'ssh-rsa-sha384@ssh.com', b'ssh-rsa-sha512@ssh.com', b'ssh-rsa') cert_sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa') cert_algorithms = tuple(alg + b'-cert-v01@openssh.com' for alg in cert_sig_algorithms) x509_sig_algorithms = (b'rsa2048-sha256', b'ssh-rsa') x509_algorithms = tuple(b'x509v3-' + alg for alg in x509_sig_algorithms) all_sig_algorithms = set(x509_sig_algorithms + sig_algorithms) def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are RSAKey instances # pylint: disable=protected-access if not isinstance(other, RSAKey): return NotImplemented return (self._key.n == other._key.n and self._key.e == other._key.e and self._key.d == other._key.d) def __hash__(self) -> int: return hash((self._key.n, self._key.e, self._key.d, self._key.p, self._key.q)) @classmethod def generate(cls, algorithm: bytes, *, # type: ignore key_size: int = 2048, exponent: int = 65537) -> 'RSAKey': """Generate a new RSA private key""" # pylint: disable=arguments-differ,unused-argument return cls(RSAPrivateKey.generate(key_size, exponent)) @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct an RSA private key""" n, e, d, p, q, dmp1, dmq1, iqmp, unsafe_skip_rsa_key_validation = \ cast(_PrivateKeyConstructArgs, key_params) if unsafe_skip_rsa_key_validation is None: unsafe_skip_rsa_key_validation = _default_skip_rsa_key_validation return cls(RSAPrivateKey.construct(n, e, d, p, q, dmp1, dmq1, iqmp, unsafe_skip_rsa_key_validation)) @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct an RSA public key""" n, e = cast(_PublicKeyArgs, key_params) return cls(RSAPublicKey.construct(n, e)) @classmethod def decode_pkcs1_private(cls, key_data: object) -> \ Optional[_PrivateKeyArgs]: """Decode a PKCS#1 format RSA private key""" if (isinstance(key_data, tuple) and all_ints(key_data) and len(key_data) >= 9): return cast(_PrivateKeyArgs, key_data[1:9]) else: return None @classmethod def decode_pkcs1_public(cls, key_data: object) -> \ Optional[_PublicKeyArgs]: """Decode a PKCS#1 format RSA public key""" if (isinstance(key_data, tuple) and all_ints(key_data) and len(key_data) == 2): return cast(_PublicKeyArgs, key_data) else: return None @classmethod def decode_pkcs8_private(cls, alg_params: object, data: bytes) -> Optional[_PrivateKeyArgs]: """Decode a PKCS#8 format RSA private key""" if alg_params is not None: return None try: key_data = der_decode(data) except ASN1DecodeError: return None return cls.decode_pkcs1_private(key_data) @classmethod def decode_pkcs8_public(cls, alg_params: object, data: bytes) -> Optional[_PublicKeyArgs]: """Decode a PKCS#8 format RSA public key""" if alg_params is not None: return None try: key_data = der_decode(data) except ASN1DecodeError: return None return cls.decode_pkcs1_public(key_data) @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format RSA private key""" n = packet.get_mpint() e = packet.get_mpint() d = packet.get_mpint() iqmp = packet.get_mpint() p = packet.get_mpint() q = packet.get_mpint() return n, e, d, p, q, d % (p-1), d % (q-1), iqmp @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format RSA public key""" e = packet.get_mpint() n = packet.get_mpint() return n, e def encode_pkcs1_private(self) -> object: """Encode a PKCS#1 format RSA private key""" if not self._key.d: raise KeyExportError('Key is not private') return (0, self._key.n, self._key.e, self._key.d, self._key.p, self._key.q, self._key.dmp1, self._key.dmq1, self._key.iqmp) def encode_pkcs1_public(self) -> object: """Encode a PKCS#1 format RSA public key""" return self._key.n, self._key.e def encode_pkcs8_private(self) -> Tuple[object, object]: """Encode a PKCS#8 format RSA private key""" return None, der_encode(self.encode_pkcs1_private()) def encode_pkcs8_public(self) -> Tuple[object, object]: """Encode a PKCS#8 format RSA public key""" return None, der_encode(self.encode_pkcs1_public()) def encode_ssh_private(self) -> bytes: """Encode an SSH format RSA private key""" if not self._key.d: raise KeyExportError('Key is not private') assert self._key.iqmp is not None assert self._key.p is not None assert self._key.q is not None return b''.join((MPInt(self._key.n), MPInt(self._key.e), MPInt(self._key.d), MPInt(self._key.iqmp), MPInt(self._key.p), MPInt(self._key.q))) def encode_ssh_public(self) -> bytes: """Encode an SSH format RSA public key""" return b''.join((MPInt(self._key.e), MPInt(self._key.n))) def encode_agent_cert_private(self) -> bytes: """Encode RSA certificate private key data for agent""" if not self._key.d: raise KeyExportError('Key is not private') assert self._key.iqmp is not None assert self._key.p is not None assert self._key.q is not None return b''.join((MPInt(self._key.d), MPInt(self._key.iqmp), MPInt(self._key.p), MPInt(self._key.q))) def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" if not self._key.d: raise ValueError('Private key needed for signing') return String(self._key.sign(data, _hash_algs[sig_algorithm])) def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" sig = packet.get_string() packet.check_end() return self._key.verify(data, sig, _hash_algs[sig_algorithm]) def encrypt(self, data: bytes, algorithm: bytes) -> Optional[bytes]: """Encrypt a block of data with this key""" pub_key = cast(RSAPublicKey, self._key) return pub_key.encrypt(data, _hash_algs[algorithm]) def decrypt(self, data: bytes, algorithm: bytes) -> Optional[bytes]: """Decrypt a block of data with this key""" priv_key = cast(RSAPrivateKey, self._key) return priv_key.decrypt(data, _hash_algs[algorithm]) register_public_key_alg(b'ssh-rsa', RSAKey, True) for _alg in RSAKey.cert_sig_algorithms: register_certificate_alg(1, _alg, _alg + b'-cert-v01@openssh.com', RSAKey, SSHOpenSSHCertificateV01, True) for _alg in RSAKey.x509_algorithms: register_x509_certificate_alg(_alg, True) asyncssh-2.20.0/asyncssh/saslprep.py000066400000000000000000000071301475467777400175210ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SASLprep implementation This module implements the stringprep algorithm defined in RFC 3454 and the SASLprep profile of stringprep defined in RFC 4013. This profile is used to normalize usernames and passwords sent in the SSH protocol. """ # The stringprep module should not be flagged as deprecated # pylint: disable=deprecated-module import stringprep # pylint: enable=deprecated-module import unicodedata from typing import Callable, Optional, Sequence from typing_extensions import Literal class SASLPrepError(ValueError): """Invalid data provided to saslprep""" def _check_bidi(s: str) -> None: """Enforce bidirectional character check from RFC 3454 (stringprep)""" r_and_al_cat = False l_cat = False for c in s: if not r_and_al_cat and stringprep.in_table_d1(c): r_and_al_cat = True if not l_cat and stringprep.in_table_d2(c): l_cat = True if r_and_al_cat and l_cat: raise SASLPrepError('Both RandALCat and LCat characters present') if r_and_al_cat and not (stringprep.in_table_d1(s[0]) and stringprep.in_table_d1(s[-1])): raise SASLPrepError('RandALCat character not at both start and end') def _stringprep(s: str, check_unassigned: bool, mapping: Optional[Callable[[str], str]], normalization: Literal['NFC', 'NFD', 'NFKC', 'NFKD'], prohibited: Sequence[Callable[[str], bool]], bidi: bool) -> str: """Implement a stringprep profile as defined in RFC 3454""" if check_unassigned: # pragma: no branch for c in s: if stringprep.in_table_a1(c): raise SASLPrepError(f'Unassigned character: {c!r}') if mapping: # pragma: no branch s = mapping(s) if normalization: # pragma: no branch s = unicodedata.normalize(normalization, s) if prohibited: # pragma: no branch for c in s: for lookup in prohibited: if lookup(c): raise SASLPrepError(f'Prohibited character: {c!r}') if bidi: # pragma: no branch _check_bidi(s) return s def _map_saslprep(s: str) -> str: """Map stringprep table B.1 to nothing and C.1.2 to ASCII space""" r = [] for c in s: if stringprep.in_table_c12(c): r.append(' ') elif not stringprep.in_table_b1(c): r.append(c) return ''.join(r) def saslprep(s: str) -> str: """Implement SASLprep profile defined in RFC 4013""" prohibited = (stringprep.in_table_c12, stringprep.in_table_c21_c22, stringprep.in_table_c3, stringprep.in_table_c4, stringprep.in_table_c5, stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, stringprep.in_table_c9) return _stringprep(s, True, _map_saslprep, 'NFKC', prohibited, True) asyncssh-2.20.0/asyncssh/scp.py000066400000000000000000001131111475467777400164520ustar00rootroot00000000000000# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Jonathan Slenders - proposed changes to allow SFTP server callbacks # to be coroutines """SCP handlers""" import argparse import asyncio import posixpath from pathlib import PurePath import shlex import string import sys from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, List, NoReturn, Optional from typing import Sequence, Tuple, Type, Union, cast from typing_extensions import Protocol, Self from .constants import DEFAULT_LANG from .constants import FILEXFER_TYPE_REGULAR, FILEXFER_TYPE_DIRECTORY from .logging import SSHLogger from .misc import BytesOrStr, FilePath, HostPort, MaybeAwait from .misc import async_context_manager, plural from .sftp import SFTPAttrs, SFTPGlob, SFTPName, SFTPServer, SFTPServerFS from .sftp import SFTPFileProtocol, SFTPError, SFTPFailure, SFTPBadMessage from .sftp import SFTPConnectionLost, SFTPErrorHandler, SFTPProgressHandler from .sftp import local_fs if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHServerChannel from .connection import SSHClientConnection from .stream import SSHReader, SSHWriter _SCPConn = Union[None, bytes, str, HostPort, 'SSHClientConnection'] _SCPPath = Union[bytes, FilePath] _SCPConnPath = Union[Tuple[_SCPConn, _SCPPath], _SCPConn, _SCPPath] _SCP_BLOCK_SIZE = 256*1024 # 256 KiB class _SCPFSProtocol(Protocol): """Protocol for accessing a filesystem during an SCP copy""" @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" async def stat(self, path: bytes) -> 'SFTPAttrs': """Get attributes of a file or directory, following symlinks""" async def setstat(self, path: bytes, attrs: 'SFTPAttrs') -> None: """Set attributes of a file or directory""" async def exists(self, path: bytes) -> bool: """Return if a path exists""" async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: """Read the names and attributes of files in a directory""" async def mkdir(self, path: bytes) -> None: """Create a directory""" @async_context_manager async def open(self, path: bytes, mode: str) -> SFTPFileProtocol: """Open a file""" def _scp_error(exc_class: Type[Exception], reason: BytesOrStr, path: Optional[bytes] = None, fatal: bool = False, suppress_send: bool = False, lang: str = DEFAULT_LANG) -> Exception: """Construct SCP version of SFTPError exception""" if isinstance(reason, bytes): reason = reason.decode('utf-8', errors='replace') if path: reason = reason + ': ' + path.decode('utf-8', errors='replace') exc = exc_class(reason, lang) setattr(exc, 'fatal', fatal) setattr(exc, 'suppress_send', suppress_send) return exc def _parse_cd_args(args: bytes) -> Tuple[int, int, bytes]: """Parse arguments to an SCP copy or dir request""" try: permissions, size, name = args.split(None, 2) return int(permissions, 8), int(size), name except ValueError: raise _scp_error(SFTPBadMessage, 'Invalid copy or dir request') from None def _parse_t_args(args: bytes) -> Tuple[int, int]: """Parse argument to an SCP time request""" try: mtime, _, atime, _ = args.split() return int(atime), int(mtime) except ValueError: raise _scp_error(SFTPBadMessage, 'Invalid time request') from None async def _parse_path(path: _SCPConnPath, **kwargs) -> \ Tuple[Optional['SSHClientConnection'], _SCPPath, bool]: """Convert an SCP path into an SSHClientConnection and path""" # pylint: disable=cyclic-import,import-outside-toplevel from . import connect conn: _SCPConn if isinstance(path, tuple): conn, path = cast(Tuple[_SCPConn, _SCPPath], path) elif isinstance(path, str) and sys.platform == 'win32' and \ path[:1] in string.ascii_letters and \ path[1:2] == ':': # pragma: no cover (win32) conn = None elif isinstance(path, str) and ':' in path: conn, path = path.split(':', 1) elif isinstance(path, bytes) and b':' in path: conn, path = path.split(b':', 1) conn = conn.decode('utf-8') elif isinstance(path, (bytes, str, PurePath)): conn = None else: conn = path path = b'.' if isinstance(conn, str): close_conn = True conn = await connect(conn, **kwargs) elif isinstance(conn, tuple): close_conn = True conn = await connect(*conn, **kwargs) else: close_conn = False return (cast(Optional['SSHClientConnection'], conn), cast(_SCPPath, path), close_conn) async def _start_remote(conn: 'SSHClientConnection', source: bool, must_be_dir: bool, preserve: bool, recurse: bool, path: _SCPPath) -> \ Tuple['SSHReader[bytes]', 'SSHWriter[bytes]']: """Start remote SCP server""" if isinstance(path, PurePath): path = str(path) if isinstance(path, str): path = path.encode('utf-8') command = (b'scp ' + (b'-f ' if source else b'-t ') + (b'-d ' if must_be_dir else b'') + (b'-p ' if preserve else b'') + (b'-r ' if recurse else b'') + path) conn.logger.get_child('sftp').info('Starting remote SCP, args: %s', command[4:]) writer, reader, _ = await conn.open_session(command, encoding=None) return reader, writer class _SCPArgs(argparse.Namespace): """SCP command line arguments""" path: str source: bool must_be_dir: bool preserve: bool recurse: bool class _SCPArgParser(argparse.ArgumentParser): """A parser for SCP arguments""" def __init__(self) -> None: super().__init__(add_help=False) group = self.add_mutually_exclusive_group(required=True) group.add_argument('-f', dest='source', action='store_true') group.add_argument('-t', dest='source', action='store_false') self.add_argument('-d', dest='must_be_dir', action='store_true') self.add_argument('-p', dest='preserve', action='store_true') self.add_argument('-r', dest='recurse', action='store_true') self.add_argument('-v', dest='verbose', action='store_true') self.add_argument('path') def error(self, message: str) -> NoReturn: raise ValueError(message) def parse(self, command: str) -> _SCPArgs: """Parse an SCP command""" return self.parse_args(shlex.split(command)[1:], namespace=_SCPArgs()) class _SCPHandler: """SCP handler""" def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', error_handler: SFTPErrorHandler = None, server: bool = False): self._reader = reader self._writer = writer self._error_handler = error_handler self._server = server self._logger = reader.logger.get_child('sftp') async def __aenter__(self) -> Self: # pragma: no cover """Allow _SCPHandler to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> \ bool: # pragma: no cover """Wait for file close when used as an async context manager""" await self.close() return False @property def logger(self) -> SSHLogger: """A logger associated with this SCP handler""" return self._logger async def await_response(self) -> Optional[Exception]: """Wait for an SCP response""" result = await self._reader.read(1) if result != b'\0': reason = await self._reader.readline() if not result or not reason.endswith(b'\n'): raise _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True, suppress_send=True) if result not in b'\x01\x02': reason = result + reason return _scp_error(SFTPFailure, reason[:-1], fatal=result != b'\x01', suppress_send=True) self.logger.debug1('Received SCP OK') return None def send_request(self, *args: bytes) -> None: """Send an SCP request""" request = b''.join(args) self.logger.debug1('Sending SCP request: %s', request) self._writer.write(request + b'\n') async def make_request(self, *args: bytes) -> None: """Send an SCP request and wait for a response""" self.send_request(*args) exc = await self.await_response() if exc: raise exc async def send_data(self, data: bytes) -> None: """Send SCP file data""" self.logger.debug1('Sending %s', plural(len(data), 'SCP data byte')) self._writer.write(data) await self._writer.drain() await asyncio.sleep(0) def send_ok(self) -> None: """Send an SCP OK response""" self.logger.debug1('Sending SCP OK') self._writer.write(b'\0') def send_error(self, exc: Exception) -> None: """Send an SCP error response""" if isinstance(exc, SFTPError): reason = exc.reason.encode('utf-8') elif isinstance(exc, OSError): # pragma: no branch (win32) reason = exc.strerror.encode('utf-8') filename = cast(BytesOrStr, exc.filename) if filename: if isinstance(filename, str): # pragma: no cover (win32) filename = filename.encode('utf-8') reason += b': ' + filename else: # pragma: no cover (win32) reason = str(exc).encode('utf-8') fatal = cast(bool, getattr(exc, 'fatal', False)) self.logger.debug1('Sending SCP %serror: %s', 'fatal ' if fatal else '', reason) self._writer.write((b'\x02' if fatal else b'\x01') + b'scp: ' + reason + b'\n') async def recv_request(self) -> Tuple[Optional[bytes], Optional[bytes]]: """Receive SCP request""" request = await self._reader.readline() if not request: return None, None action, args = request[:1], request[1:-1] if action not in b'\x01\x02': self.logger.debug1('Received SCP request: %s%s', action, args) else: self.logger.debug1('Received SCP %serror: %s', 'fatal ' if action != b'\x01' else '', args) return action, args async def recv_data(self, n: int) -> bytes: """Receive SCP file data""" data = await self._reader.read(n) self.logger.debug1('Received %s', plural(len(data), 'SCP data byte')) return data def handle_error(self, exc: Exception) -> None: """Handle an SCP error""" if isinstance(exc, BrokenPipeError): exc = _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True, suppress_send=True) if not getattr(exc, 'suppress_send', False): self.send_error(exc) self.logger.debug1('Handling SCP error: %s', str(exc)) if self._error_handler and not getattr(exc, 'fatal', False): self._error_handler(exc) elif not self._server: raise exc async def close(self, cancelled: bool = False) -> None: """Close an SCP session""" self.logger.info('Stopping remote SCP') if cancelled: self._writer.channel.abort() else: if self._server: cast('SSHServerChannel', self._writer.channel).exit(0) else: self._writer.close() await self._writer.wait_closed() class _SCPSource(_SCPHandler): """SCP handler for sending files""" def __init__(self, fs: _SCPFSProtocol, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', preserve: bool, recurse: bool, block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, server: bool = False): super().__init__(reader, writer, error_handler, server) self._fs = fs self._preserve = preserve self._recurse = recurse self._block_size = block_size self._progress_handler = progress_handler async def _make_cd_request(self, action: bytes, attrs: SFTPAttrs, size: int, path: bytes) -> None: """Make an SCP copy or dir request""" assert attrs.permissions is not None args = f'{attrs.permissions & 0o7777:04o} {size} ' await self.make_request(action, args.encode('ascii'), self._fs.basename(path)) async def _make_t_request(self, attrs: SFTPAttrs) -> None: """Make an SCP time request""" self.logger.info(' Preserving attrs: %s', SFTPAttrs(atime=attrs.atime, mtime=attrs.mtime)) assert attrs.mtime is not None assert attrs.atime is not None args = f'{attrs.mtime} 0 {attrs.atime} 0' await self.make_request(b'T', args.encode('ascii')) async def _send_file(self, srcpath: bytes, dstpath: bytes, attrs: SFTPAttrs) -> None: """Send a file over SCP""" assert attrs.size is not None file_obj = await self._fs.open(srcpath, 'rb') size = attrs.size local_exc = None offset = 0 self.logger.info(' Sending file %s, size %d', srcpath, size) try: await self._make_cd_request(b'C', attrs, size, srcpath) if self._progress_handler and size == 0: self._progress_handler(srcpath, dstpath, 0, 0) while offset < size: blocklen = min(size - offset, self._block_size) if local_exc: data = blocklen * b'\0' else: try: data = cast(bytes, await file_obj.read(blocklen, offset)) if not data: raise _scp_error(SFTPFailure, 'Unexpected EOF') except (OSError, SFTPError) as exc: local_exc = exc await self.send_data(data) offset += len(data) if self._progress_handler: self._progress_handler(srcpath, dstpath, offset, size) finally: await file_obj.close() if local_exc: self.send_error(local_exc) setattr(local_exc, 'suppress_send', True) else: self.send_ok() remote_exc = await self.await_response() final_exc = remote_exc or local_exc if final_exc: raise final_exc async def _send_dir(self, srcpath: bytes, dstpath: bytes, attrs: SFTPAttrs) -> None: """Send directory over SCP""" self.logger.info(' Starting send of directory %s', srcpath) await self._make_cd_request(b'D', attrs, 0, srcpath) async for entry in self._fs.scandir(srcpath): name = cast(bytes, entry.filename) if name in (b'.', b'..'): continue await self._send_files(posixpath.join(srcpath, name), posixpath.join(dstpath, name), entry.attrs) await self.make_request(b'E') self.logger.info(' Finished send of directory %s', srcpath) async def _send_files(self, srcpath: bytes, dstpath: bytes, attrs: SFTPAttrs) -> None: """Send files via SCP""" try: if self._preserve: await self._make_t_request(attrs) if self._recurse and attrs.type == FILEXFER_TYPE_DIRECTORY: await self._send_dir(srcpath, dstpath, attrs) elif attrs.type == FILEXFER_TYPE_REGULAR: await self._send_file(srcpath, dstpath, attrs) else: raise _scp_error(SFTPFailure, 'Not a regular file', srcpath) except (OSError, SFTPError, ValueError) as exc: self.handle_error(exc) async def run(self, srcpath: _SCPPath) -> None: """Start SCP transfer""" cancelled = False try: if isinstance(srcpath, PurePath): srcpath = str(srcpath) if isinstance(srcpath, str): srcpath = srcpath.encode('utf-8') exc = await self.await_response() if exc: raise exc for name in await SFTPGlob(self._fs).match(srcpath): await self._send_files(cast(bytes, name.filename), b'', name.attrs) except asyncio.CancelledError: cancelled = True except (OSError, SFTPError) as exc: self.handle_error(exc) finally: await self.close(cancelled) class _SCPSink(_SCPHandler): """SCP handler for receiving files""" def __init__(self, fs: _SCPFSProtocol, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', must_be_dir: bool, preserve: bool, recurse: bool, block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, server: bool = False): super().__init__(reader, writer, error_handler, server) self._fs = fs self._must_be_dir = must_be_dir self._preserve = preserve self._recurse = recurse self._block_size = block_size self._progress_handler = progress_handler async def _recv_file(self, srcpath: bytes, dstpath: bytes, size: int) -> None: """Receive a file via SCP""" file_obj = await self._fs.open(dstpath, 'wb') local_exc = None offset = 0 self.logger.info(' Receiving file %s, size %d', dstpath, size) try: self.send_ok() if self._progress_handler and size == 0: self._progress_handler(srcpath, dstpath, 0, 0) while offset < size: blocklen = min(size - offset, self._block_size) data = await self.recv_data(blocklen) if not data: raise _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True, suppress_send=True) if not local_exc: try: await file_obj.write(data, offset) except (OSError, SFTPError) as exc: local_exc = exc offset += len(data) if self._progress_handler: self._progress_handler(srcpath, dstpath, offset, size) finally: await file_obj.close() remote_exc = await self.await_response() if local_exc: self.send_error(local_exc) setattr(local_exc, 'suppress_send',True) else: self.send_ok() final_exc = remote_exc or local_exc if final_exc: raise final_exc async def _recv_dir(self, srcpath: bytes, dstpath: bytes) -> None: """Receive a directory over SCP""" if not self._recurse: raise _scp_error(SFTPBadMessage, 'Directory received without recurse') self.logger.info(' Starting receive of directory %s', dstpath) if await self._fs.exists(dstpath): if not await self._fs.isdir(dstpath): raise _scp_error(SFTPFailure, 'Not a directory', dstpath) else: await self._fs.mkdir(dstpath) await self._recv_files(srcpath, dstpath) self.logger.info(' Finished receive of directory %s', dstpath) async def _recv_files(self, srcpath: bytes, dstpath: bytes) -> None: """Receive files over SCP""" self.send_ok() attrs = SFTPAttrs() while True: action, args = await self.recv_request() if not action: break assert args is not None try: if action in b'\x01\x02': raise _scp_error(SFTPFailure, args, fatal=action != b'\x01', suppress_send=True) elif action == b'T': if self._preserve: attrs.atime, attrs.mtime = _parse_t_args(args) self.send_ok() elif action == b'E': self.send_ok() break elif action in b'CD': try: attrs.permissions, size, name = _parse_cd_args(args) new_srcpath = posixpath.join(srcpath, name) if await self._fs.isdir(dstpath): new_dstpath = posixpath.join(dstpath, name) else: new_dstpath = dstpath if action == b'D': await self._recv_dir(new_srcpath, new_dstpath) else: await self._recv_file(new_srcpath, new_dstpath, size) if self._preserve: self.logger.info(' Preserving attrs: %s', attrs) await self._fs.setstat(new_dstpath, attrs) finally: attrs = SFTPAttrs() else: raise _scp_error(SFTPBadMessage, 'Unknown request') except (OSError, SFTPError) as exc: self.handle_error(exc) async def run(self, dstpath: _SCPPath) -> None: """Start SCP file receive""" cancelled = False try: if isinstance(dstpath, PurePath): dstpath = str(dstpath) if isinstance(dstpath, str): dstpath = dstpath.encode('utf-8') if self._must_be_dir and not await self._fs.isdir(dstpath): self.handle_error(_scp_error(SFTPFailure, 'Not a directory', dstpath)) else: await self._recv_files(b'', dstpath) except asyncio.CancelledError: cancelled = True except (OSError, SFTPError, ValueError) as exc: self.handle_error(exc) finally: await self.close(cancelled) class _SCPCopier: """SCP handler for remote-to-remote copies""" def __init__(self, src_reader: 'SSHReader[bytes]', src_writer: 'SSHWriter[bytes]', dst_reader: 'SSHReader[bytes]', dst_writer: 'SSHWriter[bytes]', block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None): self._source = _SCPHandler(src_reader, src_writer) self._sink = _SCPHandler(dst_reader, dst_writer) self._logger = self._source.logger self._block_size = block_size self._progress_handler = progress_handler self._error_handler = error_handler @property def logger(self) -> SSHLogger: """A logger associated with this SCP handler""" return self._logger def _handle_error(self, exc: Exception) -> None: """Handle an SCP error""" if isinstance(exc, BrokenPipeError): exc = _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True, suppress_send=True) self.logger.debug1('Handling SCP error: %s', str(exc)) if self._error_handler and not getattr(exc, 'fatal', False): self._error_handler(exc) else: raise exc async def _forward_response(self, src: _SCPHandler, dst: _SCPHandler) -> Optional[Exception]: """Forward an SCP response between two remote SCP servers""" # pylint: disable=no-self-use try: exc = await src.await_response() if exc: dst.send_error(exc) return exc else: dst.send_ok() return None except OSError as exc: return exc async def _copy_file(self, path: bytes, size: int) -> None: """Copy a file from one remote SCP server to another""" self.logger.info(' Copying file %s, size %d', path, size) offset = 0 if self._progress_handler and size == 0: self._progress_handler(path, path, 0, 0) while offset < size: blocklen = min(size - offset, self._block_size) data = await self._source.recv_data(blocklen) if not data: raise _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True, suppress_send=True) await self._sink.send_data(data) offset += len(data) if self._progress_handler: self._progress_handler(path, path, offset, size) source_exc = await self._forward_response(self._source, self._sink) sink_exc = await self._forward_response(self._sink, self._source) exc = sink_exc or source_exc if exc: self._handle_error(exc) async def _copy_files(self) -> None: """Copy files from one SCP server to another""" exc = await self._forward_response(self._sink, self._source) if exc: self._handle_error(exc) pathlist: List[bytes] = [] attrlist: List[SFTPAttrs] = [] attrs = SFTPAttrs() while True: action, args = await self._source.recv_request() if not action: break assert args is not None self._sink.send_request(action, args) if action in b'\x01\x02': exc = _scp_error(SFTPFailure, args, fatal=action != b'\x01') self._handle_error(exc) continue exc = await self._forward_response(self._sink, self._source) if exc: self._handle_error(exc) continue if action in b'CD': try: attrs.permissions, size, name = _parse_cd_args(args) if action == b'C': path = b'/'.join(pathlist + [name]) await self._copy_file(path, size) self.logger.info(' Preserving attrs: %s', attrs) else: pathlist.append(name) attrlist.append(attrs) self.logger.info(' Starting copy of directory %s', b'/'.join(pathlist)) finally: attrs = SFTPAttrs() elif action == b'E': if pathlist: self.logger.info(' Finished copy of directory %s', b'/'.join(pathlist)) pathlist.pop() attrs = attrlist.pop() self.logger.info(' Preserving attrs: %s', attrs) else: break elif action == b'T': attrs.atime, attrs.mtime = _parse_t_args(args) else: raise _scp_error(SFTPBadMessage, 'Unknown SCP action') async def run(self) -> None: """Start SCP remote-to-remote transfer""" cancelled = False try: await self._copy_files() except asyncio.CancelledError: cancelled = True except (OSError, SFTPError) as exc: self._handle_error(exc) finally: await self._source.close(cancelled) await self._sink.close(cancelled) async def scp(srcpaths: Union[_SCPConnPath, Sequence[_SCPConnPath]], dstpath: _SCPConnPath = None, *, preserve: bool = False, recurse: bool = False, block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, **kwargs) -> None: """Copy files using SCP This function is a coroutine which copies one or more files or directories using the SCP protocol. Source and destination paths can be `str` or `bytes` values to reference local files or can be a tuple of the form `(conn, path)` where `conn` is an open :class:`SSHClientConnection` to reference files and directories on a remote system. For convenience, a host name or tuple of the form `(host, port)` can be provided in place of the :class:`SSHClientConnection` to request that a new SSH connection be opened to a host using default connect arguments. A `str` or `bytes` value of the form `'host:path'` may also be used in place of the `(conn, path)` tuple to make a new connection to the requested host on the default SSH port. Either a single source path or a sequence of source paths can be provided, and each path can contain '*' and '?' wildcard characters which can be used to match multiple source files or directories. When copying a single file or directory, the destination path can be either the full path to copy data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the source path will be used as the destination name. When copying multiple files, the destination path must refer to a directory. If it doesn't already exist, a directory will be created with that name. If the destination path is an :class:`SSHClientConnection` without a path or the path provided is empty, files are copied into the default destination working directory. If preserve is `True`, the access and modification times and permissions of the original files and directories are set on the copied files. However, do to the timing of when this information is sent, the preserved access time will be what was set on the source file before the copy begins. So, the access time on the source file will no longer match the destination after the transfer completes. If recurse is `True` and the source path points at a directory, the entire subtree under that directory is copied. Symbolic links found on the source will have the contents of their target copied rather than creating a destination symbolic link. When using this option during a recursive copy, one needs to watch out for links that result in loops. SCP does not provide a mechanism for preserving links. If you need this, consider using SFTP instead. The block_size value controls the size of read and write operations issued to copy the files. It defaults to 256 KB. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments passed to this handler will be the relative path of the file being copied, bytes copied so far, and total bytes in the file being copied. If multiple source paths are provided or recurse is set to `True`, the progress_handler will be called consecutively on each file being copied. If error_handler is specified and an error occurs during the copy, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple source paths are provided or when recurse is set to `True`, to allow error information to be collected without aborting the copy of the remaining files. The error handler can raise an exception if it wants the copy to completely stop. Otherwise, after an error, the copy will continue starting with the next file. If any other keyword arguments are specified, they will be passed to the AsyncSSH connect() call when attempting to open any new SSH connections needed to perform the file transfer. :param srcpaths: The paths of the source files or directories to copy :param dstpath: (optional) The path of the destination file or directory to copy into :param preserve: (optional) Whether or not to preserve the original file attributes :param recurse: (optional) Whether or not to recursively copy directories :param block_size: (optional) The block size to use for file reads and writes :param progress_handler: (optional) The function to call to report copy progress :param error_handler: (optional) The function to call when an error occurs :type preserve: `bool` :type recurse: `bool` :type block_size: `int` :type progress_handler: `callable` :type error_handler: `callable` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error | :exc:`ValueError` if both source and destination are local """ if (isinstance(srcpaths, (bytes, str, PurePath)) or (isinstance(srcpaths, tuple) and len(srcpaths) == 2)): srcpaths = [srcpaths] # type: ignore srcpaths: Sequence[_SCPConnPath] must_be_dir = len(srcpaths) > 1 dstconn, dstpath, close_dst = await _parse_path(dstpath, **kwargs) try: for srcpath in srcpaths: srcconn, srcpath, close_src = await _parse_path(srcpath, **kwargs) try: if srcconn and dstconn: src_reader, src_writer = await _start_remote( srcconn, True, must_be_dir, preserve, recurse, srcpath) dst_reader, dst_writer = await _start_remote( dstconn, False, must_be_dir, preserve, recurse, dstpath) copier = _SCPCopier(src_reader, src_writer, dst_reader, dst_writer, block_size, progress_handler, error_handler) await copier.run() elif srcconn: reader, writer = await _start_remote( srcconn, True, must_be_dir, preserve, recurse, srcpath) sink = _SCPSink(local_fs, reader, writer, must_be_dir, preserve, recurse, block_size, progress_handler, error_handler) await sink.run(dstpath) elif dstconn: reader, writer = await _start_remote( dstconn, False, must_be_dir, preserve, recurse, dstpath) source = _SCPSource(local_fs, reader, writer, preserve, recurse, block_size, progress_handler, error_handler) await source.run(srcpath) else: raise ValueError('Local copy not supported') finally: if close_src: assert srcconn is not None srcconn.close() await srcconn.wait_closed() finally: if close_dst: assert dstconn is not None dstconn.close() await dstconn.wait_closed() def run_scp_server(sftp_server: SFTPServer, command: str, stdin: 'SSHReader[bytes]', stdout: 'SSHWriter[bytes]', stderr: 'SSHWriter[bytes]') -> MaybeAwait[None]: """Return a handler for an SCP server session""" async def _run_handler() -> None: """Run an SCP server to handle this request""" try: await handler.run(args.path) finally: sftp_server.exit() try: args = _SCPArgParser().parse(command) except ValueError as exc: stdin.logger.info('Error starting SCP server: %s', str(exc)) stderr.write(b'scp: ' + str(exc).encode('utf-8') + b'\n') cast('SSHServerChannel', stderr.channel).exit(1) return None stdin.logger.info('Starting SCP server, args: %s', command[4:].strip()) fs = SFTPServerFS(sftp_server) handler: Union[_SCPSource, _SCPSink] if args.source: handler = _SCPSource(fs, stdin, stdout, args.preserve, args.recurse, error_handler=False, server=True) else: handler = _SCPSink(fs, stdin, stdout, args.must_be_dir, args.preserve, args.recurse, error_handler=False, server=True) return _run_handler() asyncssh-2.20.0/asyncssh/server.py000066400000000000000000001334701475467777400172050ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH server protocol handler""" from typing import TYPE_CHECKING, Optional, Tuple, Union from .auth import KbdIntChallenge, KbdIntResponse from .listener import SSHListener from .misc import MaybeAwait from .public_key import SSHKey from .stream import SSHSocketSessionFactory, SSHServerSessionFactory if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHServerConnection, SSHAcceptHandler from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel from .channel import SSHTunTapChannel from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession from .session import SSHTunTapSession _NewSession = Union[bool, 'SSHServerSession', SSHServerSessionFactory, Tuple['SSHServerChannel', 'SSHServerSession'], Tuple['SSHServerChannel', SSHServerSessionFactory]] _NewTCPSession = Union[bool, 'SSHTCPSession', SSHSocketSessionFactory, Tuple['SSHTCPChannel', 'SSHTCPSession'], Tuple['SSHTCPChannel', SSHSocketSessionFactory]] _NewUNIXSession = Union[bool, 'SSHUNIXSession', SSHSocketSessionFactory, Tuple['SSHUNIXChannel', 'SSHUNIXSession'], Tuple['SSHUNIXChannel', SSHSocketSessionFactory]] _NewTunTapSession = Union[bool, 'SSHTunTapSession', SSHSocketSessionFactory, Tuple['SSHTunTapChannel', 'SSHTunTapSession'], Tuple['SSHTunTapChannel', SSHSocketSessionFactory]] _NewListener = Union[bool, 'SSHAcceptHandler', SSHListener] class SSHServer: """SSH server protocol handler Applications may subclass this when implementing an SSH server to provide custom authentication and request handlers. Whenever a new SSH server connection is accepted, a corresponding SSHServer object is created and the method :meth:`connection_made` is called, passing in the :class:`SSHServerConnection` object. When the connection is closed, the method :meth:`connection_lost` is called with an exception representing the reason for the disconnect, or `None` if the connection was closed cleanly. The method :meth:`begin_auth` can be overridden decide whether or not authentication is required, and additional callbacks are provided for each form of authentication in cases where authentication information is not provided in the call to :func:`create_server`. In addition, the methods :meth:`session_requested`, :meth:`connection_requested`, :meth:`server_requested`, :meth:`unix_connection_requested`, or :meth:`unix_server_requested` can be overridden to handle requests to open sessions or direct connections or set up listeners for forwarded connections. .. note:: The authentication callbacks described here can be defined as coroutines. However, they may be cancelled if they are running when the SSH connection is closed by the client. If they attempt to catch the CancelledError exception to perform cleanup, they should make sure to re-raise it to allow AsyncSSH to finish its own cleanup. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, conn: 'SSHServerConnection') -> None: """Called when a connection is made This method is called when a new TCP connection is accepted. The `conn` parameter should be stored if needed for later use. :param conn: The connection which was successfully opened :type conn: :class:`SSHServerConnection` """ def connection_lost(self, exc: Optional[Exception]) -> None: """Called when a connection is lost or closed This method is called when a connection is closed. If the connection is shut down cleanly, *exc* will be `None`. Otherwise, it will be an exception explaining the reason for the disconnect. """ def debug_msg_received(self, msg: str, lang: str, always_display: bool) -> None: """A debug message was received on this connection This method is called when the other end of the connection sends a debug message. Applications should implement this method if they wish to process these debug messages. :param msg: The debug message sent :param lang: The language the message is in :param always_display: Whether or not to display the message :type msg: `str` :type lang: `str` :type always_display: `bool` """ def begin_auth(self, username: str) -> MaybeAwait[bool]: """Authentication has been requested by the client This method will be called when authentication is attempted for the specified user. Applications should use this method to prepare whatever state they need to complete the authentication, such as loading in the set of authorized keys for that user. If no authentication is required for this user, this method should return `False` to cause the authentication to immediately succeed. Otherwise, it should return `True` to indicate that authentication should proceed. If blocking operations need to be performed to prepare the state needed to complete the authentication, this method may be defined as a coroutine. :param username: The name of the user being authenticated :type username: `str` :returns: A `bool` indicating whether authentication is required """ return True # pragma: no cover def auth_completed(self) -> None: """Authentication was completed successfully This method is called when authentication has completed successfully. Applications may use this method to perform processing based on the authenticated username or options in the authorized keys list or certificate associated with the user before any sessions are opened or forwarding requests are handled. """ def validate_gss_principal(self, username: str, user_principal: str, host_principal: str) -> MaybeAwait[bool]: """Return whether a GSS principal is valid for this user This method should return `True` if the specified user principal is valid for the user being authenticated. It can be overridden by applications wishing to perform their own authentication. If blocking operations need to be performed to determine the validity of the principal, this method may be defined as a coroutine. By default, this method will return `True` only when the name in the user principal exactly matches the username and the domain of the user principal matches the domain of the host principal. :param username: The user being authenticated :param user_principal: The user principal sent by the client :param host_principal: The host principal sent by the server :type username: `str` :type user_principal: `str` :type host_principal: `str` :returns: A `bool` indicating if the specified user principal is valid for the user being authenticated """ host_domain = host_principal.rsplit('@')[-1] return user_principal == username + '@' + host_domain def host_based_auth_supported(self) -> bool: """Return whether or not host-based authentication is supported This method should return `True` if client host-based authentication is supported. Applications wishing to support it must have this method return `True` and implement :meth:`validate_host_public_key` and/or :meth:`validate_host_ca_key` to return whether or not the key provided by the client is valid for the client host being authenticated. By default, it returns `False` indicating the client host based authentication is not supported. :returns: A `bool` indicating if host-based authentication is supported or not """ return False # pragma: no cover def validate_host_public_key(self, client_host: str, client_addr: str, client_port: int, key: SSHKey) -> bool: """Return whether key is an authorized host key for this client host Host key based client authentication can be supported by passing authorized host keys in the `known_client_hosts` argument of :func:`create_server`. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid host key for the client host being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. By default, this method returns `False` for all client host keys. .. note:: This function only needs to report whether the public key provided is a valid key for this client host. If it is, AsyncSSH will verify that the client possesses the corresponding private key before allowing the authentication to succeed. :param client_host: The hostname of the client host :param client_addr: The IP address of the client host :param client_port: The port number on the client host :param key: The host public key sent by the client :type client_host: `str` :type client_addr: `str` :type client_port: `int` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid key for the client host being authenticated """ return False # pragma: no cover def validate_host_ca_key(self, client_host: str, client_addr: str, client_port: int, key: SSHKey) -> bool: """Return whether key is an authorized CA key for this client host Certificate based client host authentication can be supported by passing authorized host CA keys in the `known_client_hosts` argument of :func:`create_server`. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid certificate authority key for the client host being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. By default, this method returns `False` for all CA keys. .. note:: This function only needs to report whether the public key provided is a valid CA key for this client host. If it is, AsyncSSH will verify that the certificate is valid, that the client host is one of the valid principals for the certificate, and that the client possesses the private key corresponding to the public key in the certificate before allowing the authentication to succeed. :param client_host: The hostname of the client host :param client_addr: The IP address of the client host :param client_port: The port number on the client host :param key: The public key which signed the certificate sent by the client :type client_host: `str` :type client_addr: `str` :type client_port: `int` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid CA key for the client host being authenticated """ return False # pragma: no cover def validate_host_based_user(self, username: str, client_host: str, client_username: str) -> MaybeAwait[bool]: """Return whether remote host and user is authorized for this user This method should return `True` if the specified client host and user is valid for the user being authenticated. It can be overridden by applications wishing to enforce restrictions on which remote users are allowed to authenticate as particular local users. If blocking operations need to be performed to determine the validity of the client host and user, this method may be defined as a coroutine. By default, this method will return `True` when the client username matches the name of the user being authenticated. :param username: The user being authenticated :param client_host: The hostname of the client host making the request :param client_username: The username of the user on the client host :type username: `str` :type client_host: `str` :type client_username: `str` :returns: A `bool` indicating if the specified client host and user is valid for the user being authenticated """ return username == client_username def public_key_auth_supported(self) -> bool: """Return whether or not public key authentication is supported This method should return `True` if client public key authentication is supported. Applications wishing to support it must have this method return `True` and implement :meth:`validate_public_key` and/or :meth:`validate_ca_key` to return whether or not the key provided by the client is valid for the user being authenticated. By default, it returns `False` indicating the client public key authentication is not supported. :returns: A `bool` indicating if public key authentication is supported or not """ return False # pragma: no cover def validate_public_key(self, username: str, key: SSHKey) -> \ MaybeAwait[bool]: """Return whether key is an authorized client key for this user Key based client authentication can be supported by passing authorized keys in the `authorized_client_keys` argument of :func:`create_server`, or by calling :meth:`set_authorized_keys ` on the server connection from the :meth:`begin_auth` method. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid client key for the user being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. If blocking operations need to be performed to determine the validity of the key, this method may be defined as a coroutine. By default, this method returns `False` for all client keys. .. note:: This function only needs to report whether the public key provided is a valid client key for this user. If it is, AsyncSSH will verify that the client possesses the corresponding private key before allowing the authentication to succeed. :param username: The user being authenticated :param key: The public key sent by the client :type username: `str` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid client key for the user being authenticated """ return False # pragma: no cover def validate_ca_key(self, username: str, key: SSHKey) -> MaybeAwait[bool]: """Return whether key is an authorized CA key for this user Certificate based client authentication can be supported by passing authorized CA keys in the `authorized_client_keys` argument of :func:`create_server`, or by calling :meth:`set_authorized_keys ` on the server connection from the :meth:`begin_auth` method. However, for more flexibility in matching on the allowed set of keys, this method can be implemented by the application to do the matching itself. It should return `True` if the specified key is a valid certificate authority key for the user being authenticated. This method may be called multiple times with different keys provided by the client. Applications should precompute as much as possible in the :meth:`begin_auth` method so that this function can quickly return whether the key provided is in the list. If blocking operations need to be performed to determine the validity of the key, this method may be defined as a coroutine. By default, this method returns `False` for all CA keys. .. note:: This function only needs to report whether the public key provided is a valid CA key for this user. If it is, AsyncSSH will verify that the certificate is valid, that the user is one of the valid principals for the certificate, and that the client possesses the private key corresponding to the public key in the certificate before allowing the authentication to succeed. :param username: The user being authenticated :param key: The public key which signed the certificate sent by the client :type username: `str` :type key: :class:`SSHKey` *public key* :returns: A `bool` indicating if the specified key is a valid CA key for the user being authenticated """ return False # pragma: no cover def password_auth_supported(self) -> bool: """Return whether or not password authentication is supported This method should return `True` if password authentication is supported. Applications wishing to support it must have this method return `True` and implement :meth:`validate_password` to return whether or not the password provided by the client is valid for the user being authenticated. By default, this method returns `False` indicating that password authentication is not supported. :returns: A `bool` indicating if password authentication is supported or not """ return False # pragma: no cover def validate_password(self, username: str, password: str) -> \ MaybeAwait[bool]: """Return whether password is valid for this user This method should return `True` if the specified password is a valid password for the user being authenticated. It must be overridden by applications wishing to support password authentication. If the password provided is valid but expired, this method may raise :exc:`PasswordChangeRequired` to request that the client provide a new password before authentication is allowed to complete. In this case, the application must override :meth:`change_password` to handle the password change request. This method may be called multiple times with different passwords provided by the client. Applications may wish to limit the number of attempts which are allowed. This can be done by having :meth:`password_auth_supported` begin returning `False` after the maximum number of attempts is exceeded. If blocking operations need to be performed to determine the validity of the password, this method may be defined as a coroutine. By default, this method returns `False` for all passwords. :param username: The user being authenticated :param password: The password sent by the client :type username: `str` :type password: `str` :returns: A `bool` indicating if the specified password is valid for the user being authenticated :raises: :exc:`PasswordChangeRequired` if the password provided is expired and needs to be changed """ return False # pragma: no cover def change_password(self, username: str, old_password: str, new_password: str) -> MaybeAwait[bool]: """Handle a request to change a user's password This method is called when a user makes a request to change their password. It should first validate that the old password provided is correct and then attempt to change the user's password to the new value. If the old password provided is valid and the change to the new password is successful, this method should return `True`. If the old password is not valid or password changes are not supported, it should return `False`. It may also raise :exc:`PasswordChangeRequired` to request that the client try again if the new password is not acceptable for some reason. If blocking operations need to be performed to determine the validity of the old password or to change to the new password, this method may be defined as a coroutine. By default, this method returns `False`, rejecting all password changes. :param username: The user whose password should be changed :param old_password: The user's current password :param new_password: The new password being requested :type username: `str` :type old_password: `str` :type new_password: `str` :returns: A `bool` indicating if the password change is successful or not :raises: :exc:`PasswordChangeRequired` if the new password is not acceptable and the client should be asked to provide another """ return False # pragma: no cover def kbdint_auth_supported(self) -> bool: """Return whether or not keyboard-interactive authentication is supported This method should return `True` if keyboard-interactive authentication is supported. Applications wishing to support it must have this method return `True` and implement :meth:`get_kbdint_challenge` and :meth:`validate_kbdint_response` to generate the appropriate challenges and validate the responses for the user being authenticated. By default, this method returns `NotImplemented` tying this authentication to password authentication. If the application implements password authentication and this method is not overridden, keyboard-interactive authentication will be supported by prompting for a password and passing that to the password authentication callbacks. :returns: A `bool` indicating if keyboard-interactive authentication is supported or not """ return NotImplemented # pragma: no cover def get_kbdint_challenge(self, username: str, lang: str, submethods: str) -> MaybeAwait[KbdIntChallenge]: """Return a keyboard-interactive auth challenge This method should return `True` if authentication should succeed without any challenge, `False` if authentication should fail without any challenge, or an auth challenge consisting of a challenge name, instructions, a language tag, and a list of tuples containing prompt strings and booleans indicating whether input should be echoed when a value is entered for that prompt. If blocking operations need to be performed to determine the challenge to issue, this method may be defined as a coroutine. :param username: The user being authenticated :param lang: The language requested by the client for the challenge :param submethods: A comma-separated list of the types of challenges the client can support, or the empty string if the server should choose :type username: `str` :type lang: `str` :type submethods: `str` :returns: An authentication challenge as described above """ return False # pragma: no cover def validate_kbdint_response( self, username: str, responses: KbdIntResponse) -> \ MaybeAwait[KbdIntChallenge]: """Return whether the keyboard-interactive response is valid for this user This method should validate the keyboard-interactive responses provided and return `True` if authentication should succeed with no further challenge, `False` if authentication should fail, or an additional auth challenge in the same format returned by :meth:`get_kbdint_challenge`. Any series of challenges can be returned this way. To print a message in the middle of a sequence of challenges without prompting for additional data, a challenge can be returned with an empty list of prompts. After the client acknowledges this message, this function will be called again with an empty list of responses to continue the authentication. If blocking operations need to be performed to determine the validity of the response or the next challenge to issue, this method may be defined as a coroutine. :param username: The user being authenticated :param responses: A list of responses to the last challenge :type username: `str` :type responses: `list` of `str` :returns: `True`, `False`, or the next challenge """ return False # pragma: no cover def session_requested(self) -> MaybeAwait[_NewSession]: """Handle an incoming session request This method is called when a session open request is received from the client, indicating it wishes to open a channel to be used for running a shell, executing a command, or connecting to a subsystem. If the application wishes to accept the session, it must override this method to return either an :class:`SSHServerSession` object to use to process the data received on the channel or a tuple consisting of an :class:`SSHServerChannel` object created with :meth:`create_server_channel ` and an :class:`SSHServerSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHServerSession` object can be returned instead of the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHServerChannel` object. To reject this request, this method should return `False` to send back a "Session refused" response or raise a :exc:`ChannelOpenError` exception with the reason for the failure. The details of what type of session the client wants to start will be delivered to methods on the :class:`SSHServerSession` object which is returned, along with other information such as environment variables, terminal type, size, and modes. By default, all session requests are rejected. :returns: One of the following: * An :class:`SSHServerSession` object or a coroutine which returns an :class:`SSHServerSession` * A tuple consisting of an :class:`SSHServerChannel` and the above * A `callable` or coroutine handler function which takes AsyncSSH stream objects for stdin, stdout, and stderr as arguments * A tuple consisting of an :class:`SSHServerChannel` and the above * `False` to refuse the request :raises: :exc:`ChannelOpenError` if the session shouldn't be accepted """ return False # pragma: no cover def connection_requested(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> _NewTCPSession: """Handle a direct TCP/IP connection request This method is called when a direct TCP/IP connection request is received by the server. Applications wishing to accept such connections must override this method. To allow standard port forwarding of data on the connection to the requested destination host and port, this method should return `True`. To reject this request, this method should return `False` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHTCPSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHTCPChannel` object created with :meth:`create_tcp_channel() ` and an :class:`SSHTCPSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHTCPSession` object can be returned instead of the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHTCPChannel` object. By default, all connection requests are rejected. :param dest_host: The address the client wishes to connect to :param dest_port: The port the client wishes to connect to :param orig_host: The address the connection was originated from :param orig_port: The port the connection was originated from :type dest_host: `str` :type dest_port: `int` :type orig_host: `str` :type orig_port: `int` :returns: One of the following: * An :class:`SSHTCPSession` object or a coroutine which returns an :class:`SSHTCPSession` * A tuple consisting of an :class:`SSHTCPChannel` and the above * A `callable` or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHTCPChannel` and the above * `True` to request standard port forwarding * `False` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover def server_requested(self, listen_host: str, listen_port: int) -> MaybeAwait[_NewListener]: """Handle a request to listen on a TCP/IP address and port This method is called when a client makes a request to listen on an address and port for incoming TCP connections. The port to listen on may be `0` to request a dynamically allocated port. Applications wishing to allow TCP/IP connection forwarding must override this method. To set up standard port forwarding of connections received on this address and port, this method should return `True`. If the application wishes to manage listening for incoming connections itself, this method should return an :class:`SSHListener` object that listens for new connections and calls :meth:`create_connection ` on each of them to forward them back to the client or return `None` if the listener can't be set up. If blocking operations need to be performed to set up the listener, a coroutine which returns an :class:`SSHListener` can be returned instead of the listener itself. To reject this request, this method should return `False`. By default, this method rejects all server requests. :param listen_host: The address the server should listen on :param listen_port: The port the server should listen on, or the value `0` to request that the server dynamically allocate a port :type listen_host: `str` :type listen_port: `int` :returns: One of the following: * An :class:`SSHListener` object * `True` to set up standard port forwarding * `False` to reject the request * A coroutine object which returns one of the above """ return False # pragma: no cover def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession: """Handle a direct UNIX domain socket connection request This method is called when a direct UNIX domain socket connection request is received by the server. Applications wishing to accept such connections must override this method. To allow standard path forwarding of data on the connection to the requested destination path, this method should return `True`. To reject this request, this method should return `False` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHUNIXSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHUNIXChannel` object created with :meth:`create_unix_channel() ` and an :class:`SSHUNIXSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHUNIXSession` object can be returned instead of the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHUNIXChannel` object. By default, all connection requests are rejected. :param dest_path: The path the client wishes to connect to :type dest_path: `str` :returns: One of the following: * An :class:`SSHUNIXSession` object or a coroutine which returns an :class:`SSHUNIXSession` * A tuple consisting of an :class:`SSHUNIXChannel` and the above * A `callable` or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHUNIXChannel` and the above * `True` to request standard path forwarding * `False` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover def unix_server_requested(self, listen_path: str) -> \ MaybeAwait[_NewListener]: """Handle a request to listen on a UNIX domain socket This method is called when a client makes a request to listen on a path for incoming UNIX domain socket connections. Applications wishing to allow UNIX domain socket forwarding must override this method. To set up standard path forwarding of connections received on this path, this method should return `True`. If the application wishes to manage listening for incoming connections itself, this method should return an :class:`SSHListener` object that listens for new connections and calls :meth:`create_unix_connection ` on each of them to forward them back to the client or return `None` if the listener can't be set up. If blocking operations need to be performed to set up the listener, a coroutine which returns an :class:`SSHListener` can be returned instead of the listener itself. To reject this request, this method should return `False`. By default, this method rejects all server requests. :param listen_path: The path the server should listen on :type listen_path: `str` :returns: One of the following: * An :class:`SSHListener` object or a coroutine which returns an :class:`SSHListener` or `False` if the listener can't be opened * `True` to set up standard path forwarding * `False` to reject the request """ return False # pragma: no cover def tun_requested(self, unit: Optional[int]) -> _NewTunTapSession: """Handle a layer 3 tunnel request This method is called when a layer 3 tunnel request is received by the server. Applications wishing to accept such tunnels must override this method. To allow standard path forwarding of data on the connection to the requested TUN device, this method should return `True`. To reject this request, this method should return `False` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHTunTapSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHTunTapChannel` object created with :meth:`create_tuntap_channel() ` and an :class:`SSHTunTapSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHTunTapSession` object can be returned instead of the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHTunTapChannel` object. By default, all layer 3 tunnel requests are rejected. :param dest_path: The path the client wishes to connect to :type dest_path: `str` :returns: One of the following: * An :class:`SSHTunTapSession` object or a coroutine which returns an :class:`SSHTunTapSession` * A tuple consisting of an :class:`SSHTunTapChannel` and the above * A `callable` or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHTunTapChannel` and the above * `True` to request standard layer 3 tunnel forwarding * `False` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover def tap_requested(self, unit: Optional[int]) -> _NewTunTapSession: """Handle a layer 2 tunnel request This method is called when a layer 2 tunnel request is received by the server. Applications wishing to accept such tunnels must override this method. To allow standard path forwarding of data on the connection to the requested TUN device, this method should return `True`. To reject this request, this method should return `False` to send back a "Connection refused" response or raise an :exc:`ChannelOpenError` exception with the reason for the failure. If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHTunTapSession` object which can be used to process the data received on the channel or a tuple consisting of of an :class:`SSHTunTapChannel` object created with :meth:`create_tuntap_channel() ` and an :class:`SSHTunTapSession`, if the application wishes to pass non-default arguments when creating the channel. If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHTunTapSession` object can be returned instead of the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHTunTapChannel` object. By default, all layer 2 tunnel requests are rejected. :param dest_path: The path the client wishes to connect to :type dest_path: `str` :returns: One of the following: * An :class:`SSHTunTapSession` object or a coroutine which returns an :class:`SSHTunTapSession` * A tuple consisting of an :class:`SSHTunTapChannel` and the above * A `callable` or coroutine handler function which takes AsyncSSH stream objects for reading from and writing to the connection * A tuple consisting of an :class:`SSHTunTapChannel` and the above * `True` to request standard layer 2 tunnel forwarding * `False` to refuse the connection :raises: :exc:`ChannelOpenError` if the connection shouldn't be accepted """ return False # pragma: no cover asyncssh-2.20.0/asyncssh/session.py000066400000000000000000000545341475467777400173650ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH session handlers""" from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic from typing import Mapping, Optional, Tuple if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHClientChannel, SSHServerChannel from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel DataType = Optional[int] class SSHSession(Generic[AnyStr]): """SSH session handler""" # pylint: disable=no-self-use,unused-argument def connection_made(self, chan: Any) -> None: """Called when a channel is opened successfully""" def connection_lost(self, exc: Optional[Exception]) -> None: """Called when a channel is closed This method is called when a channel is closed. If the channel is shut down cleanly, *exc* will be `None`. Otherwise, it will be an exception explaining the reason for the channel close. :param exc: The exception which caused the channel to close, or `None` if the channel closed cleanly. :type exc: :class:`Exception` """ def session_started(self) -> None: """Called when the session is started This method is called when a session has started up. For client and server sessions, this will be called once a shell, exec, or subsystem request has been successfully completed. For TCP and UNIX domain socket sessions, it will be called immediately after the connection is opened. """ def data_received(self, data: AnyStr, datatype: DataType) -> None: """Called when data is received on the channel This method is called when data is received on the channel. If an encoding was specified when the channel was created, the data will be delivered as a string after decoding with the requested encoding. Otherwise, the data will be delivered as bytes. :param data: The data received on the channel :param datatype: The extended data type of the data, from :ref:`extended data types ` :type data: `str` or `bytes` """ def eof_received(self) -> bool: """Called when EOF is received on the channel This method is called when an end-of-file indication is received on the channel, after which no more data will be received. If this method returns `True`, the channel remains half open and data may still be sent. Otherwise, the channel is automatically closed after this method returns. This is the default behavior for classes derived directly from :class:`SSHSession`, but not when using the higher-level streams API. Because input is buffered in that case, streaming sessions enable half-open channels to allow applications to respond to input read after an end-of-file indication is received. """ return False # pragma: no cover def pause_writing(self) -> None: """Called when the write buffer becomes full This method is called when the channel's write buffer becomes full and no more data can be sent until the remote system adjusts its window. While data can still be buffered locally, applications may wish to stop producing new data until the write buffer has drained. """ def resume_writing(self) -> None: """Called when the write buffer has sufficiently drained This method is called when the channel's send window reopens and enough data has drained from the write buffer to allow the application to produce more data. """ class SSHClientSession(SSHSession[AnyStr]): """SSH client session handler Applications should subclass this when implementing an SSH client session handler. The functions listed below should be implemented to define application-specific behavior. In particular, the standard `asyncio` protocol methods such as :meth:`connection_made`, :meth:`connection_lost`, :meth:`data_received`, :meth:`eof_received`, :meth:`pause_writing`, and :meth:`resume_writing` are all supported. In addition, :meth:`session_started` is called as soon as the SSH session is fully started, :meth:`xon_xoff_requested` can be used to determine if the server wants the client to support XON/XOFF flow control, and :meth:`exit_status_received` and :meth:`exit_signal_received` can be used to receive session exit information. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, chan: 'SSHClientChannel[AnyStr]') -> None: """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHClientChannel` """ def xon_xoff_requested(self, client_can_do: bool) -> None: """XON/XOFF flow control has been enabled or disabled This method is called to notify the client whether or not to enable XON/XOFF flow control. If client_can_do is `True` and output is being sent to an interactive terminal the application should allow input of Control-S and Control-Q to pause and resume output, respectively. If client_can_do is `False`, Control-S and Control-Q should be treated as normal input and passed through to the server. Non-interactive applications can ignore this request. By default, this message is ignored. :param client_can_do: Whether or not to enable XON/XOFF flow control :type client_can_do: `bool` """ def exit_status_received(self, status: int) -> None: """A remote exit status has been received for this session This method is called when the shell, command, or subsystem running on the server terminates and returns an exit status. A zero exit status generally means that the operation was successful. This call will generally be followed by a call to :meth:`connection_lost`. By default, the exit status is ignored. :param status: The exit status returned by the remote process :type status: `int` """ def exit_signal_received(self, signal: str, core_dumped: bool, msg: str, lang: str) -> None: """A remote exit signal has been received for this session This method is called when the shell, command, or subsystem running on the server terminates abnormally with a signal. A more detailed error may also be provided, along with an indication of whether the remote process dumped core. This call will generally be followed by a call to :meth:`connection_lost`. By default, exit signals are ignored. :param signal: The signal which caused the remote process to exit :param core_dumped: Whether or not the remote process dumped core :param msg: Details about what error occurred :param lang: The language the error message is in :type signal: `str` :type core_dumped: `bool` :type msg: `str` :type lang: `str` """ class SSHServerSession(SSHSession[AnyStr]): """SSH server session handler Applications should subclass this when implementing an SSH server session handler. The functions listed below should be implemented to define application-specific behavior. In particular, the standard `asyncio` protocol methods such as :meth:`connection_made`, :meth:`connection_lost`, :meth:`data_received`, :meth:`eof_received`, :meth:`pause_writing`, and :meth:`resume_writing` are all supported. In addition, :meth:`pty_requested` is called when the client requests a pseudo-terminal, one of :meth:`shell_requested`, :meth:`exec_requested`, or :meth:`subsystem_requested` is called depending on what type of session the client wants to start, :meth:`session_started` is called once the SSH session is fully started, :meth:`terminal_size_changed` is called when the client's terminal size changes, :meth:`signal_received` is called when the client sends a signal, and :meth:`break_received` is called when the client sends a break. """ # pylint: disable=no-self-use,unused-argument def connection_made(self, chan: 'SSHServerChannel[AnyStr]') -> None: """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHServerChannel` """ def pty_requested(self, term_type: str, term_size: Tuple[int, int, int, int], term_modes: Mapping[int, int]) -> bool: """A pseudo-terminal has been requested This method is called when the client sends a request to allocate a pseudo-terminal with the requested terminal type, size, and POSIX terminal modes. This method should return `True` if the request for the pseudo-terminal is accepted. Otherwise, it should return `False` to reject the request. By default, requests to allocate a pseudo-terminal are accepted but nothing is done with the associated terminal information. Applications wishing to use this information should implement this method and have it return `True`, or call :meth:`get_terminal_type() `, :meth:`get_terminal_size() `, or :meth:`get_terminal_mode() ` on the :class:`SSHServerChannel` to get the information they need after a shell, command, or subsystem is started. :param term_type: Terminal type to set for this session :param term_size: Terminal size to set for this session provided as a tuple of four `int` values: the width and height of the terminal in characters followed by the width and height of the terminal in pixels :param term_modes: POSIX terminal modes to set for this session, where keys are taken from :ref:`POSIX terminal modes ` with values defined in section 8 of :rfc:`RFC 4254 <4254#section-8>`. :type term_type: `str` :type term_size: tuple of 4 `int` values :type term_modes: `dict` :returns: A `bool` indicating if the request for a pseudo-terminal was allowed or not """ return True # pragma: no cover def terminal_size_changed(self, width: int, height: int, pixwidth: int, pixheight: int) -> None: """The terminal size has changed This method is called when a client requests a pseudo-terminal and again whenever the the size of he client's terminal window changes. By default, this information is ignored, but applications wishing to use the terminal size can implement this method to get notified whenever it changes. :param width: The width of the terminal in characters :param height: The height of the terminal in characters :param pixwidth: (optional) The width of the terminal in pixels :param pixheight: (optional) The height of the terminal in pixels :type width: `int` :type height: `int` :type pixwidth: `int` :type pixheight: `int` """ def shell_requested(self) -> bool: """The client has requested a shell This method should be implemented by the application to perform whatever processing is required when a client makes a request to open an interactive shell. It should return `True` to accept the request, or `False` to reject it. If the application returns `True`, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns `False` to reject all requests. :returns: A `bool` indicating if the shell request was allowed or not """ return False # pragma: no cover def exec_requested(self, command: str) -> bool: """The client has requested to execute a command This method should be implemented by the application to perform whatever processing is required when a client makes a request to execute a command. It should return `True` to accept the request, or `False` to reject it. If the application returns `True`, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns `False` to reject all requests. :param command: The command the client has requested to execute :type command: `str` :returns: A `bool` indicating if the exec request was allowed or not """ return False # pragma: no cover def subsystem_requested(self, subsystem: str) -> bool: """The client has requested to start a subsystem This method should be implemented by the application to perform whatever processing is required when a client makes a request to start a subsystem. It should return `True` to accept the request, or `False` to reject it. If the application returns `True`, the :meth:`session_started` method will be called once the channel is fully open. No output should be sent until this method is called. By default this method returns `False` to reject all requests. :param subsystem: The subsystem to start :type subsystem: `str` :returns: A `bool` indicating if the request to open the subsystem was allowed or not """ return False # pragma: no cover def break_received(self, msec: int) -> bool: """The client has sent a break This method is called when the client requests that the server perform a break operation on the terminal. If the break is performed, this method should return `True`. Otherwise, it should return `False`. By default, this method returns `False` indicating that no break was performed. :param msec: The duration of the break in milliseconds :type msec: `int` :returns: A `bool` to indicate if the break operation was performed or not """ return False # pragma: no cover def signal_received(self, signal: str) -> None: """The client has sent a signal This method is called when the client delivers a signal on the channel. By default, signals from the client are ignored. :param signal: The name of the signal received :type signal: `str` """ def soft_eof_received(self) -> None: """The client has sent a soft EOF This method is called by the line editor when the client send a soft EOF (Ctrl-D on an empty input line). By default, soft EOF will trigger an EOF to an outstanding read call but still allow additional input to be received from the client after that. """ class SSHTCPSession(SSHSession[AnyStr]): """SSH TCP session handler Applications should subclass this when implementing a handler for SSH direct or forwarded TCP connections. SSH client applications wishing to open a direct connection should call :meth:`create_connection() ` on their :class:`SSHClientConnection`, passing in a factory which returns instances of this class. Server applications wishing to allow direct connections should implement the coroutine :meth:`connection_requested() ` on their :class:`SSHServer` object and have it return instances of this class. Server applications wishing to allow connection forwarding back to the client should implement the coroutine :meth:`server_requested() ` on their :class:`SSHServer` object and call :meth:`create_connection() ` on their :class:`SSHServerConnection` for each new connection, passing it a factory which returns instances of this class. When a connection is successfully opened, :meth:`session_started` will be called, after which the application can begin sending data. Received data will be passed to the :meth:`data_received` method. """ def connection_made(self, chan: 'SSHTCPChannel[AnyStr]') -> None: """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHTCPChannel` """ class SSHUNIXSession(SSHSession[AnyStr]): """SSH UNIX domain socket session handler Applications should subclass this when implementing a handler for SSH direct or forwarded UNIX domain socket connections. SSH client applications wishing to open a direct connection should call :meth:`create_unix_connection() ` on their :class:`SSHClientConnection`, passing in a factory which returns instances of this class. Server applications wishing to allow direct connections should implement the coroutine :meth:`unix_connection_requested() ` on their :class:`SSHServer` object and have it return instances of this class. Server applications wishing to allow connection forwarding back to the client should implement the coroutine :meth:`unix_server_requested() ` on their :class:`SSHServer` object and call :meth:`create_unix_connection() ` on their :class:`SSHServerConnection` for each new connection, passing it a factory which returns instances of this class. When a connection is successfully opened, :meth:`session_started` will be called, after which the application can begin sending data. Received data will be passed to the :meth:`data_received` method. """ def connection_made(self, chan: 'SSHUNIXChannel[AnyStr]') -> None: """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHUNIXChannel` """ class SSHTunTapSession(SSHSession[bytes]): """SSH TUN/TAP session handler Applications should subclass this when implementing a handler for SSH TUN/TAP tunnels. SSH client applications wishing to open a tunnel should call :meth:`create_tun() ` or :meth:`create_tap() ` on their :class:`SSHClientConnection`, passing in a factory which returns instances of this class. Server applications wishing to allow tunnel connections should implement the coroutine :meth:`tun_requested() ` or :meth:`tap_requested() ` on their :class:`SSHServer` object and have it return instances of this class. When a connection is successfully opened, :meth:`session_started` will be called, after which the application can begin sending data. Received data will be passed to the :meth:`data_received` method. """ def connection_made(self, chan: 'SSHTunTapChannel') -> None: """Called when a channel is opened successfully This method is called when a channel is opened successfully. The channel parameter should be stored if needed for later use. :param chan: The channel which was successfully opened. :type chan: :class:`SSHTunTapChannel` """ SSHSessionFactory = Callable[[], SSHSession[AnyStr]] SSHClientSessionFactory = Callable[[], SSHClientSession[AnyStr]] SSHTCPSessionFactory = Callable[[], SSHTCPSession[AnyStr]] SSHUNIXSessionFactory = Callable[[], SSHUNIXSession[AnyStr]] SSHTunTapSessionFactory = Callable[[], SSHTunTapSession] asyncssh-2.20.0/asyncssh/sftp.py000066400000000000000000010545461475467777400166620ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Jonathan Slenders - proposed changes to allow SFTP server callbacks # to be coroutines """SFTP handlers""" import asyncio import errno from fnmatch import fnmatch import inspect import os from os import SEEK_SET, SEEK_CUR, SEEK_END from pathlib import PurePath import posixpath import stat import sys import time from types import TracebackType from typing import TYPE_CHECKING, AnyStr, AsyncIterator, Awaitable, Callable from typing import Dict, Generic, IO, Iterable, List, Mapping, Optional from typing import Sequence, Set, Tuple, Type, TypeVar, Union, cast, overload from typing_extensions import Literal, Protocol, Self from . import constants from .constants import DEFAULT_LANG from .constants import FXP_INIT, FXP_VERSION, FXP_OPEN, FXP_CLOSE, FXP_READ from .constants import FXP_WRITE, FXP_LSTAT, FXP_FSTAT, FXP_SETSTAT from .constants import FXP_FSETSTAT, FXP_OPENDIR, FXP_READDIR, FXP_REMOVE from .constants import FXP_MKDIR, FXP_RMDIR, FXP_REALPATH, FXP_STAT, FXP_RENAME from .constants import FXP_READLINK, FXP_SYMLINK, FXP_LINK, FXP_BLOCK from .constants import FXP_UNBLOCK, FXP_STATUS, FXP_HANDLE, FXP_DATA from .constants import FXP_NAME, FXP_ATTRS, FXP_EXTENDED, FXP_EXTENDED_REPLY from .constants import FXR_OVERWRITE from .constants import FXRP_NO_CHECK, FXRP_STAT_IF_EXISTS, FXRP_STAT_ALWAYS from .constants import FXF_READ, FXF_WRITE, FXF_APPEND from .constants import FXF_CREAT, FXF_TRUNC, FXF_EXCL from .constants import FXF_ACCESS_DISPOSITION, FXF_CREATE_NEW from .constants import FXF_CREATE_TRUNCATE, FXF_OPEN_EXISTING from .constants import FXF_OPEN_OR_CREATE, FXF_TRUNCATE_EXISTING from .constants import FXF_APPEND_DATA from .constants import ACE4_READ_DATA, ACE4_WRITE_DATA, ACE4_APPEND_DATA from .constants import ACE4_READ_ATTRIBUTES, ACE4_WRITE_ATTRIBUTES from .constants import FILEXFER_ATTR_SIZE, FILEXFER_ATTR_UIDGID from .constants import FILEXFER_ATTR_PERMISSIONS, FILEXFER_ATTR_ACMODTIME from .constants import FILEXFER_ATTR_EXTENDED, FILEXFER_ATTR_DEFINED_V3 from .constants import FILEXFER_ATTR_ACCESSTIME, FILEXFER_ATTR_CREATETIME from .constants import FILEXFER_ATTR_MODIFYTIME, FILEXFER_ATTR_ACL from .constants import FILEXFER_ATTR_OWNERGROUP, FILEXFER_ATTR_SUBSECOND_TIMES from .constants import FILEXFER_ATTR_DEFINED_V4 from .constants import FILEXFER_ATTR_BITS, FILEXFER_ATTR_DEFINED_V5 from .constants import FILEXFER_ATTR_ALLOCATION_SIZE, FILEXFER_ATTR_TEXT_HINT from .constants import FILEXFER_ATTR_MIME_TYPE, FILEXFER_ATTR_LINK_COUNT from .constants import FILEXFER_ATTR_UNTRANSLATED_NAME, FILEXFER_ATTR_CTIME from .constants import FILEXFER_ATTR_DEFINED_V6 from .constants import FX_OK, FX_EOF, FX_NO_SUCH_FILE, FX_PERMISSION_DENIED from .constants import FX_FAILURE, FX_BAD_MESSAGE, FX_NO_CONNECTION from .constants import FX_CONNECTION_LOST, FX_OP_UNSUPPORTED, FX_V3_END from .constants import FX_INVALID_HANDLE, FX_NO_SUCH_PATH from .constants import FX_FILE_ALREADY_EXISTS, FX_WRITE_PROTECT, FX_NO_MEDIA from .constants import FX_V4_END, FX_NO_SPACE_ON_FILESYSTEM, FX_QUOTA_EXCEEDED from .constants import FX_UNKNOWN_PRINCIPAL, FX_LOCK_CONFLICT, FX_V5_END from .constants import FX_DIR_NOT_EMPTY, FX_NOT_A_DIRECTORY from .constants import FX_INVALID_FILENAME, FX_LINK_LOOP, FX_CANNOT_DELETE from .constants import FX_INVALID_PARAMETER, FX_FILE_IS_A_DIRECTORY from .constants import FX_BYTE_RANGE_LOCK_CONFLICT, FX_BYTE_RANGE_LOCK_REFUSED from .constants import FX_DELETE_PENDING, FX_FILE_CORRUPT, FX_OWNER_INVALID from .constants import FX_GROUP_INVALID, FX_NO_MATCHING_BYTE_RANGE_LOCK from .constants import FX_V6_END from .constants import FILEXFER_TYPE_REGULAR, FILEXFER_TYPE_DIRECTORY from .constants import FILEXFER_TYPE_SYMLINK, FILEXFER_TYPE_SPECIAL from .constants import FILEXFER_TYPE_UNKNOWN, FILEXFER_TYPE_SOCKET from .constants import FILEXFER_TYPE_CHAR_DEVICE, FILEXFER_TYPE_BLOCK_DEVICE from .constants import FILEXFER_TYPE_FIFO from .logging import SSHLogger from .misc import BytesOrStr, Error, FilePath, MaybeAwait, OptExcInfo, Record from .misc import ConnectionLost from .misc import async_context_manager, get_symbol_names, hide_empty, plural from .packet import Boolean, Byte, String, UInt16, UInt32, UInt64 from .packet import PacketDecodeError, SSHPacket, SSHPacketLogger from .version import __author__, __version__ if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHServerChannel from .connection import SSHClientConnection, SSHServerConnection from .stream import SSHReader, SSHWriter if TYPE_CHECKING: _RequestWaiter = asyncio.Future[Tuple[int, SSHPacket]] else: _RequestWaiter = asyncio.Future if sys.platform == 'win32': # pragma: no cover _LocalPath = str else: _LocalPath = bytes _SFTPFileObj = IO[bytes] _SFTPPath = Union[bytes, FilePath] _SFTPPaths = Union[_SFTPPath, Sequence[_SFTPPath]] _SFTPPatList = List[Union[bytes, List[bytes]]] _SFTPStatFunc = Callable[[_SFTPPath], Awaitable['SFTPAttrs']] _SFTPClientFileOrPath = Union['SFTPClientFile', _SFTPPath] _SFTPNames = Tuple[Sequence['SFTPName'], bool] _SFTPOSAttrs = Union[os.stat_result, 'SFTPAttrs'] _SFTPOSVFSAttrs = Union[os.statvfs_result, 'SFTPVFSAttrs'] _SFTPOnErrorHandler = Optional[Callable[[Callable, bytes, OptExcInfo], None]] _SFTPPacketHandler = Optional[Callable[['SFTPServerHandler', SSHPacket], Awaitable[object]]] SFTPErrorHandler = Union[None, Literal[False], Callable[[Exception], None]] SFTPProgressHandler = Optional[Callable[[bytes, bytes, int, int], None]] _T = TypeVar('_T') MIN_SFTP_VERSION = 3 MAX_SFTP_VERSION = 6 SAFE_SFTP_READ_LEN = 16*1024 # 16 KiB SAFE_SFTP_WRITE_LEN = 16*1024 # 16 KiB MAX_SFTP_READ_LEN = 4*1024*1024 # 4 MiB MAX_SFTP_WRITE_LEN = 4*1024*1024 # 4 MiB MAX_SFTP_PACKET_LEN = MAX_SFTP_WRITE_LEN + 1024 _COPY_DATA_BLOCK_SIZE = 256*1024 # 256 KiB _MAX_SFTP_REQUESTS = 128 _MAX_READDIR_NAMES = 128 _NSECS_IN_SEC = 1_000_000_000 _const_dict: Mapping[str, int] = constants.__dict__ _valid_attr_flags = { 3: FILEXFER_ATTR_DEFINED_V3, 4: FILEXFER_ATTR_DEFINED_V4, 5: FILEXFER_ATTR_DEFINED_V5, 6: FILEXFER_ATTR_DEFINED_V6 } _open_modes = { 'r': FXF_READ, 'w': FXF_WRITE | FXF_CREAT | FXF_TRUNC, 'a': FXF_WRITE | FXF_CREAT | FXF_APPEND, 'x': FXF_WRITE | FXF_CREAT | FXF_EXCL, 'r+': FXF_READ | FXF_WRITE, 'w+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_TRUNC, 'a+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_APPEND, 'x+': FXF_READ | FXF_WRITE | FXF_CREAT | FXF_EXCL } _file_types = {k: v.lower() for k, v in get_symbol_names(_const_dict, 'FILEXFER_TYPE_', 14).items()} class _SupportsEncode(Protocol): """Protocol for applying encoding to path names""" def encode(self, sftp_version: int) -> bytes: """Encode result as bytes in an SSH packet""" class _SFTPGlobProtocol(Protocol): """Protocol for getting files to perform glob matching against""" async def stat(self, path: bytes) -> 'SFTPAttrs': """Get attributes of a file""" def scandir(self, path: bytes) -> AsyncIterator['SFTPName']: """Return names and attributes of the files in a directory""" class SFTPFileProtocol(Protocol): """Protocol for accessing a file via an SFTP server""" async def __aenter__(self) -> Self: """Allow SFTPFileProtocol to be used as an async context manager""" async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: """Wait for file close when used as an async context manager""" async def read(self, size: int, offset: int) -> bytes: """Read data from the local file""" async def write(self, data: bytes, offset: int) -> int: """Write data to the local file""" async def close(self) -> None: """Close the local file""" class _SFTPFSProtocol(Protocol): """Protocol for accessing a filesystem via an SFTP server""" @property def limits(self) -> 'SFTPLimits': """SFTP server limits associated with this SFTP session""" @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" def encode(self, path: _SFTPPath) -> bytes: """Encode path name using configured path encoding""" def compose_path(self, path: bytes, parent: Optional[bytes] = None) -> bytes: """Compose a path""" async def stat(self, path: bytes, *, follow_symlinks: bool = True) -> 'SFTPAttrs': """Get attributes of a file, directory, or symlink""" async def setstat(self, path: bytes, attrs: 'SFTPAttrs', *, follow_symlinks: bool = True) -> None: """Set attributes of a file, directory, or symlink""" async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" def scandir(self, path: bytes) -> AsyncIterator['SFTPName']: """Return names and attributes of the files in a directory""" async def mkdir(self, path: bytes) -> None: """Create a directory""" async def readlink(self, path: bytes) -> bytes: """Return the target of a symbolic link""" async def symlink(self, oldpath: bytes, newpath: bytes) -> None: """Create a symbolic link""" @async_context_manager async def open(self, path: bytes, mode: str, block_size: int = -1) -> SFTPFileProtocol: """Open a file""" def _parse_acl_supported(data: bytes) -> int: """Parse an SFTPv6 "acl-supported" extension""" packet = SSHPacket(data) capabilities = packet.get_uint32() packet.check_end() return capabilities def _parse_supported(data: bytes) -> \ Tuple[int, int, int, int, int, Sequence[bytes]]: """Parse an SFTPv5 "supported" extension""" packet = SSHPacket(data) attr_mask = packet.get_uint32() attrib_mask = packet.get_uint32() open_flags = packet.get_uint32() access_mask = packet.get_uint32() max_read_size = packet.get_uint32() ext_names: List[bytes] = [] while packet: name = packet.get_string() ext_names.append(name) return (attr_mask, attrib_mask, open_flags, access_mask, max_read_size, ext_names) def _parse_supported2(data: bytes) -> Tuple[int, int, int, int, int, int, int, Sequence[bytes], Sequence[bytes]]: """Parse an SFTPv6 "supported2" extension""" packet = SSHPacket(data) attr_mask = packet.get_uint32() attrib_mask = packet.get_uint32() open_flags = packet.get_uint32() access_mask = packet.get_uint32() max_read_size = packet.get_uint32() open_block_vector = packet.get_uint16() block_vector = packet.get_uint16() attrib_ext_count = packet.get_uint32() attrib_ext_names: List[bytes] = [] for _ in range(attrib_ext_count): attrib_ext_names.append(packet.get_string()) ext_count = packet.get_uint32() ext_names: List[bytes] = [] for _ in range(ext_count): ext_names.append(packet.get_string()) packet.check_end() return (attr_mask, attrib_mask, open_flags, access_mask, max_read_size, open_block_vector, block_vector, attrib_ext_names, ext_names) def _parse_vendor_id(data: bytes) -> Tuple[str, str, str, int]: """Parse a "vendor-id" extension""" packet = SSHPacket(data) vendor_name = packet.get_string().decode('utf-8', 'backslashreplace') product_name = packet.get_string().decode('utf-8', 'backslashreplace') product_version = packet.get_string().decode('utf-8', 'backslashreplace') product_build = packet.get_uint64() return vendor_name, product_name, product_version, product_build def _stat_mode_to_filetype(mode: int) -> int: """Convert stat mode/permissions to file type""" if stat.S_ISREG(mode): filetype = FILEXFER_TYPE_REGULAR elif stat.S_ISDIR(mode): filetype = FILEXFER_TYPE_DIRECTORY elif stat.S_ISLNK(mode): filetype = FILEXFER_TYPE_SYMLINK elif stat.S_ISSOCK(mode): filetype = FILEXFER_TYPE_SOCKET elif stat.S_ISCHR(mode): filetype = FILEXFER_TYPE_CHAR_DEVICE elif stat.S_ISBLK(mode): filetype = FILEXFER_TYPE_BLOCK_DEVICE elif stat.S_ISFIFO(mode): filetype = FILEXFER_TYPE_FIFO elif stat.S_IFMT(mode) != 0: filetype = FILEXFER_TYPE_SPECIAL else: filetype = FILEXFER_TYPE_UNKNOWN return filetype def _nsec_to_tuple(nsec: int) -> Tuple[int, int]: """Convert nanoseconds since epoch to seconds & remainder""" return divmod(nsec, _NSECS_IN_SEC) def _float_sec_to_tuple(sec: float) -> Tuple[int, int]: """Convert float seconds since epoch to seconds & remainder""" return (int(sec), int((sec % 1) * _NSECS_IN_SEC)) def _tuple_to_float_sec(sec: int, nsec: Optional[int]) -> float: """Convert seconds and remainder to float seconds since epoch""" return sec + float(nsec or 0) / _NSECS_IN_SEC def _tuple_to_nsec(sec: int, nsec: Optional[int]) -> int: """Convert seconds and remainder to nanoseconds since epoch""" return sec * _NSECS_IN_SEC + (nsec or 0) def _utime_to_attrs(times: Optional[Tuple[float, float]] = None, ns: Optional[Tuple[int, int]] = None) -> 'SFTPAttrs': """Convert utime arguments to SFTPAttrs""" if ns: atime, atime_ns = _nsec_to_tuple(ns[0]) mtime, mtime_ns = _nsec_to_tuple(ns[1]) elif times: atime, atime_ns = _float_sec_to_tuple(times[0]) mtime, mtime_ns = _float_sec_to_tuple(times[1]) else: if hasattr(time, 'time_ns'): atime, atime_ns = _nsec_to_tuple(time.time_ns()) else: # pragma: no cover atime, atime_ns = _float_sec_to_tuple(time.time()) mtime, mtime_ns = atime, atime_ns return SFTPAttrs(atime=atime, atime_ns=atime_ns, mtime=mtime, mtime_ns=mtime_ns) def _lookup_uid(user: Optional[str]) -> Optional[int]: """Return the uid associated with a user name""" if user is not None: try: # pylint: disable=import-outside-toplevel import pwd uid = pwd.getpwnam(user).pw_uid except (ImportError, KeyError): try: uid = int(user) except ValueError: raise SFTPOwnerInvalid(f'Invalid owner: {user}') from None else: uid = None return uid def _lookup_gid(group: Optional[str]) -> Optional[int]: """Return the gid associated with a group name""" if group is not None: try: # pylint: disable=import-outside-toplevel import grp gid = grp.getgrnam(group).gr_gid except (ImportError, KeyError): try: gid = int(group) except ValueError: raise SFTPGroupInvalid(f'Invalid group: {group}') from None else: gid = None return gid def _lookup_user(uid: Optional[int]) -> str: """Return the user name associated with a uid""" if uid is not None: try: # pylint: disable=import-outside-toplevel import pwd user = pwd.getpwuid(uid).pw_name except (ImportError, KeyError): user = str(uid) else: user = '' return user def _lookup_group(gid: Optional[int]) -> str: """Return the group name associated with a gid""" if gid is not None: try: # pylint: disable=import-outside-toplevel import grp group = grp.getgrgid(gid).gr_name except (ImportError, KeyError): group = str(gid) else: group = '' return group def _mode_to_pflags(mode: str) -> Tuple[int, bool]: """Convert open mode to SFTP open flags""" if 'b' in mode: mode = mode.replace('b', '') binary = True else: binary = False pflags = _open_modes.get(mode) if not pflags: raise ValueError(f'Invalid mode: {mode!r}') return pflags, binary def _pflags_to_flags(pflags: int) -> Tuple[int, int]: """Convert SFTPv3 pflags to SFTPv5 desired-access and flags""" desired_access = 0 flags = 0 if pflags & (FXF_CREAT | FXF_EXCL) == (FXF_CREAT | FXF_EXCL): flags = FXF_CREATE_NEW elif pflags & (FXF_CREAT | FXF_TRUNC) == (FXF_CREAT | FXF_TRUNC): flags = FXF_CREATE_TRUNCATE elif pflags & FXF_CREAT: flags = FXF_OPEN_OR_CREATE elif pflags & FXF_TRUNC: flags = FXF_TRUNCATE_EXISTING else: flags = FXF_OPEN_EXISTING if pflags & FXF_READ: desired_access |= ACE4_READ_DATA | ACE4_READ_ATTRIBUTES if pflags & FXF_WRITE: desired_access |= ACE4_WRITE_DATA | ACE4_WRITE_ATTRIBUTES if pflags & FXF_APPEND: desired_access |= ACE4_APPEND_DATA flags |= FXF_APPEND_DATA return desired_access, flags def _from_local_path(path: _SFTPPath) -> bytes: """Convert local path to SFTP path""" path = os.fsencode(path) if sys.platform == 'win32': # pragma: no cover path = path.replace(b'\\', b'/') if path[:1] != b'/' and path[1:2] == b':': path = b'/' + path return path def _to_local_path(path: bytes) -> _LocalPath: """Convert SFTP path to local path""" if sys.platform == 'win32': # pragma: no cover path = os.fsdecode(path) if path[:1] == '/' and path[2:3] == ':': path = path[1:] path = path.replace('/', '\\') else: path = os.fsencode(path) return path def _setstat(path: Union[int, _SFTPPath], attrs: 'SFTPAttrs', *, follow_symlinks: bool = True) -> None: """Utility function to set file attributes""" if attrs.size is not None: os.truncate(path, attrs.size) uid = _lookup_uid(attrs.owner) if attrs.uid is None else attrs.uid gid = _lookup_gid(attrs.group) if attrs.gid is None else attrs.gid atime_ns = _tuple_to_nsec(attrs.atime, attrs.atime_ns) \ if attrs.atime is not None else None mtime_ns = _tuple_to_nsec(attrs.mtime, attrs.mtime_ns) \ if attrs.mtime is not None else None if ((atime_ns is None and mtime_ns is not None) or (atime_ns is not None and mtime_ns is None)): stat_result = os.stat(path, follow_symlinks=follow_symlinks) if atime_ns is None and mtime_ns is not None: atime_ns = stat_result.st_atime_ns if atime_ns is not None and mtime_ns is None: mtime_ns = stat_result.st_mtime_ns if uid is not None and gid is not None: try: os.chown(path, uid, gid, follow_symlinks=follow_symlinks) except NotImplementedError: # pragma: no cover pass except AttributeError: # pragma: no cover raise NotImplementedError from None if attrs.permissions is not None: try: os.chmod(path, stat.S_IMODE(attrs.permissions), follow_symlinks=follow_symlinks) except NotImplementedError: # pragma: no cover pass if atime_ns is not None and mtime_ns is not None: try: os.utime(path, ns=(atime_ns, mtime_ns), follow_symlinks=follow_symlinks) except NotImplementedError: # pragma: no cover pass class _SFTPParallelIO(Generic[_T]): """Parallelize I/O requests on files This class issues parallel read and write requests on files. """ def __init__(self, block_size: int, max_requests: int, offset: int, size: int): self._block_size = block_size self._max_requests = max_requests self._offset = offset self._bytes_left = size self._pending: Set['asyncio.Task[Tuple[int, int, int, _T]]'] = set() async def _start_task(self, offset: int, size: int) -> \ Tuple[int, int, int, _T]: """Start a task to perform file I/O on a particular byte range""" count, result = await self.run_task(offset, size) return offset, size, count, result def _start_tasks(self) -> None: """Create parallel file I/O tasks""" while self._bytes_left and len(self._pending) < self._max_requests: size = min(self._bytes_left, self._block_size) task = asyncio.ensure_future(self._start_task(self._offset, size)) self._pending.add(task) self._offset += size self._bytes_left -= size async def run_task(self, offset: int, size: int) -> Tuple[int, _T]: """Perform file I/O on a particular byte range""" raise NotImplementedError async def iter(self) -> AsyncIterator[Tuple[int, _T]]: """Perform file I/O and return async iterator of results""" self._start_tasks() while self._pending: done, self._pending = await asyncio.wait( self._pending, return_when=asyncio.FIRST_COMPLETED) exceptions = [] for task in done: try: offset, size, count, result = task.result() yield offset, result if count and count < size: self._pending.add(asyncio.ensure_future( self._start_task(offset+count, size-count))) except SFTPEOFError: self._bytes_left = 0 except (OSError, SFTPError) as exc: exceptions.append(exc) if exceptions: for task in self._pending: task.cancel() raise exceptions[0] self._start_tasks() class _SFTPFileReader(_SFTPParallelIO[bytes]): """Parallelized SFTP file reader""" def __init__(self, block_size: int, max_requests: int, handler: 'SFTPClientHandler', handle: bytes, offset: int, size: int): super().__init__(block_size, max_requests, offset, size) self._handler = handler self._handle = handle self._start = offset async def run_task(self, offset: int, size: int) -> Tuple[int, bytes]: """Read a block of the file""" data, _ = await self._handler.read(self._handle, offset, size) return len(data), data async def run(self) -> bytes: """Reassemble and return data from parallel reads""" result = bytearray() async for offset, data in self.iter(): pos = offset - self._start pad = pos - len(result) if pad > 0: result += pad * b'\0' result[pos:pos+len(data)] = data return bytes(result) class _SFTPFileWriter(_SFTPParallelIO[int]): """Parallelized SFTP file writer""" def __init__(self, block_size: int, max_requests: int, handler: 'SFTPClientHandler', handle: bytes, offset: int, data: bytes): super().__init__(block_size, max_requests, offset, len(data)) self._handler = handler self._handle = handle self._start = offset self._data = data async def run_task(self, offset: int, size: int) -> Tuple[int, int]: """Write a block to the file""" pos = offset - self._start await self._handler.write(self._handle, offset, self._data[pos:pos+size]) return size, size async def run(self): """Perform parallel writes""" async for _ in self.iter(): pass class _SFTPFileCopier(_SFTPParallelIO[int]): """SFTP file copier This class parforms an SFTP file copy, initiating multiple read and write requests to copy chunks of the file in parallel. """ def __init__(self, block_size: int, max_requests: int, offset: int, total_bytes: int, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, srcpath: bytes, dstpath: bytes, progress_handler: SFTPProgressHandler): super().__init__(block_size, max_requests, offset, total_bytes) self._srcfs = srcfs self._dstfs = dstfs self._srcpath = srcpath self._dstpath = dstpath self._src: Optional[SFTPFileProtocol] = None self._dst: Optional[SFTPFileProtocol] = None self._bytes_copied = 0 self._total_bytes = total_bytes self._progress_handler = progress_handler async def run_task(self, offset: int, size: int) -> Tuple[int, int]: """Copy a block of the source file""" assert self._src is not None assert self._dst is not None data = await self._src.read(size, offset) await self._dst.write(data, offset) datalen = len(data) return datalen, datalen async def run(self) -> None: """Perform parallel file copy""" try: self._src = await self._srcfs.open(self._srcpath, 'rb', block_size=0) self._dst = await self._dstfs.open(self._dstpath, 'wb', block_size=0) if self._progress_handler and self._total_bytes == 0: self._progress_handler(self._srcpath, self._dstpath, 0, 0) if self._srcfs == self._dstfs and \ isinstance(self._srcfs, SFTPClient) and \ self._srcfs.supports_remote_copy: await self._srcfs.remote_copy(cast(SFTPClientFile, self._src), cast(SFTPClientFile, self._dst)) self._bytes_copied = self._total_bytes if self._progress_handler: self._progress_handler(self._srcpath, self._dstpath, self._bytes_copied, self._total_bytes) else: async for _, datalen in self.iter(): if datalen: self._bytes_copied += datalen if self._progress_handler: self._progress_handler(self._srcpath, self._dstpath, self._bytes_copied, self._total_bytes) if self._bytes_copied != self._total_bytes: exc = SFTPFailure('Unexpected EOF during file copy') setattr(exc, 'filename', self._srcpath) setattr(exc, 'offset', self._bytes_copied) raise exc finally: if self._src: # pragma: no branch await self._src.close() if self._dst: # pragma: no branch await self._dst.close() class SFTPError(Error): """SFTP error This exception is raised when an error occurs while processing an SFTP request. Exception codes should be taken from :ref:`SFTP error codes `. :param code: Disconnect reason, taken from :ref:`disconnect reason codes ` :param reason: A human-readable reason for the disconnect :param lang: (optional) The language the reason is in :type code: `int` :type reason: `str` :type lang: `str` """ @staticmethod def construct(packet: SSHPacket) -> Optional['SFTPError']: """Construct an SFTPError from an FXP_STATUS response""" code = packet.get_uint32() if packet: try: reason = packet.get_string().decode('utf-8') lang = packet.get_string().decode('ascii') except UnicodeDecodeError: raise SFTPBadMessage('Invalid status message') from None else: # Some servers may not always send reason and lang (usually # when responding with FX_OK). Tolerate this, automatically # filling in empty strings for them if they're not present. reason = '' lang = '' if code == FX_OK: return None else: try: exc = _sftp_error_map[code](reason, lang) except KeyError: exc = SFTPError(code, f'{reason} (error {code})', lang) exc.decode(packet) return exc def encode(self, version: int) -> bytes: """Encode an SFTPError as bytes in an SSHPacket""" if self.code == FX_NOT_A_DIRECTORY and version < 6: code = FX_NO_SUCH_FILE elif (self.code <= FX_V6_END and ((self.code > FX_V3_END and version <= 3) or (self.code > FX_V4_END and version <= 4) or (self.code > FX_V5_END and version <= 5))): code = FX_FAILURE else: code = self.code return UInt32(code) + String(self.reason) + String(self.lang) def decode(self, packet: SSHPacket) -> None: """Decode error-specific data""" # pylint: disable=no-self-use # By default, expect no error-specific data class SFTPEOFError(SFTPError): """SFTP EOF error This exception is raised when end of file is reached when reading a file or directory. :param reason: (optional) Details about the EOF :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str = '', lang: str = DEFAULT_LANG): super().__init__(FX_EOF, reason, lang) class SFTPNoSuchFile(SFTPError): """SFTP no such file This exception is raised when the requested file is not found. :param reason: Details about the missing file :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_SUCH_FILE, reason, lang) class SFTPPermissionDenied(SFTPError): """SFTP permission denied This exception is raised when the permissions are not available to perform the requested operation. :param reason: Details about the invalid permissions :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_PERMISSION_DENIED, reason, lang) class SFTPFailure(SFTPError): """SFTP failure This exception is raised when an unexpected SFTP failure occurs. :param reason: Details about the failure :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_FAILURE, reason, lang) class SFTPBadMessage(SFTPError): """SFTP bad message This exception is raised when an invalid SFTP message is received. :param reason: Details about the invalid message :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_BAD_MESSAGE, reason, lang) class SFTPNoConnection(SFTPError): """SFTP no connection This exception is raised when an SFTP request is made on a closed SSH connection. :param reason: Details about the closed connection :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_CONNECTION, reason, lang) class SFTPConnectionLost(SFTPError): """SFTP connection lost This exception is raised when the SSH connection is lost or closed while making an SFTP request. :param reason: Details about the connection failure :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_CONNECTION_LOST, reason, lang) class SFTPOpUnsupported(SFTPError): """SFTP operation unsupported This exception is raised when the requested SFTP operation is not supported. :param reason: Details about the unsupported operation :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_OP_UNSUPPORTED, reason, lang) class SFTPInvalidHandle(SFTPError): """SFTP invalid handle (SFTPv4+) This exception is raised when the handle provided is invalid. :param reason: Details about the invalid handle :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_INVALID_HANDLE, reason, lang) class SFTPNoSuchPath(SFTPError): """SFTP no such path (SFTPv4+) This exception is raised when the requested path is not found. :param reason: Details about the missing path :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_SUCH_PATH, reason, lang) class SFTPFileAlreadyExists(SFTPError): """SFTP file already exists (SFTPv4+) This exception is raised when the requested file already exists. :param reason: Details about the existing file :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_FILE_ALREADY_EXISTS, reason, lang) class SFTPWriteProtect(SFTPError): """SFTP write protect (SFTPv4+) This exception is raised when a write is attempted to a file on read-only or write protected media. :param reason: Details about the requested file :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_WRITE_PROTECT, reason, lang) class SFTPNoMedia(SFTPError): """SFTP no media (SFTPv4+) This exception is raised when there is no media in the requested drive. :param reason: Details about the requested drive :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_MEDIA, reason, lang) class SFTPNoSpaceOnFilesystem(SFTPError): """SFTP no space on filesystem (SFTPv5+) This exception is raised when there is no space available on the filesystem a file is being written to. :param reason: Details about the filesystem which has filled up :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_SPACE_ON_FILESYSTEM, reason, lang) class SFTPQuotaExceeded(SFTPError): """SFTP quota exceeded (SFTPv5+) This exception is raised when the user's storage quota is exceeded. :param reason: Details about the exceeded quota :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_QUOTA_EXCEEDED, reason, lang) class SFTPUnknownPrincipal(SFTPError): """SFTP unknown principal (SFTPv5+) This exception is raised when a file owner or group is not reocgnized. :param reason: Details about the unknown principal :param lang: (optional) The language the reason is in :param unknown_names: (optional) A list of unknown principal names :type reason: `str` :type lang: `str` :type unknown_names: list of `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG, unknown_names: Sequence[str] = ()): super().__init__(FX_UNKNOWN_PRINCIPAL, reason, lang) self.unknown_names = unknown_names def encode(self, version: int) -> bytes: """Encode an SFTPUnknownPrincipal as bytes in an SSHPacket""" return super().encode(version) + \ b''.join(String(name) for name in self.unknown_names) def decode(self, packet: SSHPacket) -> None: """Decode error-specific data""" self.unknown_names = [] try: while packet: self.unknown_names.append( packet.get_string().decode('utf-8')) except UnicodeDecodeError: raise SFTPBadMessage('Invalid status message') from None class SFTPLockConflict(SFTPError): """SFTP lock conflict (SFTPv5+) This exception is raised when a requested lock is held by another process. :param reason: Details about the conflicting lock :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_LOCK_CONFLICT, reason, lang) class SFTPDirNotEmpty(SFTPError): """SFTP directory not empty (SFTPv6+) This exception is raised when a directory is not empty. :param reason: Details about the non-empty directory :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_DIR_NOT_EMPTY, reason, lang) class SFTPNotADirectory(SFTPError): """SFTP not a directory (SFTPv6+) This exception is raised when a specified file is not a directory where one was expected. :param reason: Details about the file expected to be a directory :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NOT_A_DIRECTORY, reason, lang) class SFTPInvalidFilename(SFTPError): """SFTP invalid filename (SFTPv6+) This exception is raised when a filename is not valid. :param reason: Details about the invalid filename :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_INVALID_FILENAME, reason, lang) class SFTPLinkLoop(SFTPError): """SFTP link loop (SFTPv6+) This exception is raised when a symbolic link loop is detected. :param reason: Details about the link loop :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_LINK_LOOP, reason, lang) class SFTPCannotDelete(SFTPError): """SFTP cannot delete (SFTPv6+) This exception is raised when a file cannot be deleted. :param reason: Details about the undeletable file :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_CANNOT_DELETE, reason, lang) class SFTPInvalidParameter(SFTPError): """SFTP invalid parameter (SFTPv6+) This exception is raised when parameters in a request are out of range or incompatible with one another. :param reason: Details about the invalid parameter :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_INVALID_PARAMETER, reason, lang) class SFTPFileIsADirectory(SFTPError): """SFTP file is a directory (SFTPv6+) This exception is raised when a specified file is a directory where one isn't allowed. :param reason: Details about the unexpected directory :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_FILE_IS_A_DIRECTORY, reason, lang) class SFTPByteRangeLockConflict(SFTPError): """SFTP byte range lock conflict (SFTPv6+) This exception is raised when a read or write request overlaps a byte range lock held by another process. :param reason: Details about the conflicting byte range lock :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_BYTE_RANGE_LOCK_CONFLICT, reason, lang) class SFTPByteRangeLockRefused(SFTPError): """SFTP byte range lock refused (SFTPv6+) This exception is raised when a request for a byte range lock was refused. :param reason: Details about the refused byte range lock :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_BYTE_RANGE_LOCK_REFUSED, reason, lang) class SFTPDeletePending(SFTPError): """SFTP delete pending (SFTPv6+) This exception is raised when an operation was attempted on a file for which a delete operation is pending. another process. :param reason: Details about the file being deleted :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_DELETE_PENDING, reason, lang) class SFTPFileCorrupt(SFTPError): """SFTP file corrupt (SFTPv6+) This exception is raised when filesystem corruption is detected. :param reason: Details about the corrupted filesystem :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_FILE_CORRUPT, reason, lang) class SFTPOwnerInvalid(SFTPError): """SFTP owner invalid (SFTPv6+) This exception is raised when a principal cannot be assigned as the owner of a file. :param reason: Details about the principal being set as a file's owner :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_OWNER_INVALID, reason, lang) class SFTPGroupInvalid(SFTPError): """SFTP group invalid (SFTPv6+) This exception is raised when a principal cannot be assigned as the primary group of a file. :param reason: Details about the principal being set as a file's group :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_GROUP_INVALID, reason, lang) class SFTPNoMatchingByteRangeLock(SFTPError): """SFTP no matching byte range lock (SFTPv6+) This exception is raised when an unlock is requested for a byte range lock which is not currently held. :param reason: Details about the byte range lock being released :param lang: (optional) The language the reason is in :type reason: `str` :type lang: `str` """ def __init__(self, reason: str, lang: str = DEFAULT_LANG): super().__init__(FX_NO_MATCHING_BYTE_RANGE_LOCK, reason, lang) _sftp_error_map: Dict[int, Callable[[str, str], SFTPError]] = { FX_EOF: SFTPEOFError, FX_NO_SUCH_FILE: SFTPNoSuchFile, FX_PERMISSION_DENIED: SFTPPermissionDenied, FX_FAILURE: SFTPFailure, FX_BAD_MESSAGE: SFTPBadMessage, FX_NO_CONNECTION: SFTPNoConnection, FX_CONNECTION_LOST: SFTPConnectionLost, FX_OP_UNSUPPORTED: SFTPOpUnsupported, FX_INVALID_HANDLE: SFTPInvalidHandle, FX_NO_SUCH_PATH: SFTPNoSuchPath, FX_FILE_ALREADY_EXISTS: SFTPFileAlreadyExists, FX_WRITE_PROTECT: SFTPWriteProtect, FX_NO_MEDIA: SFTPNoMedia, FX_NO_SPACE_ON_FILESYSTEM: SFTPNoSpaceOnFilesystem, FX_QUOTA_EXCEEDED: SFTPQuotaExceeded, FX_UNKNOWN_PRINCIPAL: SFTPUnknownPrincipal, FX_LOCK_CONFLICT: SFTPLockConflict, FX_DIR_NOT_EMPTY: SFTPDirNotEmpty, FX_NOT_A_DIRECTORY: SFTPNotADirectory, FX_INVALID_FILENAME: SFTPInvalidFilename, FX_LINK_LOOP: SFTPLinkLoop, FX_CANNOT_DELETE: SFTPCannotDelete, FX_INVALID_PARAMETER: SFTPInvalidParameter, FX_FILE_IS_A_DIRECTORY: SFTPFileIsADirectory, FX_BYTE_RANGE_LOCK_CONFLICT: SFTPByteRangeLockConflict, FX_BYTE_RANGE_LOCK_REFUSED: SFTPByteRangeLockRefused, FX_DELETE_PENDING: SFTPDeletePending, FX_FILE_CORRUPT: SFTPFileCorrupt, FX_OWNER_INVALID: SFTPOwnerInvalid, FX_GROUP_INVALID: SFTPGroupInvalid, FX_NO_MATCHING_BYTE_RANGE_LOCK: SFTPNoMatchingByteRangeLock } class SFTPAttrs(Record): """SFTP file attributes SFTPAttrs is a simple record class with the following fields: ============ ================================================= ====== Field Description Type ============ ================================================= ====== type File type (SFTPv4+) byte size File size in bytes uint64 alloc_size Allocation file size in bytes (SFTPv6+) uint64 uid User id of file owner uint32 gid Group id of file owner uint32 owner User name of file owner (SFTPv4+) string group Group name of file owner (SFTPv4+) string permissions Bit mask of POSIX file permissions uint32 atime Last access time, UNIX epoch seconds uint64 atime_ns Last access time, nanoseconds (SFTPv4+) uint32 crtime Creation time, UNIX epoch seconds (SFTPv4+) uint64 crtime_ns Creation time, nanoseconds (SFTPv4+) uint32 mtime Last modify time, UNIX epoch seconds uint64 mtime_ns Last modify time, nanoseconds (SFTPv4+) uint32 ctime Last change time, UNIX epoch seconds (SFTPv6+) uint64 ctime_ns Last change time, nanoseconds (SFTPv6+) uint32 acl Access control list for file (SFTPv4+) bytes attrib_bits Attribute bits set for file (SFTPv5+) uint32 attrib_valid Valid attribute bits for file (SFTPv5+) uint32 text_hint Text/binary hint for file (SFTPv6+) byte mime_type MIME type for file (SFTPv6+) string nlink Link count for file (SFTPv6+) uint32 untrans_name Untranslated name for file (SFTPv6+) bytes ============ ================================================= ====== Extended attributes can also be added via a field named `extended` which is a list of bytes name/value pairs. When setting attributes using an :class:`SFTPAttrs`, only fields which have been initialized will be changed on the selected file. """ type: int = FILEXFER_TYPE_UNKNOWN size: Optional[int] alloc_size: Optional[int] uid: Optional[int] gid: Optional[int] owner: Optional[str] group: Optional[str] permissions: Optional[int] atime: Optional[int] atime_ns: Optional[int] crtime: Optional[int] crtime_ns: Optional[int] mtime: Optional[int] mtime_ns: Optional[int] ctime: Optional[int] ctime_ns: Optional[int] acl: Optional[bytes] attrib_bits: Optional[int] attrib_valid: Optional[int] text_hint: Optional[int] mime_type: Optional[str] nlink: Optional[int] untrans_name: Optional[bytes] extended: Sequence[Tuple[bytes, bytes]] = () def _format_ns(self, k: str): """Convert epoch seconds & nanoseconds to a string date & time""" result = time.ctime(getattr(self, k)) nsec = getattr(self, k + '_ns') if result and nsec: result = result[:19] + f'.{nsec:09d}' + result[19:] return result def _format(self, k: str, v: object) -> Optional[str]: """Convert attributes to more readable values""" if v is None or k == 'extended' and not v: return None if k == 'type': return _file_types.get(cast(int, v), str(v)) \ if v != FILEXFER_TYPE_UNKNOWN else None elif k == 'permissions': return f'{cast(int, v):04o}' elif k in ('atime', 'crtime', 'mtime', 'ctime'): return self._format_ns(k) elif k in ('atime_ns', 'crtime_ns', 'mtime_ns', 'ctime_ns'): return None else: return str(v) or None def encode(self, sftp_version: int) -> bytes: """Encode SFTP attributes as bytes in an SSH packet""" flags = 0 attrs = [] if sftp_version >= 4: if sftp_version < 5 and self.type >= FILEXFER_TYPE_SOCKET: filetype = FILEXFER_TYPE_SPECIAL else: filetype = self.type attrs.append(Byte(filetype)) if self.size is not None: flags |= FILEXFER_ATTR_SIZE attrs.append(UInt64(self.size)) if self.alloc_size is not None: flags |= FILEXFER_ATTR_ALLOCATION_SIZE attrs.append(UInt64(self.alloc_size)) if sftp_version == 3: if self.uid is not None and self.gid is not None: flags |= FILEXFER_ATTR_UIDGID attrs.append(UInt32(self.uid) + UInt32(self.gid)) elif self.owner is not None and self.group is not None: raise ValueError('Setting owner and group requires SFTPv4 ' 'or later') else: if self.owner is not None and self.group is not None: flags |= FILEXFER_ATTR_OWNERGROUP attrs.append(String(self.owner) + String(self.group)) elif self.uid is not None and self.gid is not None: flags |= FILEXFER_ATTR_OWNERGROUP attrs.append(String(str(self.uid)) + String(str(self.gid))) if self.permissions is not None: flags |= FILEXFER_ATTR_PERMISSIONS attrs.append(UInt32(self.permissions)) if sftp_version == 3: if self.atime is not None and self.mtime is not None: flags |= FILEXFER_ATTR_ACMODTIME attrs.append(UInt32(int(self.atime)) + UInt32(int(self.mtime))) else: subsecond = (self.atime_ns is not None or self.crtime_ns is not None or self.mtime_ns is not None or self.ctime_ns is not None) if subsecond: flags |= FILEXFER_ATTR_SUBSECOND_TIMES if self.atime is not None: flags |= FILEXFER_ATTR_ACCESSTIME attrs.append(UInt64(int(self.atime))) if subsecond: attrs.append(UInt32(self.atime_ns or 0)) if self.crtime is not None: flags |= FILEXFER_ATTR_CREATETIME attrs.append(UInt64(int(self.crtime))) if subsecond: attrs.append(UInt32(self.crtime_ns or 0)) if self.mtime is not None: flags |= FILEXFER_ATTR_MODIFYTIME attrs.append(UInt64(int(self.mtime))) if subsecond: attrs.append(UInt32(self.mtime_ns or 0)) if sftp_version >= 6 and self.ctime is not None: flags |= FILEXFER_ATTR_CTIME attrs.append(UInt64(int(self.ctime))) if subsecond: attrs.append(UInt32(self.ctime_ns or 0)) if sftp_version >= 4 and self.acl is not None: flags |= FILEXFER_ATTR_ACL attrs.append(String(self.acl)) if sftp_version >= 5 and \ self.attrib_bits is not None and \ self.attrib_valid is not None: flags |= FILEXFER_ATTR_BITS attrs.append(UInt32(self.attrib_bits) + UInt32(self.attrib_valid)) if sftp_version >= 6: if self.text_hint is not None: flags |= FILEXFER_ATTR_TEXT_HINT attrs.append(Byte(self.text_hint)) if self.mime_type is not None: flags |= FILEXFER_ATTR_MIME_TYPE attrs.append(String(self.mime_type)) if self.nlink is not None: flags |= FILEXFER_ATTR_LINK_COUNT attrs.append(UInt32(self.nlink)) if self.untrans_name is not None: flags |= FILEXFER_ATTR_UNTRANSLATED_NAME attrs.append(String(self.untrans_name)) if self.extended: flags |= FILEXFER_ATTR_EXTENDED attrs.append(UInt32(len(self.extended))) attrs.extend(String(type) + String(data) for type, data in self.extended) return UInt32(flags) + b''.join(attrs) @classmethod def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPAttrs': """Decode bytes in an SSH packet as SFTP attributes""" flags = packet.get_uint32() attrs = cls() # Work around a bug seen in a Huawei SFTP server where # FILEXFER_ATTR_MODIFYTIME is included in flags, even though # the SFTP version is set to 3. That flag is only defined for # SFTPv4 and later. if sftp_version == 3 and flags & (FILEXFER_ATTR_ACMODTIME | FILEXFER_ATTR_MODIFYTIME): flags &= ~FILEXFER_ATTR_MODIFYTIME unsupported_attrs = flags & ~_valid_attr_flags[sftp_version] if unsupported_attrs: raise SFTPBadMessage( f'Unsupported attribute flags: 0x{unsupported_attrs:08x}') if sftp_version >= 4: attrs.type = packet.get_byte() if flags & FILEXFER_ATTR_SIZE: attrs.size = packet.get_uint64() if flags & FILEXFER_ATTR_ALLOCATION_SIZE: attrs.alloc_size = packet.get_uint64() if sftp_version == 3: if flags & FILEXFER_ATTR_UIDGID: attrs.uid = packet.get_uint32() attrs.gid = packet.get_uint32() else: if flags & FILEXFER_ATTR_OWNERGROUP: owner = packet.get_string() try: attrs.owner = owner.decode('utf-8') except UnicodeDecodeError: raise SFTPOwnerInvalid('Invalid owner name: ' + owner.decode('utf-8', 'backslashreplace')) from None group = packet.get_string() try: attrs.group = group.decode('utf-8') except UnicodeDecodeError: raise SFTPGroupInvalid('Invalid group name: ' + group.decode('utf-8', 'backslashreplace')) from None if flags & FILEXFER_ATTR_PERMISSIONS: mode = packet.get_uint32() if sftp_version == 3: attrs.type = _stat_mode_to_filetype(mode) attrs.permissions = mode & 0xffff else: attrs.permissions = mode & 0xfff if sftp_version == 3: if flags & FILEXFER_ATTR_ACMODTIME: attrs.atime = packet.get_uint32() attrs.mtime = packet.get_uint32() else: if flags & FILEXFER_ATTR_ACCESSTIME: attrs.atime = packet.get_uint64() if flags & FILEXFER_ATTR_SUBSECOND_TIMES: attrs.atime_ns = packet.get_uint32() if flags & FILEXFER_ATTR_CREATETIME: attrs.crtime = packet.get_uint64() if flags & FILEXFER_ATTR_SUBSECOND_TIMES: attrs.crtime_ns = packet.get_uint32() if flags & FILEXFER_ATTR_MODIFYTIME: attrs.mtime = packet.get_uint64() if flags & FILEXFER_ATTR_SUBSECOND_TIMES: attrs.mtime_ns = packet.get_uint32() if flags & FILEXFER_ATTR_CTIME: attrs.ctime = packet.get_uint64() if flags & FILEXFER_ATTR_SUBSECOND_TIMES: attrs.ctime_ns = packet.get_uint32() if flags & FILEXFER_ATTR_ACL: attrs.acl = packet.get_string() if flags & FILEXFER_ATTR_BITS: attrs.attrib_bits = packet.get_uint32() attrs.attrib_valid = packet.get_uint32() if flags & FILEXFER_ATTR_TEXT_HINT: attrs.text_hint = packet.get_byte() if flags & FILEXFER_ATTR_MIME_TYPE: try: attrs.mime_type = packet.get_string().decode('utf-8') except UnicodeDecodeError: raise SFTPBadMessage('Invalid MIME type') from None if flags & FILEXFER_ATTR_LINK_COUNT: attrs.nlink = packet.get_uint32() if flags & FILEXFER_ATTR_UNTRANSLATED_NAME: attrs.untrans_name = packet.get_string() if flags & FILEXFER_ATTR_EXTENDED: count = packet.get_uint32() attrs.extended = [] for _ in range(count): attr = packet.get_string() data = packet.get_string() attrs.extended.append((attr, data)) return attrs @classmethod def from_local(cls, result: os.stat_result) -> 'SFTPAttrs': """Convert from local stat attributes""" mode = result.st_mode filetype = _stat_mode_to_filetype(mode) if sys.platform == 'win32': # pragma: no cover uid = 0 gid = 0 owner = '' group = '' else: uid = result.st_uid gid = result.st_gid owner = _lookup_user(uid) group = _lookup_group(gid) atime, atime_ns = _nsec_to_tuple(result.st_atime_ns) mtime, mtime_ns = _nsec_to_tuple(result.st_mtime_ns) ctime, ctime_ns = _nsec_to_tuple(result.st_ctime_ns) if sys.platform == 'win32': # pragma: no cover crtime, crtime_ns = ctime, ctime_ns elif hasattr(result, 'st_birthtime'): # pragma: no cover crtime, crtime_ns = _float_sec_to_tuple(result.st_birthtime) else: # pragma: no cover crtime, crtime_ns = mtime, mtime_ns return cls(filetype, result.st_size, None, uid, gid, owner, group, mode, atime, atime_ns, crtime, crtime_ns, mtime, mtime_ns, ctime, ctime_ns, None, None, None, None, None, result.st_nlink, None) class SFTPVFSAttrs(Record): """SFTP file system attributes SFTPVFSAttrs is a simple record class with the following fields: ============ =========================================== ====== Field Description Type ============ =========================================== ====== bsize File system block size (I/O size) uint64 frsize Fundamental block size (allocation size) uint64 blocks Total data blocks (in frsize units) uint64 bfree Free data blocks uint64 bavail Available data blocks (for non-root) uint64 files Total file inodes uint64 ffree Free file inodes uint64 favail Available file inodes (for non-root) uint64 fsid File system id uint64 flags File system flags (read-only, no-setuid) uint64 namemax Maximum filename length uint64 ============ =========================================== ====== """ bsize: int = 0 frsize: int = 0 blocks: int = 0 bfree: int = 0 bavail: int = 0 files: int = 0 ffree: int = 0 favail: int = 0 fsid: int = 0 flags: int = 0 namemax: int = 0 def encode(self, sftp_version: int) -> bytes: """Encode SFTP statvfs attributes as bytes in an SSH packet""" # pylint: disable=unused-argument return b''.join((UInt64(self.bsize), UInt64(self.frsize), UInt64(self.blocks), UInt64(self.bfree), UInt64(self.bavail), UInt64(self.files), UInt64(self.ffree), UInt64(self.favail), UInt64(self.fsid), UInt64(self.flags), UInt64(self.namemax))) @classmethod def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPVFSAttrs': """Decode bytes in an SSH packet as SFTP statvfs attributes""" # pylint: disable=unused-argument vfsattrs = cls() vfsattrs.bsize = packet.get_uint64() vfsattrs.frsize = packet.get_uint64() vfsattrs.blocks = packet.get_uint64() vfsattrs.bfree = packet.get_uint64() vfsattrs.bavail = packet.get_uint64() vfsattrs.files = packet.get_uint64() vfsattrs.ffree = packet.get_uint64() vfsattrs.favail = packet.get_uint64() vfsattrs.fsid = packet.get_uint64() vfsattrs.flags = packet.get_uint64() vfsattrs.namemax = packet.get_uint64() return vfsattrs @classmethod def from_local(cls, result: os.statvfs_result) -> 'SFTPVFSAttrs': """Convert from local statvfs attributes""" return cls(result.f_bsize, result.f_frsize, result.f_blocks, result.f_bfree, result.f_bavail, result.f_files, result.f_ffree, result.f_favail, 0, result.f_flag, result.f_namemax) class SFTPName(Record): """SFTP file name and attributes SFTPName is a simple record class with the following fields: ========= ================================== ================== Field Description Type ========= ================================== ================== filename Filename `str` or `bytes` longname Expanded form of filename & attrs `str` or `bytes` attrs File attributes :class:`SFTPAttrs` ========= ================================== ================== A list of these is returned by :meth:`readdir() ` in :class:`SFTPClient` when retrieving the contents of a directory. """ filename: BytesOrStr = '' longname: BytesOrStr = '' attrs: SFTPAttrs = SFTPAttrs() def _format(self, k: str, v: object) -> Optional[str]: """Convert name fields to more readable values""" if k == 'longname' and not v: return None if isinstance(v, bytes): v = v.decode('utf-8', 'backslashreplace') return str(v) or None def encode(self, sftp_version: int) -> bytes: """Encode an SFTP name as bytes in an SSH packet""" longname = String(self.longname) if sftp_version == 3 else b'' return (String(self.filename) + longname + self.attrs.encode(sftp_version)) @classmethod def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPName': """Decode bytes in an SSH packet as an SFTP name""" filename = packet.get_string() longname = packet.get_string() if sftp_version == 3 else None attrs = SFTPAttrs.decode(packet, sftp_version) return cls(filename, longname, attrs) class SFTPLimits(Record): """SFTP server limits SFTPLimits is a simple record class with the following fields: ================= ========================================= ====== Field Description Type ================= ========================================= ====== max_packet_len Max allowed size of an SFTP packet uint64 max_read_len Max allowed size of an SFTP read request uint64 max_write_len Max allowed size of an SFTP write request uint64 max_open_handles Max allowed number of open file handles uint64 ================= ========================================= ====== """ max_packet_len: int max_read_len: int max_write_len: int max_open_handles: int def encode(self, sftp_version: int) -> bytes: """Encode SFTP server limits in an SSH packet""" # pylint: disable=unused-argument return (UInt64(self.max_packet_len) + UInt64(self.max_read_len) + UInt64(self.max_write_len) + UInt64(self.max_open_handles)) @classmethod def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPLimits': """Decode bytes in an SSH packet as SFTP server limits""" # pylint: disable=unused-argument max_packet_len = packet.get_uint64() max_read_len = packet.get_uint64() max_write_len = packet.get_uint64() max_open_handles = packet.get_uint64() return cls(max_packet_len, max_read_len, max_write_len, max_open_handles) class SFTPGlob: """SFTP glob matcher""" def __init__(self, fs: _SFTPGlobProtocol, multiple=False): self._fs = fs self._multiple = multiple self._prev_matches: Set[bytes] = set() self._new_matches: List[SFTPName] = [] self._matched = False self._stat_cache: Dict[bytes, Optional[SFTPAttrs]] = {} self._scandir_cache: Dict[bytes, List[SFTPName]] = {} def _split(self, pattern: bytes) -> Tuple[bytes, _SFTPPatList]: """Split out exact parts of a glob pattern""" patlist: _SFTPPatList = [] if any(c in pattern for c in b'*?[]'): path = b'' plain: List[bytes] = [] for current in pattern.split(b'/'): if any(c in current for c in b'*?[]'): if plain: if patlist: patlist.append(plain) else: path = b'/'.join(plain) or b'/' plain = [] patlist.append(current) else: plain.append(current) if plain: patlist.append(plain) else: path = pattern return path, patlist def _report_match(self, path, attrs): """Report a matching name""" self._matched = True if self._multiple: if path not in self._prev_matches: self._prev_matches.add(path) else: return self._new_matches.append(SFTPName(path, attrs=attrs)) async def _stat(self, path) -> Optional[SFTPAttrs]: """Cache results of calls to stat""" try: return self._stat_cache[path] except KeyError: pass try: attrs = await self._fs.stat(path) except (SFTPNoSuchFile, SFTPPermissionDenied, SFTPNoSuchPath): attrs = None self._stat_cache[path] = attrs return attrs async def _scandir(self, path) -> AsyncIterator[SFTPName]: """Cache results of calls to scandir""" try: for entry in self._scandir_cache[path]: yield entry return except KeyError: pass entries: List[SFTPName] = [] try: async for entry in self._fs.scandir(path): entries.append(entry) yield entry except (SFTPNoSuchFile, SFTPPermissionDenied, SFTPNoSuchPath): pass self._scandir_cache[path] = entries async def _match_exact(self, path: bytes, pattern: Sequence[bytes], patlist: _SFTPPatList) -> None: """Match on an exact portion of a path""" newpath = posixpath.join(path, *pattern) newpatlist = patlist[1:] attrs = await self._stat(newpath) if attrs is None: return if newpatlist: if attrs.type == FILEXFER_TYPE_DIRECTORY: await self._match(newpath, attrs, newpatlist) else: self._report_match(newpath, attrs) async def _match_pattern(self, path: bytes, attrs: SFTPAttrs, pattern: bytes, patlist: _SFTPPatList) -> None: """Match on a pattern portion of a path""" newpatlist = patlist[1:] if pattern == b'**': if newpatlist: await self._match(path, attrs, newpatlist) else: self._report_match(path, attrs) async for entry in self._scandir(path or b'.'): filename = cast(bytes, entry.filename) if filename in (b'.', b'..'): continue if not pattern or fnmatch(filename, pattern): newpath = posixpath.join(path, filename) attrs = entry.attrs if pattern == b'**' and attrs.type == FILEXFER_TYPE_DIRECTORY: await self._match(newpath, attrs, patlist) elif newpatlist: if attrs.type == FILEXFER_TYPE_DIRECTORY: await self._match(newpath, attrs, newpatlist) else: self._report_match(newpath, attrs) async def _match(self, path: bytes, attrs: SFTPAttrs, patlist: _SFTPPatList) -> None: """Recursively match against a glob pattern""" pattern = patlist[0] if isinstance(pattern, list): await self._match_exact(path, pattern, patlist) else: await self._match_pattern(path, attrs, pattern, patlist) async def match(self, pattern: bytes, error_handler: SFTPErrorHandler = None, sftp_version = MIN_SFTP_VERSION) -> Sequence[SFTPName]: """Match against a glob pattern""" self._new_matches = [] self._matched = False path, patlist = self._split(pattern) try: attrs = await self._stat(path or b'.') if attrs: if patlist: if attrs.type == FILEXFER_TYPE_DIRECTORY: await self._match(path, attrs, patlist) elif path: self._report_match(path, attrs) if pattern and not self._matched: exc = SFTPNoSuchPath if sftp_version >= 4 else SFTPNoSuchFile raise exc('No matches found') except (OSError, SFTPError) as exc: setattr(exc, 'srcpath', pattern) if error_handler: error_handler(exc) else: raise return self._new_matches class SFTPHandler(SSHPacketLogger): """SFTP session handler""" _data_pkttypes = {FXP_WRITE, FXP_DATA} _handler_names = get_symbol_names(_const_dict, 'FXP_') _realpath_check_names = get_symbol_names(_const_dict, 'FXRP_', 5) # SFTP implementations with broken order for SYMLINK arguments _nonstandard_symlink_impls = ['OpenSSH', 'paramiko'] # Return types by message -- unlisted entries always return FXP_STATUS, # those below return FXP_STATUS on error _return_types = { FXP_OPEN: FXP_HANDLE, FXP_READ: FXP_DATA, FXP_LSTAT: FXP_ATTRS, FXP_FSTAT: FXP_ATTRS, FXP_OPENDIR: FXP_HANDLE, FXP_READDIR: FXP_NAME, FXP_REALPATH: FXP_NAME, FXP_STAT: FXP_ATTRS, FXP_READLINK: FXP_NAME, b'statvfs@openssh.com': FXP_EXTENDED_REPLY, b'fstatvfs@openssh.com': FXP_EXTENDED_REPLY, b'limits@openssh.com': FXP_EXTENDED_REPLY } def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]'): self._reader: Optional['SSHReader[bytes]'] = reader self._writer: Optional['SSHWriter[bytes]'] = writer self._logger = reader.logger.get_child('sftp') self.limits = SFTPLimits(0, SAFE_SFTP_READ_LEN, SAFE_SFTP_WRITE_LEN, 0) @property def logger(self) -> SSHLogger: """A logger associated with this SFTP handler""" return self._logger async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP session""" # pylint: disable=unused-argument if self._writer: # pragma: no branch self._writer.close() self._reader = None self._writer = None def _log_extensions(self, extensions: Sequence[Tuple[bytes, bytes]]): """Dump a formatted list of extensions to the debug log""" for name, data in extensions: if name == b'acl-supported': capabilities = _parse_acl_supported(data) self.logger.debug1(' acl-supported:') self.logger.debug1(' capabilities: 0x%08x', capabilities) elif name == b'supported': attr_mask, attrib_mask, open_flags, access_mask, \ max_read_size, ext_names = _parse_supported(data) self.logger.debug1(' supported:') self.logger.debug1(' attr_mask: 0x%08x', attr_mask) self.logger.debug1(' attrib_mask: 0x%08x', attrib_mask) self.logger.debug1(' open_flags: 0x%08x', open_flags) self.logger.debug1(' access_mask: 0x%08x', access_mask) self.logger.debug1(' max_read_size: %d', max_read_size) if ext_names: self.logger.debug1(' extensions:') for ext_name in ext_names: self.logger.debug1(' %s', ext_name) elif name == b'supported2': attr_mask, attrib_mask, open_flags, access_mask, \ max_read_size, open_block_vector, block_vector, \ attrib_ext_names, ext_names = _parse_supported2(data) self.logger.debug1(' supported2:') self.logger.debug1(' attr_mask: 0x%08x', attr_mask) self.logger.debug1(' attrib_mask: 0x%08x', attrib_mask) self.logger.debug1(' open_flags: 0x%08x', open_flags) self.logger.debug1(' access_mask: 0x%08x', access_mask) self.logger.debug1(' max_read_size: %d', max_read_size) self.logger.debug1(' open_block_vector: 0x%04x', open_block_vector) self.logger.debug1(' block_vector: 0x%04x', block_vector) if attrib_ext_names: self.logger.debug1(' attrib_extensions:') for attrib_ext_name in attrib_ext_names: self.logger.debug1(' %s', attrib_ext_name) if ext_names: self.logger.debug1(' extensions:') for ext_name in ext_names: self.logger.debug1(' %s', ext_name) elif name == b'vendor-id': vendor_name, product_name, product_version, product_build = \ _parse_vendor_id(data) self.logger.debug1(' vendor-id:') self.logger.debug1(' vendor_name: %s', vendor_name) self.logger.debug1(' product_name: %s', product_name) self.logger.debug1(' product_version: %s', product_version) self.logger.debug1(' product_build: %d', product_build) else: self.logger.debug1(' %s%s%s', name, ': ' if data else '', data) def _log_limits(self, limits: SFTPLimits) -> None: """Log SFTP server limits""" self.logger.debug1(' Max packet len: %d', limits.max_packet_len) self.logger.debug1(' Max read len: %d', limits.max_read_len) self.logger.debug1(' Max write len: %d', limits.max_write_len) self.logger.debug1(' Max open handles: %d', limits.max_open_handles) async def _process_packet(self, pkttype: int, pktid: int, packet: SSHPacket) -> None: """Abstract method for processing SFTP packets""" raise NotImplementedError def send_packet(self, pkttype: int, pktid: Optional[int], *args: bytes) -> None: """Send an SFTP packet""" if not self._writer: raise SFTPNoConnection('Connection not open') payload = Byte(pkttype) + b''.join(args) try: self._writer.write(UInt32(len(payload)) + payload) except ConnectionError as exc: raise SFTPConnectionLost(str(exc)) from None self.log_sent_packet(pkttype, pktid, payload) async def recv_packet(self) -> SSHPacket: """Receive an SFTP packet""" assert self._reader is not None pktlen = await self._reader.readexactly(4) pktlen = int.from_bytes(pktlen, 'big') packet = await self._reader.readexactly(pktlen) return SSHPacket(packet) async def recv_packets(self) -> None: """Receive and process SFTP packets""" try: while self._reader: # pragma: no branch packet = await self.recv_packet() pkttype = packet.get_byte() pktid = packet.get_uint32() self.log_received_packet(pkttype, pktid, packet) await self._process_packet(pkttype, pktid, packet) except PacketDecodeError as exc: await self._cleanup(SFTPBadMessage(str(exc))) except EOFError: await self._cleanup(None) except (OSError, Error) as exc: await self._cleanup(exc) class SFTPClientHandler(SFTPHandler): """An SFTP client session handler""" def __init__(self, loop: asyncio.AbstractEventLoop, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', sftp_version: int): super().__init__(reader, writer) self._loop = loop self._version = sftp_version self._next_pktid = 0 self._requests: Dict[int, _RequestWaiter] = {} self._nonstandard_symlink = False self._supports_posix_rename = False self._supports_statvfs = False self._supports_fstatvfs = False self._supports_hardlink = False self._supports_fsync = False self._supports_lsetstat = False self._supports_limits = False self._supports_copy_data = False @property def version(self) -> int: """SFTP version associated with this SFTP session""" return self._version @property def supports_copy_data(self) -> bool: """Return whether or not SFTP remote copy is supported""" return self._supports_copy_data async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP client session""" req_exc = exc or SFTPConnectionLost('Connection closed') for waiter in list(self._requests.values()): if not waiter.cancelled(): # pragma: no branch waiter.set_exception(req_exc) self._requests = {} self.logger.info('SFTP client exited%s', ': ' + str(exc) if exc else '') await super()._cleanup(exc) async def _process_packet(self, pkttype: int, pktid: int, packet: SSHPacket) -> None: """Process incoming SFTP responses""" try: waiter = self._requests.pop(pktid) except KeyError: await self._cleanup(SFTPBadMessage('Invalid response id')) else: if not waiter.cancelled(): # pragma: no branch waiter.set_result((pkttype, packet)) def _send_request(self, pkttype: Union[int, bytes], args: Sequence[bytes], waiter: _RequestWaiter) -> None: """Send an SFTP request""" pktid = self._next_pktid self._next_pktid = (self._next_pktid + 1) & 0xffffffff self._requests[pktid] = waiter if isinstance(pkttype, bytes): hdr = UInt32(pktid) + String(pkttype) pkttype = FXP_EXTENDED else: hdr = UInt32(pktid) self.send_packet(pkttype, pktid, hdr, *args) async def _make_request(self, pkttype: Union[int, bytes], *args: bytes) -> object: """Make an SFTP request and wait for a response""" waiter: _RequestWaiter = self._loop.create_future() self._send_request(pkttype, args, waiter) resptype, resp = await waiter return_type = self._return_types.get(pkttype) if resptype not in (FXP_STATUS, return_type): raise SFTPBadMessage(f'Unexpected response type: {resptype}') result = self._packet_handlers[resptype](self, resp) if result is not None or return_type is None: return result else: raise SFTPBadMessage('Unexpected FX_OK response') def _process_status(self, packet: SSHPacket) -> None: """Process an incoming SFTP status response""" exc = SFTPError.construct(packet) if self._version < 6: packet.check_end() if exc: raise exc else: self.logger.debug1('Received OK') def _process_handle(self, packet: SSHPacket) -> bytes: """Process an incoming SFTP handle response""" handle = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received handle %s', handle.hex()) return handle def _process_data(self, packet: SSHPacket) -> Tuple[bytes, bool]: """Process an incoming SFTP data response""" data = packet.get_string() at_end = packet.get_boolean() if packet and self._version >= 6 \ else False if self._version < 6: packet.check_end() self.logger.debug1('Received %s%s', plural(len(data), 'data byte'), ' (at end)' if at_end else '') return data, at_end def _process_name(self, packet: SSHPacket) -> _SFTPNames: """Process an incoming SFTP name response""" count = packet.get_uint32() names = [SFTPName.decode(packet, self._version) for _ in range(count)] at_end = packet.get_boolean() if packet and self._version >= 6 \ else False if self._version < 6: packet.check_end() self.logger.debug1('Received %s%s', plural(len(names), 'name'), ' (at end)' if at_end else '') for name in names: self.logger.debug1(' %s', name) return names, at_end def _process_attrs(self, packet: SSHPacket) -> SFTPAttrs: """Process an incoming SFTP attributes response""" attrs = SFTPAttrs().decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received %s', attrs) return attrs def _process_extended_reply(self, packet: SSHPacket) -> SSHPacket: """Process an incoming SFTP extended reply response""" # pylint: disable=no-self-use # Let the caller do the decoding for extended replies return packet _packet_handlers = { FXP_STATUS: _process_status, FXP_HANDLE: _process_handle, FXP_DATA: _process_data, FXP_NAME: _process_name, FXP_ATTRS: _process_attrs, FXP_EXTENDED_REPLY: _process_extended_reply } async def start(self) -> None: """Start an SFTP client""" assert self._reader is not None self.logger.debug1('Sending init, version=%d', self._version) self.send_packet(FXP_INIT, None, UInt32(self._version)) try: resp = await self.recv_packet() resptype = resp.get_byte() self.log_received_packet(resptype, None, resp) if resptype != FXP_VERSION: raise SFTPBadMessage('Expected version message') version = resp.get_uint32() if not MIN_SFTP_VERSION <= version <= MAX_SFTP_VERSION: raise SFTPBadMessage(f'Unsupported version: {version}') rcvd_extensions: List[Tuple[bytes, bytes]] = [] while resp: name = resp.get_string() data = resp.get_string() rcvd_extensions.append((name, data)) except PacketDecodeError as exc: raise SFTPBadMessage(str(exc)) from None except SFTPError: raise except ConnectionLost as exc: raise SFTPConnectionLost(str(exc)) from None except (asyncio.IncompleteReadError, Error) as exc: raise SFTPConnectionLost(str(exc)) from None self.logger.debug1('Received version=%d%s', version, ', extensions:' if rcvd_extensions else '') self._log_extensions(rcvd_extensions) self._version = version for name, data in rcvd_extensions: if name == b'posix-rename@openssh.com' and data == b'1': self._supports_posix_rename = True elif name == b'statvfs@openssh.com' and data == b'2': self._supports_statvfs = True elif name == b'fstatvfs@openssh.com' and data == b'2': self._supports_fstatvfs = True elif name == b'hardlink@openssh.com' and data == b'1': self._supports_hardlink = True elif name == b'fsync@openssh.com' and data == b'1': self._supports_fsync = True elif name == b'lsetstat@openssh.com' and data == b'1': self._supports_lsetstat = True elif name == b'limits@openssh.com' and data == b'1': self._supports_limits = True elif name == b'copy-data' and data == b'1': self._supports_copy_data = True if version == 3: # Check if the server has a buggy SYMLINK implementation server_version = cast(str, self._reader.get_extra_info('server_version', '')) if any(name in server_version for name in self._nonstandard_symlink_impls): self.logger.debug1('Adjusting for non-standard symlink ' 'implementation') self._nonstandard_symlink = True async def request_limits(self) -> None: """Request SFTP server limits""" if self._supports_limits: packet = cast(SSHPacket, await self._make_request( b'limits@openssh.com')) limits = SFTPLimits.decode(packet, self._version) packet.check_end() self.logger.debug1('Received server limits:') self._log_limits(limits) if limits.max_read_len: self.limits.max_read_len = limits.max_read_len if limits.max_write_len: self.limits.max_write_len = limits.max_write_len async def open(self, filename: bytes, pflags: int, attrs: SFTPAttrs) -> bytes: """Make an SFTP open request""" if self._version >= 5: desired_access, flags = _pflags_to_flags(pflags) self.logger.debug1('Sending open for %s, desired_access=0x%08x, ' 'flags=0x%08x%s', filename, desired_access, flags, hide_empty(attrs)) return cast(bytes, await self._make_request( FXP_OPEN, String(filename), UInt32(desired_access), UInt32(flags), attrs.encode(self._version))) else: self.logger.debug1('Sending open for %s, mode 0x%02x%s', filename, pflags, hide_empty(attrs)) return cast(bytes, await self._make_request( FXP_OPEN, String(filename), UInt32(pflags), attrs.encode(self._version))) async def open56(self, filename: bytes, desired_access: int, flags: int, attrs: SFTPAttrs) -> bytes: """Make an SFTPv5/v6 open request""" self.logger.debug1('Sending open for %s, desired_access=0x%08x, ' 'flags=0x%08x%s', filename, desired_access, flags, hide_empty(attrs)) if self._version >= 5: return cast(bytes, await self._make_request( FXP_OPEN, String(filename), UInt32(desired_access), UInt32(flags), attrs.encode(self._version))) else: raise SFTPOpUnsupported('SFTPv5/v6 open not supported by server') async def close(self, handle: bytes) -> None: """Make an SFTP close request""" self.logger.debug1('Sending close for handle %s', handle.hex()) if self._writer: await self._make_request(FXP_CLOSE, String(handle)) async def read(self, handle: bytes, offset: int, length: int) -> Tuple[bytes, bool]: """Make an SFTP read request""" self.logger.debug1('Sending read for %s at offset %d in handle %s', plural(length, 'byte'), offset, handle.hex()) return cast(Tuple[bytes, bool], await self._make_request( FXP_READ, String(handle), UInt64(offset), UInt32(length))) async def write(self, handle: bytes, offset: int, data: bytes) -> int: """Make an SFTP write request""" self.logger.debug1('Sending write for %s at offset %d in handle %s', plural(len(data), 'byte'), offset, handle.hex()) return cast(int, await self._make_request( FXP_WRITE, String(handle), UInt64(offset), String(data))) async def stat(self, path: bytes, flags: int, *, follow_symlinks: bool = True) -> SFTPAttrs: """Make an SFTP stat or lstat request""" if self._version >= 4: flag_bytes = UInt32(flags) flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' if follow_symlinks: self.logger.debug1('Sending stat for %s%s', path, flag_text) return cast(SFTPAttrs, await self._make_request( FXP_STAT, String(path), flag_bytes)) else: self.logger.debug1('Sending lstat for %s%s', path, flag_text) return cast(SFTPAttrs, await self._make_request( FXP_LSTAT, String(path), flag_bytes)) async def lstat(self, path: bytes, flags: int) -> SFTPAttrs: """Make an SFTP lstat request""" if self._version >= 4: flag_bytes = UInt32(flags) flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' self.logger.debug1('Sending lstat for %s%s', path, flag_text) return cast(SFTPAttrs, await self._make_request( FXP_LSTAT, String(path), flag_bytes)) async def fstat(self, handle: bytes, flags: int) -> SFTPAttrs: """Make an SFTP fstat request""" if self._version >= 4: flag_bytes = UInt32(flags) flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' self.logger.debug1('Sending fstat for handle %s%s', handle.hex(), flag_text) return cast(SFTPAttrs, await self._make_request( FXP_FSTAT, String(handle), flag_bytes)) async def setstat(self, path: bytes, attrs: SFTPAttrs, *, follow_symlinks: bool = True) -> None: """Make an SFTP setstat or lsetstat request""" if follow_symlinks: self.logger.debug1('Sending setstat for %s%s', path, hide_empty(attrs)) await self._make_request(FXP_SETSTAT, String(path), attrs.encode(self._version)) elif self._supports_lsetstat: self.logger.debug1('Sending lsetstat for %s%s', path, hide_empty(attrs)) await self._make_request(b'lsetstat@openssh.com', String(path), attrs.encode(self._version)) else: raise SFTPOpUnsupported('lsetstat not supported by server') async def fsetstat(self, handle: bytes, attrs: SFTPAttrs) -> None: """Make an SFTP fsetstat request""" self.logger.debug1('Sending fsetstat for handle %s%s', handle.hex(), hide_empty(attrs)) await self._make_request(FXP_FSETSTAT, String(handle), attrs.encode(self._version)) async def statvfs(self, path: bytes) -> SFTPVFSAttrs: """Make an SFTP statvfs request""" if self._supports_statvfs: self.logger.debug1('Sending statvfs for %s', path) packet = cast(SSHPacket, await self._make_request( b'statvfs@openssh.com', String(path))) vfsattrs = SFTPVFSAttrs.decode(packet, self._version) packet.check_end() self.logger.debug1('Received %s', vfsattrs) return vfsattrs else: raise SFTPOpUnsupported('statvfs not supported') async def fstatvfs(self, handle: bytes) -> SFTPVFSAttrs: """Make an SFTP fstatvfs request""" if self._supports_fstatvfs: self.logger.debug1('Sending fstatvfs for handle %s', handle.hex()) packet = cast(SSHPacket, await self._make_request( b'fstatvfs@openssh.com', String(handle))) vfsattrs = SFTPVFSAttrs.decode(packet, self._version) packet.check_end() self.logger.debug1('Received %s', vfsattrs) return vfsattrs else: raise SFTPOpUnsupported('fstatvfs not supported') async def remove(self, path: bytes) -> None: """Make an SFTP remove request""" self.logger.debug1('Sending remove for %s', path) await self._make_request(FXP_REMOVE, String(path)) async def rename(self, oldpath: bytes, newpath: bytes, flags: int) -> None: """Make an SFTP rename request""" if self._version >= 5: self.logger.debug1('Sending rename request from %s to %s%s', oldpath, newpath, f', flags=0x{flags:x}' if flags else '') await self._make_request(FXP_RENAME, String(oldpath), String(newpath), UInt32(flags)) elif flags and self._supports_posix_rename: self.logger.debug1('Sending OpenSSH POSIX rename request ' 'from %s to %s', oldpath, newpath) await self._make_request(b'posix-rename@openssh.com', String(oldpath), String(newpath)) elif not flags: self.logger.debug1('Sending rename request from %s to %s', oldpath, newpath) await self._make_request(FXP_RENAME, String(oldpath), String(newpath)) else: raise SFTPOpUnsupported('Rename with overwrite not supported') async def posix_rename(self, oldpath: bytes, newpath: bytes) -> None: """Make an SFTP POSIX rename request""" if self._supports_posix_rename: self.logger.debug1('Sending OpenSSH POSIX rename request ' 'from %s to %s', oldpath, newpath) await self._make_request(b'posix-rename@openssh.com', String(oldpath), String(newpath)) elif self._version >= 5: self.logger.debug1('Sending rename request from %s to %s ' 'with overwrite', oldpath, newpath) await self._make_request(FXP_RENAME, String(oldpath), String(newpath), UInt32(FXR_OVERWRITE)) else: raise SFTPOpUnsupported('POSIX rename not supported') async def opendir(self, path: bytes) -> bytes: """Make an SFTP opendir request""" self.logger.debug1('Sending opendir for %s', path) return cast(bytes, await self._make_request( FXP_OPENDIR, String(path))) async def readdir(self, handle: bytes) -> _SFTPNames: """Make an SFTP readdir request""" self.logger.debug1('Sending readdir for handle %s', handle.hex()) return cast(_SFTPNames, await self._make_request( FXP_READDIR, String(handle))) async def mkdir(self, path: bytes, attrs: SFTPAttrs) -> None: """Make an SFTP mkdir request""" self.logger.debug1('Sending mkdir for %s', path) await self._make_request(FXP_MKDIR, String(path), attrs.encode(self._version)) async def rmdir(self, path: bytes) -> None: """Make an SFTP rmdir request""" self.logger.debug1('Sending rmdir for %s', path) await self._make_request(FXP_RMDIR, String(path)) async def realpath(self, path: bytes, *compose_paths: bytes, check: int = FXRP_NO_CHECK) -> _SFTPNames: """Make an SFTP realpath request""" if check == FXRP_NO_CHECK: checkmsg = '' else: try: checkmsg = f', check={self._realpath_check_names[check]}' except KeyError: checkmsg = f', check={check}' self.logger.debug1('Sending realpath of %s%s%s', path, b', compose_path: ' + b', '.join(compose_paths) if compose_paths else b'', checkmsg) if self._version >= 6: return cast(_SFTPNames, await self._make_request( FXP_REALPATH, String(path), Byte(check), *map(String, compose_paths))) else: return cast(_SFTPNames, await self._make_request( FXP_REALPATH, String(path))) async def readlink(self, path: bytes) -> _SFTPNames: """Make an SFTP readlink request""" self.logger.debug1('Sending readlink for %s', path) return cast(_SFTPNames, await self._make_request( FXP_READLINK, String(path))) async def symlink(self, oldpath: bytes, newpath: bytes) -> None: """Make an SFTP symlink request""" self.logger.debug1('Sending symlink request from %s to %s', oldpath, newpath) if self._version >= 6: await self._make_request(FXP_LINK, String(newpath), String(oldpath), Boolean(True)) else: if self._nonstandard_symlink: args = String(oldpath) + String(newpath) else: args = String(newpath) + String(oldpath) await self._make_request(FXP_SYMLINK, args) async def link(self, oldpath: bytes, newpath: bytes) -> None: """Make an SFTP hard link request""" if self._version >= 6 or self._supports_hardlink: self.logger.debug1('Sending hardlink request from %s to %s', oldpath, newpath) if self._version >= 6: await self._make_request(FXP_LINK, String(newpath), String(oldpath), Boolean(False)) else: await self._make_request(b'hardlink@openssh.com', String(oldpath), String(newpath)) else: raise SFTPOpUnsupported('link not supported') async def lock(self, handle: bytes, offset: int, length: int, flags: int) -> None: """Make an SFTP byte range lock request""" if self._version >= 6: self.logger.debug1('Sending byte range lock request for ' 'handle %s, offset %d, length %d, ' 'flags 0x%04x', handle.hex(), offset, length, flags) await self._make_request(FXP_BLOCK, String(handle), UInt64(offset), UInt64(length), UInt32(flags)) else: raise SFTPOpUnsupported('Byte range locks not supported') async def unlock(self, handle: bytes, offset: int, length: int) -> None: """Make an SFTP byte range unlock request""" if self._version >= 6: self.logger.debug1('Sending byte range unlock request for ' 'handle %s, offset %d, length %d', handle.hex(), offset, length) await self._make_request(FXP_UNBLOCK, String(handle), UInt64(offset), UInt64(length)) else: raise SFTPOpUnsupported('Byte range locks not supported') async def fsync(self, handle: bytes) -> None: """Make an SFTP fsync request""" if self._supports_fsync: self.logger.debug1('Sending fsync for handle %s', handle.hex()) await self._make_request(b'fsync@openssh.com', String(handle)) else: raise SFTPOpUnsupported('fsync not supported') async def copy_data(self, read_from_handle: bytes, read_from_offset: int, read_from_length: int, write_to_handle: bytes, write_to_offset: int) -> None: """Make an SFTP copy data request""" if self._supports_copy_data: self.logger.debug1('Sending copy-data from handle %s, ' 'offset %d, length %d to handle %s, ' 'offset %d', read_from_handle.hex(), read_from_offset, read_from_length, write_to_handle.hex(), write_to_offset) await self._make_request(b'copy-data', String(read_from_handle), UInt64(read_from_offset), UInt64(read_from_length), String(write_to_handle), UInt64(write_to_offset)) else: raise SFTPOpUnsupported('copy-data not supported') def exit(self) -> None: """Handle a request to close the SFTP session""" if self._writer: self._writer.write_eof() async def wait_closed(self) -> None: """Wait for this SFTP session to close""" if self._writer: await self._writer.channel.wait_closed() class SFTPClientFile: """SFTP client remote file object This class represents an open file on a remote SFTP server. It is opened with the :meth:`open() ` method on the :class:`SFTPClient` class and provides methods to read and write data and get and set attributes on the open file. """ def __init__(self, handler: SFTPClientHandler, handle: bytes, appending: bool, encoding: Optional[str], errors: str, block_size: int, max_requests: int): self._handler = handler self._handle: Optional[bytes] = handle self._appending = appending self._encoding = encoding self._errors = errors self._offset = None if appending else 0 self.read_len = \ handler.limits.max_read_len if block_size == -1 else block_size self.write_len = \ handler.limits.max_write_len if block_size == -1 else block_size if max_requests <= 0: if self.read_len: max_requests = max(16, min(MAX_SFTP_READ_LEN // self.read_len, 128)) else: max_requests = 1 self._max_requests = max_requests async def __aenter__(self) -> Self: """Allow SFTPClientFile to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: """Wait for file close when used as an async context manager""" await self.close() return False @property def handle(self) -> bytes: """Return handle or raise an error if clsoed""" if self._handle is None: raise ValueError('I/O operation on closed file') return self._handle async def _end(self) -> int: """Return the offset of the end of the file""" attrs = await self.stat() return attrs.size or 0 async def read(self, size: int = -1, offset: Optional[int] = None) -> AnyStr: """Read data from the remote file This method reads and returns up to `size` bytes of data from the remote file. If size is negative, all data up to the end of the file is returned. If offset is specified, the read will be performed starting at that offset rather than the current file position. This argument should be provided if you want to issue parallel reads on the same file, since the file position is not predictable in that case. Data will be returned as a string if an encoding was set when the file was opened. Otherwise, data is returned as bytes. An empty `str` or `bytes` object is returned when at EOF. :param size: The number of bytes to read :param offset: (optional) The offset from the beginning of the file to begin reading :type size: `int` :type offset: `int` :returns: data read from the file, as a `str` or `bytes` :raises: | :exc:`ValueError` if the file has been closed | :exc:`UnicodeDecodeError` if the data can't be decoded using the requested encoding | :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') if offset is None: offset = self._offset # If self._offset is None, we're appending and haven't sought # backward in the file since the last write, so there's no # data to return data = b'' if offset is not None: if size is None or size < 0: size = (await self._end()) - offset try: if self.read_len and size > \ min(self.read_len, self._handler.limits.max_read_len): data = await _SFTPFileReader( self.read_len, self._max_requests, self._handler, self._handle, offset, size).run() else: data, _ = await self._handler.read(self._handle, offset, size) self._offset = offset + len(data) except SFTPEOFError: pass if self._encoding: return cast(AnyStr, data.decode(self._encoding, self._errors)) else: return cast(AnyStr, data) async def read_parallel(self, size: int = -1, offset: Optional[int] = None) -> \ AsyncIterator[Tuple[int, bytes]]: """Read parallel blocks of data from the remote file This method reads and returns up to `size` bytes of data from the remote file. If size is negative, all data up to the end of the file is returned. If offset is specified, the read will be performed starting at that offset rather than the current file position. Data is returned as a series of tuples delivered by an async iterator, where each tuple contains an offset and data bytes. Encoding is ignored here, since multi-byte characters may be split across block boundaries. To maximize performance, multiple reads are issued in parallel, and data blocks may be returned out of order. The size of the blocks and the maximum number of outstanding read requests can be controlled using the `block_size` and `max_requests` arguments passed in the call to the :meth:`open() ` method on the :class:`SFTPClient` class. :param size: The number of bytes to read :param offset: (optional) The offset from the beginning of the file to begin reading :type size: `int` :type offset: `int` :returns: an async iterator of tuples of offset and data bytes :raises: | :exc:`ValueError` if the file has been closed | :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') if offset is None: offset = self._offset # If self._offset is None, we're appending and haven't sought # backward in the file since the last write, so there's no # data to return if offset is not None: if size is None or size < 0: size = (await self._end()) - offset else: offset = 0 size = 0 return _SFTPFileReader(self.read_len, self._max_requests, self._handler, self._handle, offset, size).iter() async def write(self, data: AnyStr, offset: Optional[int] = None) -> int: """Write data to the remote file This method writes the specified data at the current position in the remote file. :param data: The data to write to the file :param offset: (optional) The offset from the beginning of the file to begin writing :type data: `str` or `bytes` :type offset: `int` If offset is specified, the write will be performed starting at that offset rather than the current file position. This argument should be provided if you want to issue parallel writes on the same file, since the file position is not predictable in that case. :returns: number of bytes written :raises: | :exc:`ValueError` if the file has been closed | :exc:`UnicodeEncodeError` if the data can't be encoded using the requested encoding | :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') if offset is None: # Offset is ignored when appending, so fill in an offset of 0 # if we don't have a current file position offset = self._offset or 0 if self._encoding: data_bytes = cast(str, data).encode(self._encoding, self._errors) else: data_bytes = cast(bytes, data) datalen = len(data_bytes) if self.write_len and datalen > self.write_len: await _SFTPFileWriter( self.write_len, self._max_requests, self._handler, self._handle, offset, data_bytes).run() else: await self._handler.write(self._handle, offset, data_bytes) self._offset = None if self._appending else offset + datalen return datalen async def seek(self, offset: int, from_what: int = SEEK_SET) -> int: """Seek to a new position in the remote file This method changes the position in the remote file. The `offset` passed in is treated as relative to the beginning of the file if `from_what` is set to `SEEK_SET` (the default), relative to the current file position if it is set to `SEEK_CUR`, or relative to the end of the file if it is set to `SEEK_END`. :param offset: The amount to seek :param from_what: (optional) The reference point to use :type offset: `int` :type from_what: `SEEK_SET`, `SEEK_CUR`, or `SEEK_END` :returns: The new byte offset from the beginning of the file """ if self._handle is None: raise ValueError('I/O operation on closed file') if from_what == SEEK_SET: self._offset = offset elif from_what == SEEK_CUR: if self._offset is None: self._offset = (await self._end()) + offset else: self._offset += offset elif from_what == SEEK_END: self._offset = (await self._end()) + offset else: raise ValueError('Invalid reference point') return self._offset async def tell(self) -> int: """Return the current position in the remote file This method returns the current position in the remote file. :returns: The current byte offset from the beginning of the file """ if self._handle is None: raise ValueError('I/O operation on closed file') if self._offset is None: self._offset = await self._end() return self._offset async def stat(self, flags = FILEXFER_ATTR_DEFINED_V4) -> SFTPAttrs: """Return file attributes of the remote file This method queries file attributes of the currently open file. :param flags: (optional) Flags indicating attributes of interest (SFTPv4 or later) :type flags: `int` :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') return await self._handler.fstat(self._handle, flags) async def setstat(self, attrs: SFTPAttrs) -> None: """Set attributes of the remote file This method sets file attributes of the currently open file. :param attrs: File attributes to set on the file :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') await self._handler.fsetstat(self._handle, attrs) async def statvfs(self) -> SFTPVFSAttrs: """Return file system attributes of the remote file This method queries attributes of the file system containing the currently open file. :returns: An :class:`SFTPVFSAttrs` containing the file system attributes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ if self._handle is None: raise ValueError('I/O operation on closed file') return await self._handler.fstatvfs(self._handle) async def truncate(self, size: Optional[int] = None) -> None: """Truncate the remote file to the specified size This method changes the remote file's size to the specified value. If a size is not provided, the current file position is used. :param size: (optional) The desired size of the file, in bytes :type size: `int` :raises: :exc:`SFTPError` if the server returns an error """ if size is None: size = self._offset await self.setstat(SFTPAttrs(size=size)) @overload async def chown(self, uid: int, gid: int) -> None: ... # pragma: no cover @overload async def chown(self, owner: str, group: str) -> None: ... # pragma: no cover async def chown(self, uid_or_owner = None, gid_or_group = None, uid = None, gid = None, owner = None, group = None): """Change the owner user and group of the remote file This method changes the user and group of the currently open file. :param uid: The new user id to assign to the file :param gid: The new group id to assign to the file :param owner: The new owner to assign to the file (SFTPv4 only) :param group: The new group to assign to the file (SFTPv4 only) :type uid: `int` :type gid: `int` :type owner: `str` :type group: `str` :raises: :exc:`SFTPError` if the server returns an error """ if isinstance(uid_or_owner, int): uid = uid_or_owner elif isinstance(uid_or_owner, str): owner = uid_or_owner if isinstance(gid_or_group, int): gid = gid_or_group elif isinstance(gid_or_group, str): group = gid_or_group await self.setstat(SFTPAttrs(uid=uid, gid=gid, owner=owner, group=group)) async def chmod(self, mode: int) -> None: """Change the file permissions of the remote file This method changes the permissions of the currently open file. :param mode: The new file permissions, expressed as an int :type mode: `int` :raises: :exc:`SFTPError` if the server returns an error """ await self.setstat(SFTPAttrs(permissions=mode)) async def utime(self, times: Optional[Tuple[float, float]] = None, ns: Optional[Tuple[int, int]] = None) -> None: """Change the access and modify times of the remote file This method changes the access and modify times of the currently open file. If `times` is not provided, the times will be changed to the current time. :param times: (optional) The new access and modify times, as seconds relative to the UNIX epoch :param ns: (optional) The new access and modify times, as nanoseconds relative to the UNIX epoch :type times: tuple of two `int` or `float` values :type ns: tuple of two `int` values :raises: :exc:`SFTPError` if the server returns an error """ await self.setstat(_utime_to_attrs(times, ns)) async def lock(self, offset: int, length: int, flags: int) -> None: """Acquire a byte range lock on the remote file""" if self._handle is None: raise ValueError('I/O operation on closed file') await self._handler.lock(self._handle, offset, length, flags) async def unlock(self, offset: int, length: int) -> None: """Release a byte range lock on the remote file""" if self._handle is None: raise ValueError('I/O operation on closed file') await self._handler.unlock(self._handle, offset, length) async def fsync(self) -> None: """Force the remote file data to be written to disk""" if self._handle is None: raise ValueError('I/O operation on closed file') await self._handler.fsync(self._handle) async def close(self) -> None: """Close the remote file""" if self._handle: await self._handler.close(self._handle) self._handle = None class SFTPClient: """SFTP client This class represents the client side of an SFTP session. It is started by calling the :meth:`start_sftp_client() ` method on the :class:`SSHClientConnection` class. """ def __init__(self, handler: SFTPClientHandler, path_encoding: Optional[str], path_errors: str): self._handler = handler self._path_encoding = path_encoding self._path_errors = path_errors self._cwd: Optional[bytes] = None async def __aenter__(self) -> Self: """Allow SFTPClient to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: """Wait for client close when used as an async context manager""" self.exit() await self.wait_closed() return False @property def logger(self) -> SSHLogger: """A logger associated with this SFTP client""" return self._handler.logger @property def version(self) -> int: """SFTP version associated with this SFTP session""" return self._handler.version @property def limits(self) -> SFTPLimits: """:class:`SFTPLimits` associated with this SFTP session""" return self._handler.limits @property def supports_remote_copy(self) -> bool: """Return whether or not SFTP remote copy is supported""" return self._handler.supports_copy_data @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" return posixpath.basename(path) def encode(self, path: _SFTPPath) -> bytes: """Encode path name using configured path encoding This method has no effect if the path is already bytes. """ if isinstance(path, PurePath): path = str(path) if isinstance(path, str): if self._path_encoding: path = path.encode(self._path_encoding, self._path_errors) else: raise SFTPBadMessage('Path must be bytes when ' 'encoding is not set') return path def decode(self, path: bytes, want_string: bool = True) -> BytesOrStr: """Decode path name using configured path encoding This method has no effect if want_string is set to `False`. """ if want_string and self._path_encoding: try: return path.decode(self._path_encoding, self._path_errors) except UnicodeDecodeError: raise SFTPBadMessage('Unable to decode name') from None return path def compose_path(self, path: _SFTPPath, parent: Optional[bytes] = None) -> bytes: """Compose a path If parent is not specified, return a path relative to the current remote working directory. """ if parent is None: parent = self._cwd path = self.encode(path) return posixpath.join(parent, path) if parent else path async def _type(self, path: _SFTPPath, statfunc: Optional[_SFTPStatFunc] = None) -> int: """Return the file type of a remote path, or FILEXFER_TYPE_UNKNOWN if it can't be accessed""" if statfunc is None: statfunc = self.stat try: return (await statfunc(path)).type except (SFTPNoSuchFile, SFTPNoSuchPath, SFTPPermissionDenied): return FILEXFER_TYPE_UNKNOWN async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, srcpath: bytes, dstpath: bytes, srcattrs: SFTPAttrs, preserve: bool, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, error_handler: SFTPErrorHandler, remote_only: bool) -> None: """Copy a file, directory, or symbolic link""" try: filetype = srcattrs.type if follow_symlinks and filetype == FILEXFER_TYPE_SYMLINK: srcattrs = await srcfs.stat(srcpath) filetype = srcattrs.type if filetype == FILEXFER_TYPE_DIRECTORY: if not recurse: exc = SFTPFileIsADirectory if self.version >= 6 \ else SFTPFailure raise exc(srcpath.decode('utf-8', 'backslashreplace') + ' is a directory') self.logger.info(' Starting copy of directory %s to %s', srcpath, dstpath) if not await dstfs.isdir(dstpath): await dstfs.mkdir(dstpath) async for srcname in srcfs.scandir(srcpath): filename = cast(bytes, srcname.filename) if filename in (b'.', b'..'): continue srcfile = posixpath.join(srcpath, filename) dstfile = posixpath.join(dstpath, filename) await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler, remote_only) self.logger.info(' Finished copy of directory %s to %s', srcpath, dstpath) elif filetype == FILEXFER_TYPE_SYMLINK: targetpath = await srcfs.readlink(srcpath) self.logger.info(' Copying symlink %s to %s', srcpath, dstpath) self.logger.info(' Target path: %s', targetpath) await dstfs.symlink(targetpath, dstpath) else: self.logger.info(' Copying file %s to %s', srcpath, dstpath) if remote_only and not self.supports_remote_copy: raise SFTPOpUnsupported('Remote copy not supported') await _SFTPFileCopier(block_size, max_requests, 0, srcattrs.size or 0, srcfs, dstfs, srcpath, dstpath, progress_handler).run() if preserve: attrs = await srcfs.stat(srcpath, follow_symlinks=follow_symlinks) attrs = SFTPAttrs(permissions=attrs.permissions, atime=attrs.atime, atime_ns=attrs.atime_ns, mtime=attrs.mtime, mtime_ns=attrs.mtime_ns) try: await dstfs.setstat(dstpath, attrs, follow_symlinks=follow_symlinks or filetype != FILEXFER_TYPE_SYMLINK) self.logger.info(' Preserved attrs: %s', attrs) except SFTPOpUnsupported: self.logger.info(' Preserving symlink attrs unsupported') except (OSError, SFTPError) as exc: setattr(exc, 'srcpath', srcpath) setattr(exc, 'dstpath', dstpath) if error_handler: error_handler(exc) else: raise async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath], copy_type: str, expand_glob: bool, preserve: bool, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, error_handler: SFTPErrorHandler, remote_only: bool = False) -> None: """Begin a new file upload, download, or copy""" if block_size <= 0: block_size = min(srcfs.limits.max_read_len, dstfs.limits.max_write_len) if max_requests <= 0: max_requests = max(16, min(MAX_SFTP_READ_LEN // block_size, 128)) if isinstance(srcpaths, (bytes, str, PurePath)): srcpaths = [srcpaths] elif not isinstance(srcpaths, list): srcpaths = list(srcpaths) self.logger.info('Starting SFTP %s of %s to %s', copy_type, srcpaths, dstpath) srcnames: List[SFTPName] = [] if expand_glob: glob = SFTPGlob(srcfs, len(srcpaths) > 1) for srcpath in srcpaths: srcnames.extend(await glob.match(srcfs.encode(srcpath), error_handler, self.version)) else: for srcpath in srcpaths: srcpath = srcfs.encode(srcpath) srcattrs = await srcfs.stat(srcpath, follow_symlinks=follow_symlinks) srcnames.append(SFTPName(srcpath, attrs=srcattrs)) if dstpath: dstpath = dstfs.encode(dstpath) dstpath: Optional[bytes] dst_isdir = dstpath is None or (await dstfs.isdir(dstpath)) if len(srcnames) > 1 and not dst_isdir: assert dstpath is not None exc = SFTPNotADirectory if self.version >= 6 else SFTPFailure raise exc(dstpath.decode('utf-8', 'backslashreplace') + ' must be a directory') for srcname in srcnames: srcfile = cast(bytes, srcname.filename) basename = srcfs.basename(srcfile) if dstpath is None: dstfile = basename elif dst_isdir: dstfile = dstfs.compose_path(basename, parent=dstpath) else: dstfile = dstpath await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler, remote_only) async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files This method downloads one or more files or directories from the remote system. Either a single remote path or a sequence of remote paths to download can be provided. When downloading a single file or directory, the local path can be either the full path to download data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the remote path will be used as the local name. When downloading multiple files, the local path must refer to an existing directory. If no local path is provided, the file is downloaded into the current local working directory. If preserve is `True`, the access and modification times and permissions of the original file are set on the downloaded file. If recurse is `True` and the remote path points at a directory, the entire subtree under that directory is downloaded. If follow_symlinks is set to `True`, symbolic links found on the remote system will have the contents of their target downloaded rather than creating a local symbolic link. When using this option during a recursive download, one needs to watch out for links that result in loops. The block_size argument specifies the size of read and write requests issued when downloading the files, defaulting to the maximum allowed by the server, or 16 KB if the server doesn't advertise limits. The max_requests argument specifies the maximum number of parallel read or write requests issued, defaulting to a value between 16 and 128 depending on the selected block size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully downloaded. The arguments passed to this handler will be the source path, destination path, bytes downloaded so far, and total bytes in the file being downloaded. If multiple source paths are provided or recurse is set to `True`, the progress_handler will be called consecutively on each file being downloaded. If error_handler is specified and an error occurs during the download, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple remote paths are provided or when recurse is set to `True`, to allow error information to be collected without aborting the download of the remaining files. The error handler can raise an exception if it wants the download to completely stop. Otherwise, after an error, the download will continue starting with the next file. :param remotepaths: The paths of the remote files or directories to download :param localpath: (optional) The path of the local file or directory to download into :param preserve: (optional) Whether or not to preserve the original file attributes :param recurse: (optional) Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) The maximum number of parallel read or write requests :param progress_handler: (optional) The function to call to report download progress :param error_handler: (optional) The function to call when an error occurs :type remotepaths: :class:`PurePath `, `str`, or `bytes`, or a sequence of these :type localpath: :class:`PurePath `, `str`, or `bytes` :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` :type error_handler: `callable` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ await self._begin_copy(self, local_fs, remotepaths, localpath, 'get', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler) async def put(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files This method uploads one or more files or directories to the remote system. Either a single local path or a sequence of local paths to upload can be provided. When uploading a single file or directory, the remote path can be either the full path to upload data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the local path will be used as the remote name. When uploading multiple files, the remote path must refer to an existing directory. If no remote path is provided, the file is uploaded into the current remote working directory. If preserve is `True`, the access and modification times and permissions of the original file are set on the uploaded file. If recurse is `True` and the local path points at a directory, the entire subtree under that directory is uploaded. If follow_symlinks is set to `True`, symbolic links found on the local system will have the contents of their target uploaded rather than creating a remote symbolic link. When using this option during a recursive upload, one needs to watch out for links that result in loops. The block_size argument specifies the size of read and write requests issued when uploading the files, defaulting to the maximum allowed by the server, or 16 KB if the server doesn't advertise limits. The max_requests argument specifies the maximum number of parallel read or write requests issued, defaulting to a value between 16 and 128 depending on the selected block size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully uploaded. The arguments passed to this handler will be the source path, destination path, bytes uploaded so far, and total bytes in the file being uploaded. If multiple source paths are provided or recurse is set to `True`, the progress_handler will be called consecutively on each file being uploaded. If error_handler is specified and an error occurs during the upload, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple local paths are provided or when recurse is set to `True`, to allow error information to be collected without aborting the upload of the remaining files. The error handler can raise an exception if it wants the upload to completely stop. Otherwise, after an error, the upload will continue starting with the next file. :param localpaths: The paths of the local files or directories to upload :param remotepath: (optional) The path of the remote file or directory to upload into :param preserve: (optional) Whether or not to preserve the original file attributes :param recurse: (optional) Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) The maximum number of parallel read or write requests :param progress_handler: (optional) The function to call to report upload progress :param error_handler: (optional) The function to call when an error occurs :type localpaths: :class:`PurePath `, `str`, or `bytes`, or a sequence of these :type remotepath: :class:`PurePath `, `str`, or `bytes` :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` :type error_handler: `callable` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ await self._begin_copy(local_fs, self, localpaths, remotepath, 'put', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler) async def copy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, remote_only: bool = False) -> None: """Copy remote files to a new location This method copies one or more files or directories on the remote system to a new location. Either a single source path or a sequence of source paths to copy can be provided. When copying a single file or directory, the destination path can be either the full path to copy data into or the path to an existing directory where the data should be placed. In the latter case, the base file name from the source path will be used as the destination name. When copying multiple files, the destination path must refer to an existing remote directory. If no destination path is provided, the file is copied into the current remote working directory. If preserve is `True`, the access and modification times and permissions of the original file are set on the copied file. If recurse is `True` and the source path points at a directory, the entire subtree under that directory is copied. If follow_symlinks is set to `True`, symbolic links found in the source will have the contents of their target copied rather than creating a copy of the symbolic link. When using this option during a recursive copy, one needs to watch out for links that result in loops. The block_size argument specifies the size of read and write requests issued when copying the files, defaulting to the maximum allowed by the server, or 16 KB if the server doesn't advertise limits. The max_requests argument specifies the maximum number of parallel read or write requests issued, defaulting to a value between 16 and 128 depending on the selected block size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments passed to this handler will be the source path, destination path, bytes copied so far, and total bytes in the file being copied. If multiple source paths are provided or recurse is set to `True`, the progress_handler will be called consecutively on each file being copied. If error_handler is specified and an error occurs during the copy, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple source paths are provided or when recurse is set to `True`, to allow error information to be collected without aborting the copy of the remaining files. The error handler can raise an exception if it wants the copy to completely stop. Otherwise, after an error, the copy will continue starting with the next file. :param srcpaths: The paths of the remote files or directories to copy :param dstpath: (optional) The path of the remote file or directory to copy into :param preserve: (optional) Whether or not to preserve the original file attributes :param recurse: (optional) Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) The maximum number of parallel read or write requests :param progress_handler: (optional) The function to call to report copy progress :param error_handler: (optional) The function to call when an error occurs :param remote_only: (optional) Whether or not to only allow this to be a remote copy :type srcpaths: :class:`PurePath `, `str`, or `bytes`, or a sequence of these :type dstpath: :class:`PurePath `, `str`, or `bytes` :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` :type error_handler: `callable` :type remote_only: `bool` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ await self._begin_copy(self, self, srcpaths, dstpath, 'remote copy', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler, remote_only) async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files with glob pattern match This method downloads files and directories from the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`get` method, except that the remote paths specified can contain wildcard patterns. """ await self._begin_copy(self, local_fs, remotepaths, localpath, 'mget', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler) async def mput(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files with glob pattern match This method uploads files and directories to the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`put` method, except that the local paths specified can contain wildcard patterns. """ await self._begin_copy(local_fs, self, localpaths, remotepath, 'mput', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler) async def mcopy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, remote_only: bool = False) -> None: """Copy remote files with glob pattern match This method copies files and directories on the remote system matching one or more glob patterns. The arguments to this method are identical to the :meth:`copy` method, except that the source paths specified can contain wildcard patterns. """ await self._begin_copy(self, self, srcpaths, dstpath, 'remote mcopy', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler, remote_only) async def remote_copy(self, src: _SFTPClientFileOrPath, dst: _SFTPClientFileOrPath, src_offset: int = 0, src_length: int = 0, dst_offset: int = 0) -> None: """Copy data between remote files :param src: The remote file object to read data from :param dst: The remote file object to write data to :param src_offset: (optional) The offset to begin reading data from :param src_length: (optional) The number of bytes to attempt to copy :param dst_offset: (optional) The offset to begin writing data to :type src: :class:`SFTPClientFile`, :class:`PurePath `, `str`, or `bytes` :type dst: :class:`SFTPClientFile`, :class:`PurePath `, `str`, or `bytes` :type src_offset: `int` :type src_length: `int` :type dst_offset: `int` :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ if isinstance(src, (bytes, str, PurePath)): src = await self.open(src, 'rb', block_size=0) if isinstance(dst, (bytes, str, PurePath)): dst = await self.open(dst, 'wb', block_size=0) await self._handler.copy_data(src.handle, src_offset, src_length, dst.handle, dst_offset) async def glob(self, patterns: _SFTPPaths, error_handler: SFTPErrorHandler = None) -> \ Sequence[BytesOrStr]: """Match remote files against glob patterns This method matches remote files against one or more glob patterns. Either a single pattern or a sequence of patterns can be provided to match against. Supported wildcard characters include '*', '?', and character ranges in square brackets. In addition, '**' can be used to trigger a recursive directory search at that point in the pattern, and a trailing slash can be used to request that only directories get returned. If error_handler is specified and an error occurs during the match, this handler will be called with the exception instead of it being raised. This is intended to primarily be used when multiple patterns are provided to allow error information to be collected without aborting the match against the remaining patterns. The error handler can raise an exception if it wants to completely abort the match. Otherwise, after an error, the match will continue starting with the next pattern. An error will be raised if any of the patterns completely fail to match, and this can either stop the match against the remaining patterns or be handled by the error_handler just like other errors. :param patterns: Glob patterns to try and match remote files against :param error_handler: (optional) The function to call when an error occurs :type patterns: :class:`PurePath `, `str`, or `bytes`, or a sequence of these :type error_handler: `callable` :raises: :exc:`SFTPError` if the server returns an error or no match is found """ return [name.filename for name in await self.glob_sftpname(patterns, error_handler)] async def glob_sftpname(self, patterns: _SFTPPaths, error_handler: SFTPErrorHandler = None) -> \ Sequence[SFTPName]: """Match glob patterns and return SFTPNames This method is similar to :meth:`glob`, but it returns matching file names and attributes as :class:`SFTPName` objects. """ if isinstance(patterns, (bytes, str, PurePath)): patterns = [patterns] glob = SFTPGlob(self, len(patterns) > 1) matches: List[SFTPName] = [] for pattern in patterns: new_matches = await glob.match(self.encode(pattern), error_handler, self.version) if isinstance(pattern, (str, PurePath)): for name in new_matches: name.filename = self.decode(cast(bytes, name.filename)) matches.extend(new_matches) return matches async def makedirs(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), exist_ok: bool = False) -> None: """Create a remote directory with the specified attributes This method creates a remote directory at the specified path similar to :meth:`mkdir`, but it will also create any intermediate directories which don't yet exist. If the target directory already exists and exist_ok is set to `False`, this method will raise an error. :param path: The path of where the new remote directory should be created :param attrs: (optional) The file attributes to use when creating the directory or any intermediate directories :param exist_ok: (optional) Whether or not to raise an error if thet target directory already exists :type path: :class:`PurePath `, `str`, or `bytes` :type attrs: :class:`SFTPAttrs` :type exist_ok: `bool` :raises: :exc:`SFTPError` if the server returns an error """ path = self.encode(path) curpath = b'/' if posixpath.isabs(path) else (self._cwd or b'') exists = True parts = path.split(b'/') last = len(parts) - 1 exc: Type[SFTPError] for i, part in enumerate(parts): curpath = posixpath.join(curpath, part) try: await self.mkdir(curpath, attrs) exists = False except (SFTPFailure, SFTPFileAlreadyExists): filetype = await self._type(curpath) if filetype != FILEXFER_TYPE_DIRECTORY: curpath_str = curpath.decode('utf-8', 'backslashreplace') exc = SFTPNotADirectory if self.version >= 6 \ else SFTPFailure raise exc(f'{curpath_str} is not a directory') from None except SFTPPermissionDenied: if i == last: raise if exists and not exist_ok: exc = SFTPFileAlreadyExists if self.version >= 6 else SFTPFailure raise exc(curpath.decode('utf-8', 'backslashreplace') + ' already exists') async def rmtree(self, path: _SFTPPath, ignore_errors: bool = False, onerror: _SFTPOnErrorHandler = None) -> None: """Recursively delete a directory tree This method removes all the files in a directory tree. If ignore_errors is set, errors are ignored. Otherwise, if onerror is set, it will be called with arguments of the function which failed, the path it failed on, and exception information returns by :func:`sys.exc_info()`. If follow_symlinks is set, files or directories pointed at by symlinks (and their subdirectories, if any) will be removed in addition to the links pointing at them. :param path: The path of the parent directory to remove :param ignore_errors: (optional) Whether or not to ignore errors during the remove :param onerror: (optional) A function to call when errors occur :type path: :class:`PurePath `, `str`, or `bytes` :type ignore_errors: `bool` :type onerror: `callable` :raises: :exc:`SFTPError` if the server returns an error """ async def _unlink(path: bytes) -> None: """Internal helper for unlinking non-directories""" assert onerror is not None try: await self.unlink(path) except SFTPError: onerror(self.unlink, path, sys.exc_info()) async def _rmtree(path: bytes) -> None: """Internal helper for rmtree recursion""" assert onerror is not None tasks = [] try: async with sem: async for entry in self.scandir(path): filename = cast(bytes, entry.filename) if filename in (b'.', b'..'): continue filename = posixpath.join(path, filename) if entry.attrs.type == FILEXFER_TYPE_DIRECTORY: task = _rmtree(filename) else: task = _unlink(filename) tasks.append(asyncio.ensure_future(task)) except SFTPError: onerror(self.scandir, path, sys.exc_info()) results = await asyncio.gather(*tasks, return_exceptions=True) exc = next((result for result in results if isinstance(result, Exception)), None) if exc: raise exc try: await self.rmdir(path) except SFTPError: onerror(self.rmdir, path, sys.exc_info()) # pylint: disable=function-redefined if ignore_errors: def onerror(*_args: object) -> None: pass elif onerror is None: def onerror(*_args: object) -> None: raise # pylint: disable=misplaced-bare-raise # pylint: enable=function-redefined assert onerror is not None path = self.encode(path) sem = asyncio.Semaphore(_MAX_SFTP_REQUESTS) try: if await self.islink(path): raise SFTPNoSuchFile(path.decode('utf-8', 'backslashreplace') + ' must not be a symlink') except SFTPError: onerror(self.islink, path, sys.exc_info()) return await _rmtree(path) @async_context_manager async def open(self, path: _SFTPPath, pflags_or_mode: Union[int, str] = FXF_READ, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', block_size: int = -1, max_requests: int = -1) -> SFTPClientFile: """Open a remote file This method opens a remote file and returns an :class:`SFTPClientFile` object which can be used to read and write data and get and set file attributes. The path can be either a `str` or `bytes` value. If it is a str, it will be encoded using the file encoding specified when the :class:`SFTPClient` was started. The following open mode flags are supported: ========== ====================================================== Mode Description ========== ====================================================== FXF_READ Open the file for reading. FXF_WRITE Open the file for writing. If both this and FXF_READ are set, open the file for both reading and writing. FXF_APPEND Force writes to append data to the end of the file regardless of seek position. FXF_CREAT Create the file if it doesn't exist. Without this, attempts to open a non-existent file will fail. FXF_TRUNC Truncate the file to zero length if it already exists. FXF_EXCL Return an error when trying to open a file which already exists. ========== ====================================================== Instead of these flags, a Python open mode string can also be provided. Python open modes map to the above flags as follows: ==== ============================================= Mode Flags ==== ============================================= r FXF_READ w FXF_WRITE | FXF_CREAT | FXF_TRUNC a FXF_WRITE | FXF_CREAT | FXF_APPEND x FXF_WRITE | FXF_CREAT | FXF_EXCL r+ FXF_READ | FXF_WRITE w+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_TRUNC a+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_APPEND x+ FXF_READ | FXF_WRITE | FXF_CREAT | FXF_EXCL ==== ============================================= Including a 'b' in the mode causes the `encoding` to be set to `None`, forcing all data to be read and written as bytes in binary format. Most applications should be able to use this method regardless of the version of the SFTP protocol negotiated with the SFTP server. A conversion from the pflags_or_mode values to the SFTPv5/v6 flag values will happen automatically. However, if an application wishes to set flags only available in SFTPv5/v6, the :meth:`open56` method may be used to specify these flags explicitly. The attrs argument is used to set initial attributes of the file if it needs to be created. Otherwise, this argument is ignored. The block_size argument specifies the size of parallel read and write requests issued on the file. If set to `None`, each read or write call will become a single request to the SFTP server. Otherwise, read or write calls larger than this size will be turned into parallel requests to the server of the requested size, defaulting to the maximum allowed by the server, or 16 KB if the server doesn't advertise limits. .. note:: The OpenSSH SFTP server will close the connection if it receives a message larger than 256 KB. So, when connecting to an OpenSSH SFTP server, it is recommended that the block_size be left at its default of using the server-advertised limits. The max_requests argument specifies the maximum number of parallel read or write requests issued, defaulting to a value between 16 and 128 depending on the selected block size to avoid excessive memory usage. :param path: The name of the remote file to open :param pflags_or_mode: (optional) The access mode to use for the remote file (see above) :param attrs: (optional) File attributes to use if the file needs to be created :param encoding: (optional) The Unicode encoding to use for data read and written to the remote file :param errors: (optional) The error-handling mode if an invalid Unicode byte sequence is detected, defaulting to 'strict' which raises an exception :param block_size: (optional) The block size to use for read and write requests :param max_requests: (optional) The maximum number of parallel read or write requests :type path: :class:`PurePath `, `str`, or `bytes` :type pflags_or_mode: `int` or `str` :type attrs: :class:`SFTPAttrs` :type encoding: `str` :type errors: `str` :type block_size: `int` or `None` :type max_requests: `int` :returns: An :class:`SFTPClientFile` to use to access the file :raises: | :exc:`ValueError` if the mode is not valid | :exc:`SFTPError` if the server returns an error """ if isinstance(pflags_or_mode, str): pflags, binary = _mode_to_pflags(pflags_or_mode) if binary: encoding = None else: pflags = pflags_or_mode path = self.compose_path(path) handle = await self._handler.open(path, pflags, attrs) return SFTPClientFile(self._handler, handle, bool(pflags & FXF_APPEND), encoding, errors, block_size, max_requests) @async_context_manager async def open56(self, path: _SFTPPath, desired_access: int = ACE4_READ_DATA | ACE4_READ_ATTRIBUTES, flags: int = FXF_OPEN_EXISTING, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', block_size: int = -1, max_requests: int = -1) -> SFTPClientFile: """Open a remote file using SFTP v5/v6 flags This method is very similar to :meth:`open`, but the pflags_or_mode argument is replaced with SFTPv5/v6 desired_access and flags arguments. Most applications can continue to use :meth:`open` even when talking to an SFTPv5/v6 server and the translation of the flags will happen automatically. However, if an application wishes to set flags only available in SFTPv5/v6, this method provides that capability. The following desired_access flags can be specified: | ACE4_READ_DATA | ACE4_WRITE_DATA | ACE4_APPEND_DATA | ACE4_READ_ATTRIBUTES | ACE4_WRITE_ATTRIBUTES The following flags can be specified: | FXF_CREATE_NEW | FXF_CREATE_TRUNCATE | FXF_OPEN_EXISTING | FXF_OPEN_OR_CREATE | FXF_TRUNCATE_EXISTING | FXF_APPEND_DATA | FXF_APPEND_DATA_ATOMIC | FXF_BLOCK_READ | FXF_BLOCK_WRITE | FXF_BLOCK_DELETE | FXF_BLOCK_ADVISORY (SFTPv6) | FXF_NOFOLLOW (SFTPv6) | FXF_DELETE_ON_CLOSE (SFTPv6) | FXF_ACCESS_AUDIT_ALARM_INFO (SFTPv6) | FXF_ACCESS_BACKUP (SFTPv6) | FXF_BACKUP_STREAM (SFTPv6) | FXF_OVERRIDE_OWNER (SFTPv6) At this time, FXF_TEXT_MODE is not supported. Also, servers may support only a subset of these flags. For example, the AsyncSSH SFTP server doesn't currently support ACLs, file locking, or most of the SFTPv6 open flags, but support for some of these may be added over time. :param path: The name of the remote file to open :param desired_access: (optional) The access mode to use for the remote file (see above) :param flags: (optional) The access flags to use for the remote file (see above) :param attrs: (optional) File attributes to use if the file needs to be created :param encoding: (optional) The Unicode encoding to use for data read and written to the remote file :param errors: (optional) The error-handling mode if an invalid Unicode byte sequence is detected, defaulting to 'strict' which raises an exception :param block_size: (optional) The block size to use for read and write requests :param max_requests: (optional) The maximum number of parallel read or write requests :type path: :class:`PurePath `, `str`, or `bytes` :type desired_access: int :type flags: int :type attrs: :class:`SFTPAttrs` :type encoding: `str` :type errors: `str` :type block_size: `int` or `None` :type max_requests: `int` :returns: An :class:`SFTPClientFile` to use to access the file :raises: | :exc:`ValueError` if the mode is not valid | :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) handle = await self._handler.open56(path, desired_access, flags, attrs) return SFTPClientFile(self._handler, handle, bool(desired_access & ACE4_APPEND_DATA or flags & FXF_APPEND_DATA), encoding, errors, block_size, max_requests) async def stat(self, path: _SFTPPath, flags = FILEXFER_ATTR_DEFINED_V4, *, follow_symlinks: bool = True) -> SFTPAttrs: """Get attributes of a remote file, directory, or symlink This method queries the attributes of a remote file, directory, or symlink. If the path provided is a symlink and follow_symlinks is `True`, the returned attributes will correspond to the target of the link. :param path: The path of the remote file or directory to get attributes for :param flags: (optional) Flags indicating attributes of interest (SFTPv4 only) :param follow_symlinks: (optional) Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type flags: `int` :type follow_symlinks: `bool` :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) return await self._handler.stat(path, flags, follow_symlinks=follow_symlinks) async def lstat(self, path: _SFTPPath, flags = FILEXFER_ATTR_DEFINED_V4) -> SFTPAttrs: """Get attributes of a remote file, directory, or symlink This method queries the attributes of a remote file, directory, or symlink. Unlike :meth:`stat`, this method returns the attributes of a symlink itself rather than the target of that link. :param path: The path of the remote file, directory, or link to get attributes for :param flags: (optional) Flags indicating attributes of interest (SFTPv4 only) :type path: :class:`PurePath `, `str`, or `bytes` :type flags: `int` :returns: An :class:`SFTPAttrs` containing the file attributes :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) return await self._handler.lstat(path, flags) async def setstat(self, path: _SFTPPath, attrs: SFTPAttrs, *, follow_symlinks: bool = True) -> None: """Set attributes of a remote file, directory, or symlink This method sets attributes of a remote file, directory, or symlink. If the path provided is a symlink and follow_symlinks is `True`, the attributes will be set on the target of the link. A subset of the fields in `attrs` can be initialized and only those attributes will be changed. :param path: The path of the remote file or directory to set attributes for :param attrs: File attributes to set :type path: :class:`PurePath `, `str`, or `bytes` :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) await self._handler.setstat(path, attrs, follow_symlinks=follow_symlinks) async def statvfs(self, path: _SFTPPath) -> SFTPVFSAttrs: """Get attributes of a remote file system This method queries the attributes of the file system containing the specified path. :param path: The path of the remote file system to get attributes for :type path: :class:`PurePath `, `str`, or `bytes` :returns: An :class:`SFTPVFSAttrs` containing the file system attributes :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ path = self.compose_path(path) return await self._handler.statvfs(path) async def truncate(self, path: _SFTPPath, size: int) -> None: """Truncate a remote file to the specified size This method truncates a remote file to the specified size. If the path provided is a symbolic link, the target of the link will be truncated. :param path: The path of the remote file to be truncated :param size: The desired size of the file, in bytes :type path: :class:`PurePath `, `str`, or `bytes` :type size: `int` :raises: :exc:`SFTPError` if the server returns an error """ await self.setstat(path, SFTPAttrs(size=size)) @overload async def chown(self, path: _SFTPPath, uid: int, gid: int, *, follow_symlinks: bool = True) -> \ None: ... # pragma: no cover @overload async def chown(self, path: _SFTPPath, owner: str, group: str, *, follow_symlinks: bool = True) -> \ None: ... # pragma: no cover async def chown(self, path, uid_or_owner = None, gid_or_group = None, uid = None, gid = None, owner = None, group = None, *, follow_symlinks = True): """Change the owner of a remote file, directory, or symlink This method changes the user and group id of a remote file, directory, or symlink. If the path provided is a symlink and follow_symlinks is `True`, the target of the link will be changed. :param path: The path of the remote file to change :param uid: The new user id to assign to the file :param gid: The new group id to assign to the file :param owner: The new owner to assign to the file (SFTPv4 only) :param group: The new group to assign to the file (SFTPv4 only) :param follow_symlinks: (optional) Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type uid: `int` :type gid: `int` :type owner: `str` :type group: `str` :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error """ if isinstance(uid_or_owner, int): uid = uid_or_owner elif isinstance(uid_or_owner, str): owner = uid_or_owner if isinstance(gid_or_group, int): gid = gid_or_group elif isinstance(gid_or_group, str): group = gid_or_group await self.setstat(path, SFTPAttrs(uid=uid, gid=gid, owner=owner, group=group), follow_symlinks=follow_symlinks) async def chmod(self, path: _SFTPPath, mode: int, *, follow_symlinks: bool = True) -> None: """Change the permissions of a remote file, directory, or symlink This method changes the permissions of a remote file, directory, or symlink. If the path provided is a symlink and follow_symlinks is `True`, the target of the link will be changed. :param path: The path of the remote file to change :param mode: The new file permissions, expressed as an int :param follow_symlinks: (optional) Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type mode: `int` :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error """ await self.setstat(path, SFTPAttrs(permissions=mode), follow_symlinks=follow_symlinks) async def utime(self, path: _SFTPPath, times: Optional[Tuple[float, float]] = None, ns: Optional[Tuple[int, int]] = None, *, follow_symlinks: bool = True) -> None: """Change the timestamps of a remote file, directory, or symlink This method changes the access and modify times of a remote file, directory, or symlink. If neither `times` nor '`ns` is provided, the times will be changed to the current time. If the path provided is a symlink and follow_symlinks is `True`, the target of the link will be changed. :param path: The path of the remote file to change :param times: (optional) The new access and modify times, as seconds relative to the UNIX epoch :param ns: (optional) The new access and modify times, as nanoseconds relative to the UNIX epoch :param follow_symlinks: (optional) Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type times: tuple of two `int` or `float` values :type ns: tuple of two `int` values :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error """ await self.setstat(path, _utime_to_attrs(times, ns), follow_symlinks=follow_symlinks) async def exists(self, path: _SFTPPath) -> bool: """Return if the remote path exists and isn't a broken symbolic link :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self._type(path)) != FILEXFER_TYPE_UNKNOWN async def lexists(self, path: _SFTPPath) -> bool: """Return if the remote path exists, without following symbolic links :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self._type(path, statfunc=self.lstat)) != \ FILEXFER_TYPE_UNKNOWN async def getatime(self, path: _SFTPPath) -> Optional[float]: """Return the last access time of a remote file or directory :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_float_sec(attrs.atime, attrs.atime_ns) \ if attrs.atime is not None else None async def getatime_ns(self, path: _SFTPPath) -> Optional[int]: """Return the last access time of a remote file or directory The time returned is nanoseconds since the epoch. :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_nsec(attrs.atime, attrs.atime_ns) \ if attrs.atime is not None else None async def getcrtime(self, path: _SFTPPath) -> Optional[float]: """Return the creation time of a remote file or directory (SFTPv4 only) :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_float_sec(attrs.crtime, attrs.crtime_ns) \ if attrs.crtime is not None else None async def getcrtime_ns(self, path: _SFTPPath) -> Optional[int]: """Return the creation time of a remote file or directory The time returned is nanoseconds since the epoch. :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_nsec(attrs.crtime, attrs.crtime_ns) \ if attrs.crtime is not None else None async def getmtime(self, path: _SFTPPath) -> Optional[float]: """Return the last modification time of a remote file or directory :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_float_sec(attrs.mtime, attrs.mtime_ns) \ if attrs.mtime is not None else None async def getmtime_ns(self, path: _SFTPPath) -> Optional[int]: """Return the last modification time of a remote file or directory The time returned is nanoseconds since the epoch. :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ attrs = await self.stat(path) return _tuple_to_nsec(attrs.mtime, attrs.mtime_ns) \ if attrs.mtime is not None else None async def getsize(self, path: _SFTPPath) -> Optional[int]: """Return the size of a remote file or directory :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self.stat(path)).size async def isdir(self, path: _SFTPPath) -> bool: """Return if the remote path refers to a directory :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self._type(path)) == FILEXFER_TYPE_DIRECTORY async def isfile(self, path: _SFTPPath) -> bool: """Return if the remote path refers to a regular file :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self._type(path)) == FILEXFER_TYPE_REGULAR async def islink(self, path: _SFTPPath) -> bool: """Return if the remote path refers to a symbolic link :param path: The remote path to check :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ return (await self._type(path, statfunc=self.lstat)) == \ FILEXFER_TYPE_SYMLINK async def remove(self, path: _SFTPPath) -> None: """Remove a remote file This method removes a remote file or symbolic link. :param path: The path of the remote file or link to remove :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) await self._handler.remove(path) async def unlink(self, path: _SFTPPath) -> None: """Remove a remote file (see :meth:`remove`)""" await self.remove(path) async def rename(self, oldpath: _SFTPPath, newpath: _SFTPPath, flags: int = 0) -> None: """Rename a remote file, directory, or link This method renames a remote file, directory, or link. .. note:: By default, this version of rename will not overwrite the new path if it already exists. However, this can be controlled using the `flags` argument, available in SFTPv5 and later. When a connection is negotiated to use an earliler version of SFTP and `flags` is set, this method will attempt to fall back to the OpenSSH "posix-rename" extension if it is available. That can also be invoked directly by calling :meth:`posix_rename`. :param oldpath: The path of the remote file, directory, or link to rename :param newpath: The new name for this file, directory, or link :param flags: (optional) A combination of the `FXR_OVERWRITE`, `FXR_ATOMIC`, and `FXR_NATIVE` flags to specify what happens when `newpath` already exists, defaulting to not allowing the overwrite (SFTPv5 and later) :type oldpath: :class:`PurePath `, `str`, or `bytes` :type newpath: :class:`PurePath `, `str`, or `bytes` :type flags: `int` :raises: :exc:`SFTPError` if the server returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) await self._handler.rename(oldpath, newpath, flags) async def posix_rename(self, oldpath: _SFTPPath, newpath: _SFTPPath) -> None: """Rename a remote file, directory, or link with POSIX semantics This method renames a remote file, directory, or link, removing the prior instance of new path if it previously existed. This method may not be supported by all SFTP servers. If it is not available but the server supports SFTPv5 or later, this method will attempt to send the standard SFTP rename request with the `FXR_OVERWRITE` flag set. :param oldpath: The path of the remote file, directory, or link to rename :param newpath: The new name for this file, directory, or link :type oldpath: :class:`PurePath `, `str`, or `bytes` :type newpath: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) await self._handler.posix_rename(oldpath, newpath) async def scandir(self, path: _SFTPPath = '.') -> AsyncIterator[SFTPName]: """Return names and attributes of the files in a remote directory This method reads the contents of a directory, returning the names and attributes of what is contained there as an async iterator. If no path is provided, it defaults to the current remote working directory. :param path: (optional) The path of the remote directory to read :type path: :class:`PurePath `, `str`, or `bytes` :returns: An async iterator of :class:`SFTPName` entries, with path names matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ dirpath = self.compose_path(path) handle = await self._handler.opendir(dirpath) at_end = False try: while not at_end: names, at_end = await self._handler.readdir(handle) for entry in names: if isinstance(path, (str, PurePath)): entry.filename = \ self.decode(cast(bytes, entry.filename)) if entry.longname is not None: entry.longname = \ self.decode(cast(bytes, entry.longname)) yield entry except SFTPEOFError: pass finally: await self._handler.close(handle) async def readdir(self, path: _SFTPPath = '.') -> Sequence[SFTPName]: """Read the contents of a remote directory This method reads the contents of a directory, returning the names and attributes of what is contained there. If no path is provided, it defaults to the current remote working directory. :param path: (optional) The path of the remote directory to read :type path: :class:`PurePath `, `str`, or `bytes` :returns: A list of :class:`SFTPName` entries, with path names matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ return [entry async for entry in self.scandir(path)] @overload async def listdir(self, path: bytes) -> \ Sequence[bytes]: ... # pragma: no cover @overload async def listdir(self, path: FilePath = ...) -> \ Sequence[str]: ... # pragma: no cover async def listdir(self, path: _SFTPPath = '.') -> Sequence[BytesOrStr]: """Read the names of the files in a remote directory This method reads the names of files and subdirectories in a remote directory. If no path is provided, it defaults to the current remote working directory. :param path: (optional) The path of the remote directory to read :type path: :class:`PurePath `, `str`, or `bytes` :returns: A list of file/subdirectory names, as a `str` or `bytes` matching the type used to pass in the path :raises: :exc:`SFTPError` if the server returns an error """ names = await self.readdir(path) return [name.filename for name in names] async def mkdir(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs()) -> None: """Create a remote directory with the specified attributes This method creates a new remote directory at the specified path with the requested attributes. :param path: The path of where the new remote directory should be created :param attrs: (optional) The file attributes to use when creating the directory :type path: :class:`PurePath `, `str`, or `bytes` :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) await self._handler.mkdir(path, attrs) async def rmdir(self, path: _SFTPPath) -> None: """Remove a remote directory This method removes a remote directory. The directory must be empty for the removal to succeed. :param path: The path of the remote directory to remove :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ path = self.compose_path(path) await self._handler.rmdir(path) @overload async def realpath(self, path: bytes, # pragma: no cover *compose_paths: bytes) -> bytes: ... @overload async def realpath(self, path: FilePath, # pragma: no cover *compose_paths: FilePath) -> str: ... @overload async def realpath(self, path: bytes, # pragma: no cover *compose_paths: bytes, check: int) -> SFTPName: ... @overload async def realpath(self, path: FilePath, # pragma: no cover *compose_paths: FilePath, check: int) -> SFTPName: ... async def realpath(self, path: _SFTPPath, *compose_paths: _SFTPPath, check: int = FXRP_NO_CHECK) -> \ Union[BytesOrStr, SFTPName]: """Return the canonical version of a remote path This method returns a canonical version of the requested path. :param path: (optional) The path of the remote directory to canonicalize :param compose_paths: (optional) A list of additional paths that the server should compose with `path` before canonicalizing it :param check: (optional) One of `FXRP_NO_CHECK`, `FXRP_STAT_IF_EXISTS`, and `FXRP_STAT_ALWAYS`, specifying when to perform a stat operation on the resulting path, defaulting to `FXRP_NO_CHECK` :type path: :class:`PurePath `, `str`, or `bytes` :type compose_paths: :class:`PurePath `, `str`, or `bytes` :type check: int :returns: The canonical path as a `str` or `bytes`, matching the type used to pass in the path if `check` is set to `FXRP_NO_CHECK`, or an :class:`SFTPName` containing the canonical path name and attributes otherwise :raises: :exc:`SFTPError` if the server returns an error """ if compose_paths and isinstance(compose_paths[-1], int): check = compose_paths[-1] compose_paths = compose_paths[:-1] path_bytes = self.compose_path(path) if self.version >= 6: names, _ = await self._handler.realpath( path_bytes, *map(self.encode, compose_paths), check=check) else: for cpath in compose_paths: path_bytes = self.compose_path(cpath, path_bytes) names, _ = await self._handler.realpath(path_bytes) if len(names) > 1: raise SFTPBadMessage('Too many names returned') if check != FXRP_NO_CHECK: if self.version < 6: try: names[0].attrs = await self._handler.stat( self.encode(names[0].filename), _valid_attr_flags[self.version]) except SFTPError: if check == FXRP_STAT_IF_EXISTS: names[0].attrs = SFTPAttrs(type=FILEXFER_TYPE_UNKNOWN) else: raise return names[0] else: return self.decode(cast(bytes, names[0].filename), isinstance(path, (str, PurePath))) async def getcwd(self) -> BytesOrStr: """Return the current remote working directory :returns: The current remote working directory, decoded using the specified path encoding :raises: :exc:`SFTPError` if the server returns an error """ if self._cwd is None: self._cwd = await self.realpath(b'.') return self.decode(self._cwd) async def chdir(self, path: _SFTPPath) -> None: """Change the current remote working directory :param path: The path to set as the new remote working directory :type path: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ self._cwd = await self.realpath(self.encode(path)) @overload async def readlink(self, path: bytes) -> bytes: ... # pragma: no cover @overload async def readlink(self, path: FilePath) -> str: ... # pragma: no cover async def readlink(self, path: _SFTPPath) -> BytesOrStr: """Return the target of a remote symbolic link This method returns the target of a symbolic link. :param path: The path of the remote symbolic link to follow :type path: :class:`PurePath `, `str`, or `bytes` :returns: The target path of the link as a `str` or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ linkpath = self.compose_path(path) names, _ = await self._handler.readlink(linkpath) if len(names) > 1: raise SFTPBadMessage('Too many names returned') return self.decode(cast(bytes, names[0].filename), isinstance(path, (str, PurePath))) async def symlink(self, oldpath: _SFTPPath, newpath: _SFTPPath) -> None: """Create a remote symbolic link This method creates a symbolic link. The argument order here matches the standard Python :meth:`os.symlink` call. The argument order sent on the wire is automatically adapted depending on the version information sent by the server, as a number of servers (OpenSSH in particular) did not follow the SFTP standard when implementing this call. :param oldpath: The path the link should point to :param newpath: The path of where to create the remote symbolic link :type oldpath: :class:`PurePath `, `str`, or `bytes` :type newpath: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server returns an error """ oldpath = self.encode(oldpath) newpath = self.compose_path(newpath) await self._handler.symlink(oldpath, newpath) async def link(self, oldpath: _SFTPPath, newpath: _SFTPPath) -> None: """Create a remote hard link This method creates a hard link to the remote file specified by oldpath at the location specified by newpath. This method may not be supported by all SFTP servers. :param oldpath: The path of the remote file the hard link should point to :param newpath: The path of where to create the remote hard link :type oldpath: :class:`PurePath `, `str`, or `bytes` :type newpath: :class:`PurePath `, `str`, or `bytes` :raises: :exc:`SFTPError` if the server doesn't support this extension or returns an error """ oldpath = self.compose_path(oldpath) newpath = self.compose_path(newpath) await self._handler.link(oldpath, newpath) def exit(self) -> None: """Exit the SFTP client session This method exits the SFTP client session, closing the corresponding channel opened on the server. """ self._handler.exit() async def wait_closed(self) -> None: """Wait for this SFTP client session to close""" await self._handler.wait_closed() class SFTPServerHandler(SFTPHandler): """An SFTP server session handler""" # Supported attribute flags in setstat/fsetstat/lsetstat _supported_attr_mask = FILEXFER_ATTR_SIZE | \ FILEXFER_ATTR_PERMISSIONS | \ FILEXFER_ATTR_ACCESSTIME | \ FILEXFER_ATTR_MODIFYTIME | \ FILEXFER_ATTR_OWNERGROUP | \ FILEXFER_ATTR_SUBSECOND_TIMES # No attrib bits currently supported _supported_attrib_mask = 0 # Supported SFTPv5/v6 open flags _supported_open_flags = FXF_ACCESS_DISPOSITION | FXF_APPEND_DATA # Supported SFTPv5/v6 desired access flags _supported_access_mask = ACE4_READ_DATA | ACE4_WRITE_DATA | \ ACE4_APPEND_DATA | ACE4_READ_ATTRIBUTES | \ ACE4_WRITE_ATTRIBUTES # Locking not currently supported _supported_open_block_vector = _supported_block_vector = 0x0001 _vendor_id = String(__author__) + String('AsyncSSH') + \ String(__version__) + UInt64(0) _extensions: List[Tuple[bytes, bytes]] = [ (b'newline', os.linesep.encode('utf-8')), (b'vendor-id', _vendor_id), (b'posix-rename@openssh.com', b'1'), (b'hardlink@openssh.com', b'1'), (b'fsync@openssh.com', b'1'), (b'lsetstat@openssh.com', b'1'), (b'limits@openssh.com', b'1'), (b'copy-data', b'1')] _attrib_extensions: List[bytes] = [] if hasattr(os, 'statvfs'): # pragma: no branch _extensions += [(b'statvfs@openssh.com', b'2'), (b'fstatvfs@openssh.com', b'2')] def __init__(self, server: 'SFTPServer', reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', sftp_version: int): super().__init__(reader, writer) self._server = server self._version = sftp_version self._nonstandard_symlink = False self._next_handle = 0 self._file_handles: Dict[bytes, object] = {} self._dir_handles: Dict[bytes, AsyncIterator[SFTPName]] = {} async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP server session""" if self._server: # pragma: no branch for file_obj in list(self._file_handles.values()): result = self._server.close(file_obj) if inspect.isawaitable(result): assert result is not None await result self._server.exit() self._file_handles = {} self._dir_handles = {} self.logger.info('SFTP server exited%s', ': ' + str(exc) if exc else '') await super()._cleanup(exc) def _get_next_handle(self) -> bytes: """Get the next available unique file handle number""" while True: handle = self._next_handle.to_bytes(4, 'big') self._next_handle = (self._next_handle + 1) & 0xffffffff if (handle not in self._file_handles and handle not in self._dir_handles): return handle async def _process_packet(self, pkttype: int, pktid: int, packet: SSHPacket) -> None: """Process incoming SFTP requests""" # pylint: disable=broad-except try: if pkttype == FXP_EXTENDED: handler_type: Union[int, bytes] = packet.get_string() else: handler_type = pkttype handler = self._packet_handlers.get(handler_type) if not handler: raise SFTPOpUnsupported(f'Unsupported request type: {pkttype}') return_type = self._return_types.get(handler_type, FXP_STATUS) result = await handler(self, packet) if return_type == FXP_STATUS: self.logger.debug1('Sending OK') response = UInt32(FX_OK) + String('') + String('') elif return_type == FXP_HANDLE: handle = cast(bytes, result) self.logger.debug1('Sending handle %s', handle.hex()) response = String(handle) elif return_type == FXP_DATA: data, at_end = cast(Tuple[bytes, bool], result) self.logger.debug1('Sending %s%s', plural(len(data), 'data byte'), ' (at end)' if at_end else '') end = Boolean(at_end) if at_end and self._version >= 6 else b'' response = String(data) + end elif return_type == FXP_NAME: names, at_end = cast(_SFTPNames, result) self.logger.debug1('Sending %s%s', plural(len(names), 'name'), ' (at end)' if at_end else '') for name in names: self.logger.debug1(' %s', name) end = Boolean(at_end) if at_end and self._version >= 6 else b'' response = (UInt32(len(names)) + b''.join(name.encode(self._version) for name in names) + end) elif isinstance(result, SFTPLimits): self.logger.debug1('Sending server limits:') self._log_limits(result) response = result.encode(self._version) else: attrs: _SupportsEncode if isinstance(result, os.stat_result): attrs = SFTPAttrs.from_local(cast(os.stat_result, result)) elif isinstance(result, os.statvfs_result): attrs = SFTPVFSAttrs.from_local(cast(os.statvfs_result, result)) else: attrs = cast(_SupportsEncode, result) self.logger.debug1('Sending %s', attrs) response = attrs.encode(self._version) except PacketDecodeError as exc: return_type = FXP_STATUS self.logger.debug1('Sending bad message error: %s', str(exc)) response = (UInt32(FX_BAD_MESSAGE) + String(str(exc)) + String(DEFAULT_LANG)) except SFTPError as exc: return_type = FXP_STATUS if exc.code == FX_EOF: self.logger.debug1('Sending EOF') else: self.logger.debug1('Sending %s: %s', exc.__class__.__name__, str(exc.reason)) response = exc.encode(self._version) except NotImplementedError: assert handler is not None return_type = FXP_STATUS op_name = handler.__name__[9:] self.logger.debug1('Sending operation not supported: %s', op_name) response = (UInt32(FX_OP_UNSUPPORTED) + String(f'Operation not supported: {op_name}') + String(DEFAULT_LANG)) except OSError as exc: return_type = FXP_STATUS reason = exc.strerror or str(exc) if exc.errno == errno.ENOENT: self.logger.debug1('Sending no such file: %s', reason) code = FX_NO_SUCH_FILE elif exc.errno == errno.EACCES: self.logger.debug1('Sending permission denied: %s', reason) code = FX_PERMISSION_DENIED elif exc.errno == errno.EEXIST: self.logger.debug1('Sending file already exists: %s', reason) code = FX_FILE_ALREADY_EXISTS elif exc.errno == errno.EROFS: self.logger.debug1('Sending write protect: %s', reason) code = FX_WRITE_PROTECT elif exc.errno == errno.ENOSPC: self.logger.debug1('Sending no space on ' 'filesystem: %s', reason) code = FX_NO_SPACE_ON_FILESYSTEM elif exc.errno == errno.EDQUOT: self.logger.debug1('Sending disk quota exceeded: %s', reason) code = FX_QUOTA_EXCEEDED elif exc.errno == errno.ENOTEMPTY: self.logger.debug1('Sending directory not empty: %s', reason) code = FX_DIR_NOT_EMPTY elif exc.errno == errno.ENOTDIR: self.logger.debug1('Sending not a directory: %s', reason) code = FX_NOT_A_DIRECTORY elif exc.errno in (errno.ENAMETOOLONG, errno.EILSEQ): self.logger.debug1('Sending invalid filename: %s', reason) code = FX_INVALID_FILENAME elif exc.errno == errno.ELOOP: self.logger.debug1('Sending link loop: %s', reason) code = FX_LINK_LOOP elif exc.errno == errno.EINVAL: self.logger.debug1('Sending invalid parameter: %s', reason) code = FX_INVALID_PARAMETER elif exc.errno == errno.EISDIR: self.logger.debug1('Sending file is a directory: %s', reason) code = FX_FILE_IS_A_DIRECTORY else: self.logger.debug1('Sending failure: %s', reason) code = FX_FAILURE response = SFTPError(code, reason).encode(self._version) except Exception as exc: # pragma: no cover return_type = FXP_STATUS reason = f'Uncaught exception: {exc}' self.logger.debug1('Sending failure: %s', reason, exc_info=sys.exc_info) response = (UInt32(FX_FAILURE) + String(reason) + String(DEFAULT_LANG)) self.send_packet(return_type, pktid, UInt32(pktid), response) async def _process_open(self, packet: SSHPacket) -> bytes: """Process an incoming SFTP open request""" path = packet.get_string() if self._version >= 5: desired_access = packet.get_uint32() flags = packet.get_uint32() flagmsg = f'access=0x{desired_access:04x}, flags=0x{flags:04x}' else: pflags = packet.get_uint32() flagmsg = f'pflags=0x{pflags:02x}' attrs = SFTPAttrs.decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received open request for %s, %s%s', path, flagmsg, hide_empty(attrs)) if self._version >= 5: unsupported_access = desired_access & ~self._supported_access_mask if unsupported_access: raise SFTPInvalidParameter( f'Unsupported access flags: 0x{unsupported_access:08x}') unsupported_flags = flags & ~self._supported_open_flags if unsupported_flags: raise SFTPInvalidParameter( f'Unsupported open flags: 0x{unsupported_flags:08x}') result = self._server.open56(path, desired_access, flags, attrs) else: result = self._server.open(path, pflags, attrs) if inspect.isawaitable(result): result = await cast(Awaitable[object], result) handle = self._get_next_handle() self._file_handles[handle] = result return handle async def _process_close(self, packet: SSHPacket) -> None: """Process an incoming SFTP close request""" handle = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received close for handle %s', handle.hex()) file_obj = self._file_handles.pop(handle, None) if file_obj: result = self._server.close(file_obj) if inspect.isawaitable(result): assert result is not None await result return if self._dir_handles.pop(handle, None) is not None: return raise SFTPInvalidHandle('Invalid file handle') async def _process_read(self, packet: SSHPacket) -> Tuple[bytes, bool]: """Process an incoming SFTP read request""" handle = packet.get_string() offset = packet.get_uint64() length = packet.get_uint32() if self._version < 6: packet.check_end() self.logger.debug1('Received read for %s at offset %d in handle %s', plural(length, 'byte'), offset, handle.hex()) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.read(file_obj, offset, length) if inspect.isawaitable(result): result = await cast(Awaitable[bytes], result) result: bytes if self._version >= 6: attrs = await self._server.convert_attrs( self._server.fstat(file_obj)) at_end = offset + len(result) == attrs.size else: at_end = False if result: return cast(bytes, result), at_end else: raise SFTPEOFError else: raise SFTPInvalidHandle('Invalid file handle') async def _process_write(self, packet: SSHPacket) -> int: """Process an incoming SFTP write request""" handle = packet.get_string() offset = packet.get_uint64() data = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received write for %s at offset %d in handle %s', plural(len(data), 'byte'), offset, handle.hex()) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.write(file_obj, offset, data) if inspect.isawaitable(result): result = await cast(Awaitable[int], result) return cast(int, result) else: raise SFTPInvalidHandle('Invalid file handle') async def _process_lstat(self, packet: SSHPacket) -> _SFTPOSAttrs: """Process an incoming SFTP lstat request""" path = packet.get_string() flags = packet.get_uint32()if self._version >= 4 else 0 if self._version < 6: packet.check_end() self.logger.debug1('Received lstat for %s%s', path, f', flags=0x{flags:08x}' if flags else '') # Ignore flags for now, returning all available fields result = self._server.lstat(path) if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSAttrs], result) result: _SFTPOSAttrs return result async def _process_fstat(self, packet: SSHPacket) -> _SFTPOSAttrs: """Process an incoming SFTP fstat request""" handle = packet.get_string() flags = packet.get_uint32() if self._version >= 4 else 0 if self._version < 6: packet.check_end() self.logger.debug1('Received fstat for handle %s%s', handle.hex(), f', flags=0x{flags:08x}' if flags else '') file_obj = self._file_handles.get(handle) if file_obj: # Ignore flags for now, returning all available fields result = self._server.fstat(file_obj) if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSAttrs], result) result: _SFTPOSAttrs return result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_setstat(self, packet: SSHPacket) -> None: """Process an incoming SFTP setstat request""" path = packet.get_string() attrs = SFTPAttrs.decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received setstat for %s%s', path, hide_empty(attrs)) result = self._server.setstat(path, attrs) if inspect.isawaitable(result): assert result is not None await result async def _process_fsetstat(self, packet: SSHPacket) -> None: """Process an incoming SFTP fsetstat request""" handle = packet.get_string() attrs = SFTPAttrs.decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received fsetstat for handle %s%s', handle.hex(), hide_empty(attrs)) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fsetstat(file_obj, attrs) if inspect.isawaitable(result): assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_opendir(self, packet: SSHPacket) -> bytes: """Process an incoming SFTP opendir request""" path = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received opendir for %s', path) handle = self._get_next_handle() self._dir_handles[handle] = self._server.scandir(path) return handle async def _process_readdir(self, packet: SSHPacket) -> _SFTPNames: """Process an incoming SFTP readdir request""" handle = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received readdir for handle %s', handle.hex()) names = self._dir_handles.get(handle) if names: count = 0 result: List[SFTPName] = [] async for name in names: if not name.longname and self._version == 3: longname_result = self._server.format_longname(name) if inspect.isawaitable(longname_result): assert longname_result is not None await longname_result result.append(name) count += 1 if count == _MAX_READDIR_NAMES: break if result: return result, count < _MAX_READDIR_NAMES else: raise SFTPEOFError else: raise SFTPInvalidHandle('Invalid file handle') async def _process_remove(self, packet: SSHPacket) -> None: """Process an incoming SFTP remove request""" path = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received remove for %s', path) result = self._server.remove(path) if inspect.isawaitable(result): assert result is not None await result async def _process_mkdir(self, packet: SSHPacket) -> None: """Process an incoming SFTP mkdir request""" path = packet.get_string() attrs = SFTPAttrs.decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received mkdir for %s', path) result = self._server.mkdir(path, attrs) if inspect.isawaitable(result): assert result is not None await result async def _process_rmdir(self, packet: SSHPacket) -> None: """Process an incoming SFTP rmdir request""" path = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received rmdir for %s', path) result = self._server.rmdir(path) if inspect.isawaitable(result): assert result is not None await result async def _process_realpath(self, packet: SSHPacket) -> _SFTPNames: """Process an incoming SFTP realpath request""" path = packet.get_string() checkmsg = '' compose_paths: List[bytes] = [] if self._version >= 6: check = packet.get_byte() while packet: compose_paths.append(packet.get_string()) try: checkmsg = f', check={self._realpath_check_names[check]}' except KeyError: raise SFTPInvalidParameter('Invalid check value') from None else: check = FXRP_NO_CHECK self.logger.debug1('Received realpath for %s%s%s', path, b', compose_path: ' + b', '.join(compose_paths) if compose_paths else b'', checkmsg) for cpath in compose_paths: path = posixpath.join(path, cpath) result = self._server.realpath(path) if inspect.isawaitable(result): result = await cast(Awaitable[bytes], result) result: bytes attrs = SFTPAttrs() if check != FXRP_NO_CHECK: try: attrs = await self._server.convert_attrs( self._server.stat(result)) except (OSError, SFTPError): if check == FXRP_STAT_ALWAYS: raise return [SFTPName(result, attrs=attrs)], False async def _process_stat(self, packet: SSHPacket) -> _SFTPOSAttrs: """Process an incoming SFTP stat request""" path = packet.get_string() flags = packet.get_uint32() if self._version >= 4 else 0 if self._version < 6: packet.check_end() self.logger.debug1('Received stat for %s%s', path, f', flags=0x{flags:08x}' if flags else '') # Ignore flags for now, returning all available fields result = self._server.stat(path) if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSAttrs], result) result: _SFTPOSAttrs return result async def _process_rename(self, packet: SSHPacket) -> None: """Process an incoming SFTP rename request""" oldpath = packet.get_string() newpath = packet.get_string() if self._version >= 5: flags = packet.get_uint32() flag_text = f', flags=0x{flags:08x}' else: flags = 0 flag_text = '' if self._version < 6: packet.check_end() self.logger.debug1('Received rename request from %s to %s%s', oldpath, newpath, flag_text) if flags: result = self._server.posix_rename(oldpath, newpath) else: result = self._server.rename(oldpath, newpath) if inspect.isawaitable(result): assert result is not None await result async def _process_readlink(self, packet: SSHPacket) -> _SFTPNames: """Process an incoming SFTP readlink request""" path = packet.get_string() if self._version < 6: packet.check_end() self.logger.debug1('Received readlink for %s', path) result = self._server.readlink(path) if inspect.isawaitable(result): result = await cast(Awaitable[bytes], result) result: bytes return [SFTPName(result)], False async def _process_symlink(self, packet: SSHPacket) -> None: """Process an incoming SFTP symlink request""" if self._nonstandard_symlink: oldpath = packet.get_string() newpath = packet.get_string() else: newpath = packet.get_string() oldpath = packet.get_string() packet.check_end() self.logger.debug1('Received symlink request from %s to %s', oldpath, newpath) result = self._server.symlink(oldpath, newpath) if inspect.isawaitable(result): assert result is not None await result async def _process_link(self, packet: SSHPacket) -> None: """Process an incoming SFTP hard link request""" newpath = packet.get_string() oldpath = packet.get_string() symlink = packet.get_boolean() if symlink: self.logger.debug1('Received symlink request from %s to %s', oldpath, newpath) result = self._server.symlink(oldpath, newpath) else: self.logger.debug1('Received hardlink request from %s to %s', oldpath, newpath) result = self._server.link(oldpath, newpath) if inspect.isawaitable(result): assert result is not None await result async def _process_lock(self, packet: SSHPacket) -> None: """Process an incoming SFTP byte range lock request""" handle = packet.get_string() offset = packet.get_uint64() length = packet.get_uint64() flags = packet.get_uint32() self.logger.debug1('Received byte range lock request for ' 'handle %s, offset %d, length %d, ' 'flags 0x%04x', handle.hex(), offset, length, flags) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.lock(file_obj, offset, length, flags) if inspect.isawaitable(result): # pragma: no branch assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_unlock(self, packet: SSHPacket) -> None: """Process an incoming SFTP byte range unlock request""" handle = packet.get_string() offset = packet.get_uint64() length = packet.get_uint64() self.logger.debug1('Received byte range lock request for ' 'handle %s, offset %d, length %d', handle.hex(), offset, length) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.unlock(file_obj, offset, length) if inspect.isawaitable(result): # pragma: no branch assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_posix_rename(self, packet: SSHPacket) -> None: """Process an incoming SFTP POSIX rename request""" oldpath = packet.get_string() newpath = packet.get_string() packet.check_end() self.logger.debug1('Received POSIX rename request from %s to %s', oldpath, newpath) result = self._server.posix_rename(oldpath, newpath) if inspect.isawaitable(result): assert result is not None await result async def _process_statvfs(self, packet: SSHPacket) -> _SFTPOSVFSAttrs: """Process an incoming SFTP statvfs request""" path = packet.get_string() packet.check_end() self.logger.debug1('Received statvfs for %s', path) result = self._server.statvfs(path) if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSVFSAttrs], result) result: _SFTPOSVFSAttrs return result async def _process_fstatvfs(self, packet: SSHPacket) -> _SFTPOSVFSAttrs: """Process an incoming SFTP fstatvfs request""" handle = packet.get_string() packet.check_end() self.logger.debug1('Received fstatvfs for handle %s', handle.hex()) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fstatvfs(file_obj) if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSVFSAttrs], result) result: _SFTPOSVFSAttrs return result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_openssh_link(self, packet: SSHPacket) -> None: """Process an incoming SFTP hard link request""" oldpath = packet.get_string() newpath = packet.get_string() packet.check_end() self.logger.debug1('Received hardlink request from %s to %s', oldpath, newpath) result = self._server.link(oldpath, newpath) if inspect.isawaitable(result): assert result is not None await result async def _process_fsync(self, packet: SSHPacket) -> None: """Process an incoming SFTP fsync request""" handle = packet.get_string() packet.check_end() self.logger.debug1('Received fsync for handle %s', handle.hex()) file_obj = self._file_handles.get(handle) if file_obj: result = self._server.fsync(file_obj) if inspect.isawaitable(result): assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') async def _process_lsetstat(self, packet: SSHPacket) -> None: """Process an incoming SFTP lsetstat request""" path = packet.get_string() attrs = SFTPAttrs.decode(packet, self._version) if self._version < 6: packet.check_end() self.logger.debug1('Received lsetstat for %s%s', path, hide_empty(attrs)) result = self._server.lsetstat(path, attrs) if inspect.isawaitable(result): assert result is not None await result async def _process_limits(self, packet: SSHPacket) -> SFTPLimits: """Process an incoming SFTP limits request""" packet.check_end() nfiles = os.sysconf('SC_OPEN_MAX') - 5 if hasattr(os, 'sysconf') else 0 return SFTPLimits(MAX_SFTP_PACKET_LEN, MAX_SFTP_READ_LEN, MAX_SFTP_WRITE_LEN, nfiles) async def _process_copy_data(self, packet: SSHPacket) -> None: """Process an incoming copy data request""" read_from_handle = packet.get_string() read_from_offset = packet.get_uint64() read_from_length = packet.get_uint64() write_to_handle = packet.get_string() write_to_offset = packet.get_uint64() packet.check_end() self.logger.debug1('Received copy-data from handle %s, ' 'offset %d, length %d to handle %s, ' 'offset %d', read_from_handle.hex(), read_from_offset, read_from_length, write_to_handle.hex(), write_to_offset) src = self._file_handles.get(read_from_handle) dst = self._file_handles.get(write_to_handle) if src and dst: read_to_end = read_from_length == 0 while read_to_end or read_from_length: if read_to_end: size = _COPY_DATA_BLOCK_SIZE else: size = min(read_from_length, _COPY_DATA_BLOCK_SIZE) data = self._server.read(src, read_from_offset, size) if inspect.isawaitable(data): data = await cast(Awaitable[bytes], data) result = self._server.write(dst, write_to_offset, data) if inspect.isawaitable(result): await result if len(data) < size: break read_from_offset += size write_to_offset += size if not read_to_end: read_from_length -= size else: raise SFTPInvalidHandle('Invalid file handle') _packet_handlers: Dict[Union[int, bytes], _SFTPPacketHandler] = { FXP_OPEN: _process_open, FXP_CLOSE: _process_close, FXP_READ: _process_read, FXP_WRITE: _process_write, FXP_LSTAT: _process_lstat, FXP_FSTAT: _process_fstat, FXP_SETSTAT: _process_setstat, FXP_FSETSTAT: _process_fsetstat, FXP_OPENDIR: _process_opendir, FXP_READDIR: _process_readdir, FXP_REMOVE: _process_remove, FXP_MKDIR: _process_mkdir, FXP_RMDIR: _process_rmdir, FXP_REALPATH: _process_realpath, FXP_STAT: _process_stat, FXP_RENAME: _process_rename, FXP_READLINK: _process_readlink, FXP_SYMLINK: _process_symlink, FXP_LINK: _process_link, FXP_BLOCK: _process_lock, FXP_UNBLOCK: _process_unlock, b'posix-rename@openssh.com': _process_posix_rename, b'statvfs@openssh.com': _process_statvfs, b'fstatvfs@openssh.com': _process_fstatvfs, b'hardlink@openssh.com': _process_openssh_link, b'fsync@openssh.com': _process_fsync, b'lsetstat@openssh.com': _process_lsetstat, b'limits@openssh.com': _process_limits, b'copy-data': _process_copy_data } async def run(self) -> None: """Run an SFTP server""" assert self._reader is not None try: packet = await self.recv_packet() pkttype = packet.get_byte() self.log_received_packet(pkttype, None, packet) if pkttype != FXP_INIT: await self._cleanup(SFTPBadMessage('Expected init message')) return version = packet.get_uint32() rcvd_extensions: List[Tuple[bytes, bytes]] = [] if version == 3: while packet: name = packet.get_string() data = packet.get_string() rcvd_extensions.append((name, data)) else: packet.check_end() except PacketDecodeError as exc: await self._cleanup(SFTPBadMessage(str(exc))) return except Error as exc: await self._cleanup(exc) return self.logger.debug1('Received init, version=%d%s', version, ', extensions:' if rcvd_extensions else '') self._log_extensions(rcvd_extensions) self._version = min(version, self._version) extensions: List[Tuple[bytes, bytes]] = [] ext_names = b''.join(String(name) for (name, _) in self._extensions) attrib_ext_names = b''.join(String(name) for name in self._attrib_extensions) if self._version == 5: supported = UInt32(self._supported_attr_mask) + \ UInt32(self._supported_attrib_mask) + \ UInt32(self._supported_open_flags) + \ UInt32(self._supported_access_mask) + \ UInt32(MAX_SFTP_READ_LEN) + ext_names + \ attrib_ext_names extensions.append((b'supported', supported)) elif self._version == 6: acl_supported = UInt32(0) # No ACL support yet supported2 = UInt32(self._supported_attr_mask) + \ UInt32(self._supported_attrib_mask) + \ UInt32(self._supported_open_flags) + \ UInt32(self._supported_access_mask) + \ UInt32(MAX_SFTP_READ_LEN) + \ UInt16(self._supported_open_block_vector) + \ UInt16(self._supported_block_vector) + \ UInt32(len(self._attrib_extensions)) + \ attrib_ext_names + \ UInt32(len(self._extensions)) + \ ext_names extensions.append((b'acl-supported', acl_supported)) extensions.append((b'supported2', supported2)) extensions.extend(self._extensions) self.logger.debug1('Sending version=%d%s', self._version, ', extensions:' if extensions else '') self._log_extensions(extensions) sent_extensions: Iterable[bytes] = \ (String(name) + String(data) for name, data in extensions) try: self.send_packet(FXP_VERSION, None, UInt32(self._version), *sent_extensions) except SFTPError as exc: await self._cleanup(exc) return if self._version == 3: # Check if the client has a buggy SYMLINK implementation client_version = cast(str, self._reader.get_extra_info('client_version', '')) if any(name in client_version for name in self._nonstandard_symlink_impls): self.logger.debug1('Adjusting for non-standard symlink ' 'implementation') self._nonstandard_symlink = True await self.recv_packets() class SFTPServer: """SFTP server Applications should subclass this when implementing an SFTP server. The methods listed below should be implemented to provide the desired application behavior. .. note:: Any method can optionally be defined as a coroutine if that method needs to perform blocking operations to determine its result. The `chan` object provided here is the :class:`SSHServerChannel` instance this SFTP server is associated with. It can be queried to determine which user the client authenticated as, environment variables set on the channel when it was opened, and key and certificate options or permissions associated with this session. .. note:: In AsyncSSH 1.x, this first argument was an :class:`SSHServerConnection`, not an :class:`SSHServerChannel`. When moving to AsyncSSH 2.x, subclasses of :class:`SFTPServer` which implement an __init__ method will need to be updated to account for this change, and pass this through to the parent. If the `chroot` argument is specified when this object is created, the default :meth:`map_path` and :meth:`reverse_map_path` methods will enforce a virtual root directory starting in that location, limiting access to only files within that directory tree. This will also affect path names returned by the :meth:`realpath` and :meth:`readlink` methods. """ # The default implementation of a number of these methods don't need self # pylint: disable=no-self-use def __init__(self, chan: 'SSHServerChannel', chroot: Optional[bytes] = None): self._chan = chan self._chroot: Optional[bytes] if chroot: self._chroot = _from_local_path(os.path.realpath(chroot)) else: self._chroot = None @property def channel(self) -> 'SSHServerChannel': """The channel associated with this SFTP server session""" return self._chan @property def connection(self) -> 'SSHServerConnection': """The channel associated with this SFTP server session""" return cast('SSHServerConnection', self._chan.get_connection()) @property def env(self) -> Mapping[str, str]: """The environment associated with this SFTP server session This method returns the environment set by the client when this SFTP session was opened. :returns: A dictionary containing the environment variables set by the client """ return self._chan.get_environment() @property def logger(self) -> SSHLogger: """A logger associated with this SFTP server""" return self._chan.logger async def convert_attrs(self, result: MaybeAwait[_SFTPOSAttrs]) -> \ SFTPAttrs: """Convert stat result to SFTPAttrs""" if inspect.isawaitable(result): result = await cast(Awaitable[_SFTPOSAttrs], result) result: _SFTPOSAttrs if isinstance(result, os.stat_result): result = SFTPAttrs.from_local(result) result: SFTPAttrs return result async def _to_sftpname(self, parent: bytes, name: bytes) -> SFTPName: """Construct an SFTPName for a filename in a directory""" path = posixpath.join(parent, name) attrs = await self.convert_attrs(self.lstat(path)) return SFTPName(name, attrs=attrs) def format_user(self, uid: Optional[int]) -> str: """Return the user name associated with a uid This method returns a user name string to insert into the `longname` field of an :class:`SFTPName` object. By default, it calls the Python :func:`pwd.getpwuid` function if it is available, or returns the numeric uid as a string if not. If there is no uid, it returns an empty string. :param uid: The uid value to look up :type uid: `int` or `None` :returns: The formatted user name string """ return _lookup_user(uid) def format_group(self, gid: Optional[int]) -> str: """Return the group name associated with a gid This method returns a group name string to insert into the `longname` field of an :class:`SFTPName` object. By default, it calls the Python :func:`grp.getgrgid` function if it is available, or returns the numeric gid as a string if not. If there is no gid, it returns an empty string. :param gid: The gid value to look up :type gid: `int` or `None` :returns: The formatted group name string """ return _lookup_group(gid) def format_longname(self, name: SFTPName) -> MaybeAwait[None]: """Format the long name associated with an SFTP name This method fills in the `longname` field of a :class:`SFTPName` object. By default, it generates something similar to UNIX "ls -l" output. The `filename` and `attrs` fields of the :class:`SFTPName` should already be filled in before this method is called. :param name: The :class:`SFTPName` instance to format the long name for :type name: :class:`SFTPName` """ if name.attrs.permissions is not None: mode = stat.filemode(name.attrs.permissions) else: mode = '' nlink = str(name.attrs.nlink) if name.attrs.nlink else '' user = self.format_user(name.attrs.uid) group = self.format_group(name.attrs.gid) size = str(name.attrs.size) if name.attrs.size is not None else '' if name.attrs.mtime is not None: now = time.time() mtime = time.localtime(name.attrs.mtime) modtime = time.strftime('%b ', mtime) try: modtime += time.strftime('%e', mtime) except ValueError: modtime += time.strftime('%d', mtime) if now - 365*24*60*60/2 < name.attrs.mtime <= now: modtime += time.strftime(' %H:%M', mtime) else: modtime += time.strftime(' %Y', mtime) else: modtime = '' detail = f'{mode:10s} {nlink:>4s} {user:8s} {group:8s} ' \ f'{size:>8s} {modtime:12s} ' name.longname = detail.encode('utf-8') + cast(bytes, name.filename) return None def map_path(self, path: bytes) -> bytes: """Map the path requested by the client to a local path This method can be overridden to provide a custom mapping from path names requested by the client to paths in the local filesystem. By default, it will enforce a virtual "chroot" if one was specified when this server was created. Otherwise, path names are left unchanged, with relative paths being interpreted based on the working directory of the currently running process. :param path: The path name to map :type path: `bytes` :returns: bytes containing the local path name to operate on """ if self._chroot: normpath = posixpath.normpath(posixpath.join(b'/', path)) return posixpath.join(self._chroot, normpath[1:]) else: return path def reverse_map_path(self, path: bytes) -> bytes: """Reverse map a local path into the path reported to the client This method can be overridden to provide a custom reverse mapping for the mapping provided by :meth:`map_path`. By default, it hides the portion of the local path associated with the virtual "chroot" if one was specified. :param path: The local path name to reverse map :type path: `bytes` :returns: bytes containing the path name to report to the client """ if self._chroot: if path == self._chroot: return b'/' elif path.startswith(self._chroot + b'/'): return path[len(self._chroot):] else: raise SFTPNoSuchFile('File not found') else: return path def open(self, path: bytes, pflags: int, attrs: SFTPAttrs) -> \ MaybeAwait[object]: """Open a file to serve to a remote client This method returns a file object which can be used to read and write data and get and set file attributes. The possible open mode flags and their meanings are: ========== ====================================================== Mode Description ========== ====================================================== FXF_READ Open the file for reading. If neither FXF_READ nor FXF_WRITE are set, this is the default. FXF_WRITE Open the file for writing. If both this and FXF_READ are set, open the file for both reading and writing. FXF_APPEND Force writes to append data to the end of the file regardless of seek position. FXF_CREAT Create the file if it doesn't exist. Without this, attempts to open a non-existent file will fail. FXF_TRUNC Truncate the file to zero length if it already exists. FXF_EXCL Return an error when trying to open a file which already exists. ========== ====================================================== The attrs argument is used to set initial attributes of the file if it needs to be created. Otherwise, this argument is ignored. :param path: The name of the file to open :param pflags: The access mode to use for the file (see above) :param attrs: File attributes to use if the file needs to be created :type path: `bytes` :type pflags: `int` :type attrs: :class:`SFTPAttrs` :returns: A file object to use to access the file :raises: :exc:`SFTPError` to return an error to the client """ if pflags & FXF_EXCL: mode = 'xb' elif pflags & FXF_APPEND: mode = 'ab' elif pflags & FXF_WRITE and not pflags & FXF_READ: mode = 'wb' else: mode = 'rb' if pflags & FXF_READ and pflags & FXF_WRITE: mode += '+' flags = os.O_RDWR elif pflags & FXF_WRITE: flags = os.O_WRONLY else: flags = os.O_RDONLY if pflags & FXF_APPEND: flags |= os.O_APPEND if pflags & FXF_CREAT: flags |= os.O_CREAT if pflags & FXF_TRUNC: flags |= os.O_TRUNC if pflags & FXF_EXCL: flags |= os.O_EXCL try: flags |= os.O_BINARY except AttributeError: # pragma: no cover pass perms = 0o666 if attrs.permissions is None else attrs.permissions return open(_to_local_path(self.map_path(path)), mode, buffering=0, opener=lambda path, _: os.open(path, flags, perms)) def open56(self, path: bytes, desired_access: int, flags: int, attrs: SFTPAttrs) -> MaybeAwait[object]: """Open a file to serve to a remote client (SFTPv5 and later) This method returns a file object which can be used to read and write data and get and set file attributes. Supported desired_access bits include `ACE4_READ_DATA`, `ACE4_WRITE_DATA`, `ACE4_APPEND_DATA`, `ACE4_READ_ATTRIBUTES`, and `ACE4_WRITE_ATTRIBUTES`. Supported disposition bits in flags and their meanings are: ===================== ============================================ Disposition Description ===================== ============================================ FXF_OPEN_EXISTING Open an existing file FXF_OPEN_OR_CREATE Open an existing file or create a new one FXF_CREATE_NEW Create a new file FXF_CREATE_TRUNCATE Create a new file or truncate an existing one FXF_TRUNCATE_EXISTING Truncate an existing file ===================== ============================================ Other supported flag bits are: ===================== ============================================ Flag Description ===================== ============================================ FXF_APPEND_DATA Append data writes to the end of the file ===================== ============================================ The attrs argument is used to set initial attributes of the file if it needs to be created. Otherwise, this argument is ignored. :param path: The name of the file to open :param desired_access: The access mode to use for the file (see above) :param flags: The access flags to use for the file (see above) :param attrs: File attributes to use if the file needs to be created :type path: `bytes` :type desired_access: `int` :type flags: `int` :type attrs: :class:`SFTPAttrs` :returns: A file object to use to access the file :raises: :exc:`SFTPError` to return an error to the client """ if desired_access & ACE4_READ_DATA and \ desired_access & ACE4_WRITE_DATA: open_flags = os.O_RDWR elif desired_access & ACE4_WRITE_DATA: open_flags = os.O_WRONLY else: open_flags = os.O_RDONLY disp = flags & FXF_ACCESS_DISPOSITION if disp == FXF_CREATE_NEW: mode = 'xb' open_flags |= os.O_CREAT | os.O_EXCL elif disp == FXF_CREATE_TRUNCATE: mode = 'wb' open_flags |= os.O_CREAT | os.O_TRUNC elif disp == FXF_OPEN_OR_CREATE: mode = 'wb' open_flags |= os.O_CREAT elif disp == FXF_TRUNCATE_EXISTING: mode = 'wb' open_flags |= os.O_TRUNC else: mode = 'wb' if desired_access & ACE4_WRITE_DATA else 'rb' if desired_access & ACE4_APPEND_DATA or flags & FXF_APPEND_DATA: mode = 'ab' open_flags |= os.O_APPEND if desired_access & ACE4_READ_DATA and \ desired_access & ACE4_WRITE_DATA: mode += '+' try: open_flags |= os.O_BINARY except AttributeError: # pragma: no cover pass perms = 0o666 if attrs.permissions is None else attrs.permissions return open(_to_local_path(self.map_path(path)), mode, buffering=0, opener=lambda path, _: os.open(path, open_flags, perms)) def close(self, file_obj: object) -> MaybeAwait[None]: """Close an open file or directory :param file_obj: The file or directory object to close :type file_obj: file :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) file_obj.close() return None def read(self, file_obj: object, offset: int, size: int) -> \ MaybeAwait[bytes]: """Read data from an open file :param file_obj: The file to read from :param offset: The offset from the beginning of the file to begin reading :param size: The number of bytes to read :type file_obj: file :type offset: `int` :type size: `int` :returns: bytes read from the file :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) file_obj.seek(offset) return file_obj.read(size) def write(self, file_obj: object, offset: int, data: bytes) -> \ MaybeAwait[int]: """Write data to an open file :param file_obj: The file to write to :param offset: The offset from the beginning of the file to begin writing :param data: The data to write to the file :type file_obj: file :type offset: `int` :type data: `bytes` :returns: number of bytes written :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) file_obj.seek(offset) return file_obj.write(data) def lstat(self, path: bytes) -> MaybeAwait[_SFTPOSAttrs]: """Get attributes of a file, directory, or symlink This method queries the attributes of a file, directory, or symlink. Unlike :meth:`stat`, this method should return the attributes of a symlink itself rather than the target of that link. :param path: The path of the file, directory, or link to get attributes for :type path: `bytes` :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ return os.lstat(_to_local_path(self.map_path(path))) def fstat(self, file_obj: object) -> MaybeAwait[_SFTPOSAttrs]: """Get attributes of an open file :param file_obj: The file to get attributes for :type file_obj: file :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) file_obj.flush() return os.fstat(file_obj.fileno()) def setstat(self, path: bytes, attrs: SFTPAttrs) -> MaybeAwait[None]: """Set attributes of a file or directory This method sets attributes of a file or directory. If the path provided is a symbolic link, the attributes should be set on the target of the link. A subset of the fields in `attrs` can be initialized and only those attributes should be changed. :param path: The path of the remote file or directory to set attributes for :param attrs: File attributes to set :type path: `bytes` :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ _setstat(_to_local_path(self.map_path(path)), attrs) return None def lsetstat(self, path: bytes, attrs: SFTPAttrs) -> MaybeAwait[None]: """Set attributes of a file, directory, or symlink This method sets attributes of a file, directory, or symlink. A subset of the fields in `attrs` can be initialized and only those attributes should be changed. :param path: The path of the remote file or directory to set attributes for :param attrs: File attributes to set :type path: `bytes` :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ _setstat(_to_local_path(self.map_path(path)), attrs, follow_symlinks=False) return None def fsetstat(self, file_obj: object, attrs: SFTPAttrs) -> MaybeAwait[None]: """Set attributes of an open file :param file_obj: The file to set attributes for :param attrs: File attributes to set on the file :type file_obj: file :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) file_obj.flush() if sys.platform == 'win32': # pragma: no cover _setstat(file_obj.name, attrs) else: _setstat(file_obj.fileno(), attrs) return None async def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: """Return names and attributes of the files in a directory This function returns an async iterator of :class:`SFTPName` entries corresponding to files in the requested directory. :param path: The path of the directory to scan :type path: `bytes` :returns: An async iterator of :class:`SFTPName` :raises: :exc:`SFTPError` to return an error to the client """ if hasattr(self, 'listdir'): # Support backward compatibility with older AsyncSSH versions # which allowed listdir() to be overridden, returning a list # of either :class:`SFTPName` objects or plain filenames, in # which case :meth:`lstat` is called to retrieve attribute # information. # pylint: disable=no-member listdir_result = self.listdir(path) # type: ignore if inspect.isawaitable(listdir_result): listdir_result = await cast( Awaitable[Sequence[Union[bytes, SFTPName]]], listdir_result) listdir_result: Sequence[Union[bytes, SFTPName]] for name in listdir_result: if isinstance(name, bytes): yield await self._to_sftpname(path, name) else: yield name else: for name in (b'.', b'..'): yield await self._to_sftpname(path, name) with os.scandir(_to_local_path(self.map_path(path))) as entries: for entry in entries: filename = entry.name if sys.platform == 'win32': # pragma: no cover filename = os.fsencode(filename) attrs = SFTPAttrs.from_local( entry.stat(follow_symlinks=False)) yield SFTPName(filename, attrs=attrs) def remove(self, path: bytes) -> MaybeAwait[None]: """Remove a file or symbolic link :param path: The path of the file or link to remove :type path: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ os.remove(_to_local_path(self.map_path(path))) return None def mkdir(self, path: bytes, attrs: SFTPAttrs) -> MaybeAwait[None]: """Create a directory with the specified attributes :param path: The path of where the new directory should be created :param attrs: The file attributes to use when creating the directory :type path: `bytes` :type attrs: :class:`SFTPAttrs` :raises: :exc:`SFTPError` to return an error to the client """ mode = 0o777 if attrs.permissions is None else attrs.permissions os.mkdir(_to_local_path(self.map_path(path)), mode) return None def rmdir(self, path: bytes) -> MaybeAwait[None]: """Remove a directory :param path: The path of the directory to remove :type path: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ os.rmdir(_to_local_path(self.map_path(path))) return None def realpath(self, path: bytes) -> MaybeAwait[bytes]: """Return the canonical version of a path :param path: The path of the directory to canonicalize :type path: `bytes` :returns: bytes containing the canonical path :raises: :exc:`SFTPError` to return an error to the client """ path = os.path.realpath(_to_local_path(self.map_path(path))) return self.reverse_map_path(_from_local_path(path)) def stat(self, path: bytes) -> MaybeAwait[_SFTPOSAttrs]: """Get attributes of a file or directory, following symlinks This method queries the attributes of a file or directory. If the path provided is a symbolic link, the returned attributes should correspond to the target of the link. :param path: The path of the remote file or directory to get attributes for :type path: `bytes` :returns: An :class:`SFTPAttrs` or an os.stat_result containing the file attributes :raises: :exc:`SFTPError` to return an error to the client """ return os.stat(_to_local_path(self.map_path(path))) def rename(self, oldpath: bytes, newpath: bytes) -> MaybeAwait[None]: """Rename a file, directory, or link This method renames a file, directory, or link. .. note:: This is a request for the standard SFTP version of rename which will not overwrite the new path if it already exists. The :meth:`posix_rename` method will be called if the client requests the POSIX behavior where an existing instance of the new path is removed before the rename. :param oldpath: The path of the file, directory, or link to rename :param newpath: The new name for this file, directory, or link :type oldpath: `bytes` :type newpath: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) if os.path.exists(newpath): raise SFTPFileAlreadyExists('File already exists') os.rename(oldpath, newpath) return None def readlink(self, path: bytes) -> MaybeAwait[bytes]: """Return the target of a symbolic link :param path: The path of the symbolic link to follow :type path: `bytes` :returns: bytes containing the target path of the link :raises: :exc:`SFTPError` to return an error to the client """ path = os.readlink(_to_local_path(self.map_path(path))) if sys.platform == 'win32' and \ path.startswith('\\\\?\\'): # pragma: no cover path = path[4:] if self._chroot: path = os.path.realpath(path) return self.reverse_map_path(_from_local_path(path)) def symlink(self, oldpath: bytes, newpath: bytes) -> MaybeAwait[None]: """Create a symbolic link :param oldpath: The path the link should point to :param newpath: The path of where to create the symbolic link :type oldpath: `bytes` :type newpath: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ if posixpath.isabs(oldpath): oldpath = self.map_path(oldpath) else: newdir = posixpath.dirname(newpath) abspath1 = self.map_path(posixpath.join(newdir, oldpath)) mapped_newdir = self.map_path(newdir) abspath2 = os.path.join(mapped_newdir, oldpath) # Make sure the symlink doesn't point outside the chroot if os.path.realpath(abspath1) != os.path.realpath(abspath2): oldpath = os.path.relpath(abspath1, start=mapped_newdir) newpath = self.map_path(newpath) os.symlink(_to_local_path(oldpath), _to_local_path(newpath)) return None def link(self, oldpath: bytes, newpath: bytes) -> MaybeAwait[None]: """Create a hard link :param oldpath: The path of the file the hard link should point to :param newpath: The path of where to create the hard link :type oldpath: `bytes` :type newpath: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) os.link(oldpath, newpath) return None def lock(self, file_obj: object, offset: int, length: int, flags: int) -> MaybeAwait[None]: """Acquire a byte range lock on an open file""" raise SFTPOpUnsupported('Byte range locks not supported') def unlock(self, file_obj: object, offset: int, length: int) -> MaybeAwait[None]: """Release a byte range lock on an open file""" raise SFTPOpUnsupported('Byte range locks not supported') def posix_rename(self, oldpath: bytes, newpath: bytes) -> MaybeAwait[None]: """Rename a file, directory, or link with POSIX semantics This method renames a file, directory, or link, removing the prior instance of new path if it previously existed. :param oldpath: The path of the file, directory, or link to rename :param newpath: The new name for this file, directory, or link :type oldpath: `bytes` :type newpath: `bytes` :raises: :exc:`SFTPError` to return an error to the client """ oldpath = _to_local_path(self.map_path(oldpath)) newpath = _to_local_path(self.map_path(newpath)) os.replace(oldpath, newpath) return None def statvfs(self, path: bytes) -> MaybeAwait[_SFTPOSVFSAttrs]: """Get attributes of the file system containing a file :param path: The path of the file system to get attributes for :type path: `bytes` :returns: An :class:`SFTPVFSAttrs` or an os.statvfs_result containing the file system attributes :raises: :exc:`SFTPError` to return an error to the client """ try: return os.statvfs(_to_local_path(self.map_path(path))) except AttributeError: # pragma: no cover raise SFTPOpUnsupported('statvfs not supported') from None def fstatvfs(self, file_obj: object) -> MaybeAwait[_SFTPOSVFSAttrs]: """Return attributes of the file system containing an open file :param file_obj: The open file to get file system attributes for :type file_obj: file :returns: An :class:`SFTPVFSAttrs` or an os.statvfs_result containing the file system attributes :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) try: return os.statvfs(file_obj.fileno()) except AttributeError: # pragma: no cover raise SFTPOpUnsupported('fstatvfs not supported') from None def fsync(self, file_obj: object) -> MaybeAwait[None]: """Force file data to be written to disk :param file_obj: The open file containing the data to flush to disk :type file_obj: file :raises: :exc:`SFTPError` to return an error to the client """ file_obj = cast(_SFTPFileObj, file_obj) os.fsync(file_obj.fileno()) return None def exit(self) -> MaybeAwait[None]: """Shut down this SFTP server""" return None class LocalFile: """An async wrapper around local file I/O""" def __init__(self, file: _SFTPFileObj): self._file = file async def __aenter__(self) -> Self: # pragma: no cover """Allow LocalFile to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> \ bool: # pragma: no cover """Wait for file close when used as an async context manager""" await self.close() return False async def read(self, size: int, offset: int) -> bytes: """Read data from the local file""" self._file.seek(offset) return self._file.read(size) async def write(self, data: bytes, offset: int) -> int: """Write data to the local file""" self._file.seek(offset) return self._file.write(data) async def close(self) -> None: """Close the local file""" self._file.close() class LocalFS: """An async wrapper around local filesystem access""" limits = SFTPLimits(0, MAX_SFTP_READ_LEN, MAX_SFTP_WRITE_LEN, 0) @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a local file path""" return os.path.basename(path) def encode(self, path: _SFTPPath) -> bytes: """Encode path name using filesystem native encoding This method has no effect if the path is already bytes. """ # pylint: disable=no-self-use return os.fsencode(path) def compose_path(self, path: bytes, parent: Optional[bytes] = None) -> bytes: """Compose a path If parent is not specified, just encode the path. """ path = self.encode(path) return posixpath.join(parent, path) if parent else path async def stat(self, path: bytes, *, follow_symlinks: bool = True) -> 'SFTPAttrs': """Get attributes of a local file, directory, or symlink""" return SFTPAttrs.from_local(os.stat(_to_local_path(path), follow_symlinks=follow_symlinks)) async def setstat(self, path: bytes, attrs: 'SFTPAttrs', *, follow_symlinks: bool = True) -> None: """Set attributes of a local file, directory, or symlink""" _setstat(_to_local_path(path), attrs, follow_symlinks=follow_symlinks) async def exists(self, path: bytes) -> bool: """Return if the local path exists and isn't a broken symbolic link""" return os.path.exists(_to_local_path(path)) async def isdir(self, path: bytes) -> bool: """Return if the local path refers to a directory""" return os.path.isdir(_to_local_path(path)) async def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: """Return names and attributes of the files in a local directory""" with os.scandir(_to_local_path(path)) as entries: for entry in entries: filename = entry.name if sys.platform == 'win32': # pragma: no cover filename = os.fsencode(filename) attrs = SFTPAttrs.from_local(entry.stat(follow_symlinks=False)) yield SFTPName(filename, attrs=attrs) async def mkdir(self, path: bytes) -> None: """Create a local directory with the specified attributes""" os.mkdir(_to_local_path(path)) async def readlink(self, path: bytes) -> bytes: """Return the target of a local symbolic link""" path = os.readlink(_to_local_path(path)) if sys.platform == 'win32' and \ path.startswith('\\\\?\\'): # pragma: no cover path = path[4:] return _from_local_path(path) async def symlink(self, oldpath: bytes, newpath: bytes) -> None: """Create a local symbolic link""" os.symlink(_to_local_path(oldpath), _to_local_path(newpath)) @async_context_manager async def open(self, path: bytes, mode: str, block_size: int = -1) -> LocalFile: """Open a local file""" # pylint: disable=unused-argument return LocalFile(open(_to_local_path(path), mode)) local_fs = LocalFS() class SFTPServerFile: """A wrapper around SFTPServer used to access files it manages""" def __init__(self, server: SFTPServer, file_obj: object): self._server = server self._file_obj = file_obj async def __aenter__(self) -> Self: # pragma: no cover """Allow SFTPServerFile to be used as an async context manager""" return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> \ bool: # pragma: no cover """Wait for client close when used as an async context manager""" await self.close() return False async def read(self, size: int, offset: int) -> bytes: """Read bytes from the file""" data = self._server.read(self._file_obj, offset, size) if inspect.isawaitable(data): data = await cast(Awaitable[bytes], data) data: bytes return data async def write(self, data: bytes, offset: int) -> int: """Write bytes to the file""" size = self._server.write(self._file_obj, offset, data) if inspect.isawaitable(size): size = await cast(Awaitable[int], size) size: int return size async def close(self) -> None: """Close a file managed by the associated SFTPServer""" result = self._server.close(self._file_obj) if inspect.isawaitable(result): assert result is not None await result class SFTPServerFS: """A wrapper around SFTPServer used to access its filesystem""" def __init__(self, server: SFTPServer): self._server = server @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" return posixpath.basename(path) async def stat(self, path: bytes) -> SFTPAttrs: """Get attributes of a file or directory, following symlinks""" attrs = self._server.stat(path) if inspect.isawaitable(attrs): attrs = await cast(Awaitable[_SFTPOSAttrs], attrs) attrs: _SFTPOSAttrs if isinstance(attrs, os.stat_result): attrs = SFTPAttrs.from_local(attrs) return attrs async def setstat(self, path: bytes, attrs: SFTPAttrs) -> None: """Set attributes of a file or directory""" result = self._server.setstat(path, attrs) if inspect.isawaitable(result): assert result is not None await result async def _type(self, path: bytes) -> int: """Return the file type of a path, or 0 if it can't be accessed""" try: return (await self.stat(path)).type except OSError as exc: if exc.errno in (errno.ENOENT, errno.EACCES): return FILEXFER_TYPE_UNKNOWN else: raise except (SFTPNoSuchFile, SFTPNoSuchPath, SFTPPermissionDenied): return FILEXFER_TYPE_UNKNOWN async def exists(self, path: bytes) -> bool: """Return if a path exists""" return (await self._type(path)) != FILEXFER_TYPE_UNKNOWN async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" return (await self._type(path)) == FILEXFER_TYPE_DIRECTORY def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: """Return names and attributes of the files in a directory""" return self._server.scandir(path) async def mkdir(self, path: bytes) -> None: """Create a directory""" result = self._server.mkdir(path, SFTPAttrs()) if inspect.isawaitable(result): assert result is not None await result @async_context_manager async def open(self, path: bytes, mode: str) -> SFTPServerFile: """Open a file""" pflags, _ = _mode_to_pflags(mode) file_obj = self._server.open(path, pflags, SFTPAttrs()) if inspect.isawaitable(file_obj): file_obj = await cast(Awaitable[object], file_obj) return SFTPServerFile(self._server, file_obj) async def start_sftp_client(conn: 'SSHClientConnection', loop: asyncio.AbstractEventLoop, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', path_encoding: Optional[str], path_errors: str, sftp_version: int) -> SFTPClient: """Start an SFTP client""" handler = SFTPClientHandler(loop, reader, writer, sftp_version) handler.logger.info('Starting SFTP client') await handler.start() conn.create_task(handler.recv_packets(), handler.logger) await handler.request_limits() return SFTPClient(handler, path_encoding, path_errors) def run_sftp_server(sftp_server: SFTPServer, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', sftp_version: int) -> Awaitable[None]: """Return a handler for an SFTP server session""" handler = SFTPServerHandler(sftp_server, reader, writer, sftp_version) handler.logger.info('Starting SFTP server') return handler.run() asyncssh-2.20.0/asyncssh/sk.py000066400000000000000000000271031475467777400163070ustar00rootroot00000000000000# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """U2F security key handler""" from base64 import urlsafe_b64encode import ctypes from hashlib import sha256 import hmac import time from typing import Callable, List, Mapping, NoReturn, Optional from typing import Sequence, Tuple, TypeVar, cast _PollResult = TypeVar('_PollResult') _SKResidentKey = Tuple[int, str, bytes, bytes] _CTAP1_POLL_INTERVAL = 0.1 _dummy_hash = 32 * b'\0' # Flags SSH_SK_USER_PRESENCE_REQD = 0x01 # Algorithms SSH_SK_ECDSA = -7 SSH_SK_ED25519 = -8 def _decode_public_key(alg: int, public_key: Mapping[int, object]) -> bytes: """Decode algorithm and public value from a CTAP public key""" result = cast(bytes, public_key[-2]) if alg == SSH_SK_ED25519: return result else: return b'\x04' + result + cast(bytes, public_key[-3]) def _verify_rp_id(_rp_id: str, _origin: str): """Allow any relying party name -- SSH encodes the application here""" return True def _ctap1_poll(poll_interval: float, func: Callable[..., _PollResult], *args: object) -> _PollResult: """Poll until a CTAP1 response is received""" while True: try: return func(*args) except ApduError as exc: if exc.code != APDU.USE_NOT_SATISFIED: raise time.sleep(poll_interval) def _ctap1_enroll(dev: 'CtapHidDevice', alg: int, application: str) -> Tuple[bytes, bytes]: """Enroll a new security key using CTAP version 1""" ctap1 = Ctap1(dev) if alg != SSH_SK_ECDSA: raise ValueError('Unsupported algorithm') app_hash = sha256(application.encode('utf-8')).digest() registration = _ctap1_poll(_CTAP1_POLL_INTERVAL, ctap1.register, _dummy_hash, app_hash) return registration.public_key, registration.key_handle def _ctap2_enroll(dev: 'CtapHidDevice', alg: int, application: str, user: str, pin: Optional[str], resident: bool) -> Tuple[bytes, bytes]: """Enroll a new security key using CTAP version 2""" ctap2 = Ctap2(dev) rp = {'id': application, 'name': application} user_cred = {'id': user.encode('utf-8'), 'name': user} key_params = [{'type': 'public-key', 'alg': alg}] options = {'rk': resident} pin_protocol: Optional[PinProtocolV1] pin_auth: Optional[bytes] if pin: pin_protocol = PinProtocolV1() pin_token = ClientPin(ctap2, pin_protocol).get_pin_token(pin) pin_auth = hmac.new(pin_token, _dummy_hash, sha256).digest()[:16] else: pin_protocol = None pin_auth = None pin_version = pin_protocol.VERSION if pin_protocol else None cred = ctap2.make_credential(_dummy_hash, rp, user_cred, key_params, options=options, pin_uv_param=pin_auth, pin_uv_protocol=pin_version) cdata = cred.auth_data.credential_data # pylint: disable=no-member return _decode_public_key(alg, cdata.public_key), cdata.credential_id def _win_enroll(alg: int, application: str, user: str) -> Tuple[bytes, bytes]: """Enroll a new security key using Windows WebAuthn API""" client = WindowsClient(application, verify=_verify_rp_id) rp = {'id': application, 'name': application} user_cred = {'id': user.encode('utf-8'), 'name': user} key_params = [{'type': 'public-key', 'alg': alg}] options = {'rp': rp, 'user': user_cred, 'challenge': b'', 'pubKeyCredParams': key_params} result = client.make_credential(options) cdata = result.attestation_object.auth_data.credential_data # pylint: disable=no-member return _decode_public_key(alg, cdata.public_key), cdata.credential_id def _ctap1_sign(dev: 'CtapHidDevice', message_hash: bytes, application: str, key_handle: bytes) -> Tuple[int, int, bytes]: """Sign a message with a security key using CTAP version 1""" ctap1 = Ctap1(dev) app_hash = sha256(application.encode('utf-8')).digest() auth_response = _ctap1_poll(_CTAP1_POLL_INTERVAL, ctap1.authenticate, message_hash, app_hash, key_handle) flags = auth_response[0] counter = int.from_bytes(auth_response[1:5], 'big') sig = auth_response[5:] return flags, counter, sig def _ctap2_sign(dev: 'CtapHidDevice', message_hash: bytes, application: str, key_handle: bytes, touch_required: bool) -> Tuple[int, int, bytes]: """Sign a message with a security key using CTAP version 2""" ctap2 = Ctap2(dev) allow_creds = [{'type': 'public-key', 'id': key_handle}] options = {'up': touch_required} # See if key handle exists before requiring touch if touch_required: ctap2.get_assertions(application, message_hash, allow_creds, options={'up': False}) assertion = ctap2.get_assertions(application, message_hash, allow_creds, options=options)[0] auth_data = assertion.auth_data return auth_data.flags, auth_data.counter, assertion.signature def _win_sign(data: bytes, application: str, key_handle: bytes) -> Tuple[int, int, bytes, bytes]: """Sign a message with a security key using Windows WebAuthn API""" client = WindowsClient(application, verify=_verify_rp_id) creds = [{'type': 'public-key', 'id': key_handle}] options = {'challenge': data, 'rpId': application, 'allowCredentials': creds} result = client.get_assertion(options).get_response(0) auth_data = result.authenticator_data return auth_data.flags, auth_data.counter, \ result.signature, bytes(result.client_data) def sk_webauthn_prefix(data: bytes, application: str) -> bytes: """Calculate a WebAuthn request prefix""" return b'{"type":"webauthn.get","challenge":"' + \ urlsafe_b64encode(data).rstrip(b'=') + b'","origin":"' + \ application.encode('utf-8') + b'"' def sk_enroll(alg: int, application: str, user: str, pin: Optional[str], resident: bool) -> Tuple[bytes, bytes]: """Enroll a new security key""" if sk_use_webauthn: return _win_enroll(alg, application, user) try: dev = next(CtapHidDevice.list_devices()) except StopIteration: raise ValueError('No security key found') from None try: return _ctap2_enroll(dev, alg, application, user, pin, resident) except CtapError as exc: if exc.code == CtapError.ERR.PUAT_REQUIRED: raise ValueError('PIN required') from None elif exc.code == CtapError.ERR.PIN_INVALID: raise ValueError('Invalid PIN') from None else: raise ValueError(str(exc)) from None except ValueError: try: return _ctap1_enroll(dev, alg, application) except ApduError as exc: raise ValueError(str(exc)) from None finally: dev.close() def sk_sign(data: bytes, application: str, key_handle: bytes, flags: int, is_webauthn: bool = False) -> Tuple[int, int, bytes, bytes]: """Sign a message with a security key""" touch_required = bool(flags & SSH_SK_USER_PRESENCE_REQD) if is_webauthn and sk_use_webauthn: return _win_sign(data, application, key_handle) if is_webauthn: data = sk_webauthn_prefix(data, application) + b'}' message_hash = sha256(data).digest() for dev in CtapHidDevice.list_devices(): try: flags, counter, sig = _ctap2_sign(dev, message_hash, application, key_handle, touch_required) return flags, counter, sig, data except CtapError as exc: if exc.code != CtapError.ERR.NO_CREDENTIALS: raise ValueError(str(exc)) from None except ValueError: try: flags, counter, sig = _ctap1_sign(dev, message_hash, application, key_handle) return flags, counter, sig, data except ApduError as exc: if exc.code != APDU.WRONG_DATA: raise ValueError(str(exc)) from None finally: dev.close() raise ValueError('Security key credential not found') def sk_get_resident(application: str, user: Optional[str], pin: str) -> Sequence[_SKResidentKey]: """Get keys resident on a security key""" app_hash = sha256(application.encode('utf-8')).digest() result: List[_SKResidentKey] = [] for dev in CtapHidDevice.list_devices(): try: ctap2 = Ctap2(dev) pin_protocol = PinProtocolV1() pin_token = ClientPin(ctap2, pin_protocol).get_pin_token(pin) cred_mgmt = CredentialManagement(ctap2, pin_protocol, pin_token) for cred in cred_mgmt.enumerate_creds(app_hash): user_info = cast(Mapping[str, object], cred[CredentialManagement.RESULT.USER]) name = cast(str, user_info['name']) if user and name != user: continue cred_id = cast(Mapping[str, object], cred[CredentialManagement.RESULT.CREDENTIAL_ID]) key_handle = cast(bytes, cred_id['id']) public_key = cast(Mapping[int, object], cred[CredentialManagement.RESULT.PUBLIC_KEY]) alg = cast(int, public_key[3]) public_value = _decode_public_key(alg, public_key) result.append((alg, name, public_value, key_handle)) except CtapError as exc: if exc.code == CtapError.ERR.NO_CREDENTIALS: continue elif exc.code == CtapError.ERR.PIN_INVALID: raise ValueError('Invalid PIN') from None elif exc.code == CtapError.ERR.PIN_NOT_SET: raise ValueError('PIN not set') from None else: raise ValueError(str(exc)) from None finally: dev.close() return result try: from fido2.client import WindowsClient from fido2.ctap import CtapError from fido2.ctap1 import Ctap1, APDU, ApduError from fido2.ctap2 import Ctap2, ClientPin, PinProtocolV1 from fido2.ctap2 import CredentialManagement from fido2.hid import CtapHidDevice sk_available = True sk_use_webauthn = WindowsClient.is_available() and \ hasattr(ctypes, 'windll') and \ not ctypes.windll.shell32.IsUserAnAdmin() except (ImportError, OSError, AttributeError): # pragma: no cover sk_available = False sk_use_webauthn = False def _sk_not_available(*args: object, **kwargs: object) -> NoReturn: """Report that security key support is unavailable""" raise ValueError('Security key support not available') sk_enroll = _sk_not_available sk_sign = _sk_not_available sk_get_resident = _sk_not_available asyncssh-2.20.0/asyncssh/sk_ecdsa.py000066400000000000000000000207331475467777400174500ustar00rootroot00000000000000# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """U2F ECDSA public key encryption handler""" from hashlib import sha256 from typing import Optional, Tuple, cast from .asn1 import der_encode, der_decode from .crypto import ECDSAPublicKey from .packet import Byte, MPInt, String, UInt32, SSHPacket from .public_key import KeyExportError, SSHKey, SSHOpenSSHCertificateV01 from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_sk_alg from .sk import SSH_SK_ECDSA, SSH_SK_USER_PRESENCE_REQD from .sk import sk_enroll, sk_sign, sk_webauthn_prefix, sk_use_webauthn _PrivateKeyArgs = Tuple[bytes, bytes, str, int, bytes, bytes] _PublicKeyArgs = Tuple[bytes, bytes, str] class _SKECDSAKey(SSHKey): """Handler for U2F ECDSA public key encryption""" _key: ECDSAPublicKey use_executor = True def __init__(self, curve_id: bytes, public_value: bytes, application: str, flags: int = 0, key_handle: Optional[bytes] = None, reserved: bytes = b''): super().__init__(ECDSAPublicKey.construct(curve_id, public_value)) self.algorithm = b'sk-ecdsa-sha2-' + curve_id + b'@openssh.com' self.sig_algorithms = (self.algorithm, b'webauthn-' + self.algorithm) self.all_sig_algorithms = set(self.sig_algorithms) self.use_webauthn = sk_use_webauthn self._application = application self._app_hash = sha256(application.encode('utf-8')).digest() self._flags = flags self._key_handle = key_handle self._reserved = reserved def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _SKECDSAKey instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.curve_id == other._key.curve_id and self._key.public_value == other._key.public_value and self._application == other._application and self._flags == other._flags and self._key_handle == other._key_handle and self._reserved == other._reserved) def __hash__(self) -> int: return hash((self._key.curve_id, self._key.public_value, self._application, self._flags, self._key_handle, self._reserved)) @classmethod def generate(cls, algorithm: bytes, *, # type: ignore application: str = 'ssh:', user: str = 'AsyncSSH', pin: Optional[str] = None, resident: bool = False, touch_required: bool = True) -> '_SKECDSAKey': """Generate a new SK ECDSA private key""" # pylint: disable=arguments-differ flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 public_value, key_handle = sk_enroll(SSH_SK_ECDSA, application, user, pin, resident) # Strip prefix and suffix of algorithm to get curve_id return cls(algorithm[14:-12], public_value, application, flags, key_handle, b'') @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct a U2F ECDSA private key""" curve_id, public_value, application, flags, key_handle, reserved = \ cast(_PrivateKeyArgs, key_params) return cls(curve_id, public_value, application, flags, key_handle, reserved) @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct a U2F ECDSA public key""" curve_id, public_value, application = cast(_PublicKeyArgs, key_params) return cls(curve_id, public_value, application) @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format SK ECDSA private key""" curve_id = packet.get_string() public_value = packet.get_string() application = packet.get_string().decode('utf-8') flags = packet.get_byte() key_handle = packet.get_string() reserved = packet.get_string() return curve_id, public_value, application, flags, key_handle, reserved @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format SK ECDSA public key""" curve_id = packet.get_string() public_value = packet.get_string() application = packet.get_string().decode('utf-8') return curve_id, public_value, application def encode_ssh_private(self) -> bytes: """Encode an SSH format SK ECDSA private key""" if self._key_handle is None: raise KeyExportError('Key is not private') return b''.join((String(self._key.curve_id), String(self._key.public_value), String(self._application), Byte(self._flags), String(self._key_handle), String(self._reserved))) def encode_ssh_public(self) -> bytes: """Encode an SSH format SK ECDSA public key""" return b''.join((String(self._key.curve_id), String(self._key.public_value), String(self._application))) def encode_agent_cert_private(self) -> bytes: """Encode U2F ECDSA certificate private key data for agent""" if self._key_handle is None: raise KeyExportError('Key is not private') return b''.join((String(self._application), Byte(self._flags), String(self._key_handle), String(self._reserved))) def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" if self._key_handle is None: raise ValueError('Key handle needed for signing') is_webauthn = sig_algorithm.startswith(b'webauthn') flags, counter, sig, client_data = sk_sign(data, self._application, self._key_handle, self._flags, is_webauthn) r, s = cast(Tuple[int, int], der_decode(sig)) sig = String(MPInt(r) + MPInt(s)) + Byte(flags) + UInt32(counter) if is_webauthn: sig += String(self._application) + String(client_data) + String('') return sig def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" is_webauthn = sig_algorithm.startswith(b'webauthn') sig = packet.get_string() flags = packet.get_byte() counter = packet.get_uint32() if is_webauthn: _ = packet.get_string() # origin client_data = packet.get_string() _ = packet.get_string() # extensions prefix = sk_webauthn_prefix(data, self._application) if not client_data.startswith(prefix): return False data = client_data packet.check_end() if self._touch_required and not flags & SSH_SK_USER_PRESENCE_REQD: return False packet = SSHPacket(sig) r = packet.get_mpint() s = packet.get_mpint() packet.check_end() sig = der_encode((r, s)) return self._key.verify(self._app_hash + Byte(flags) + UInt32(counter) + sha256(data).digest(), sig, 'sha256') _algorithm = b'sk-ecdsa-sha2-nistp256@openssh.com' _cert_algorithm = b'sk-ecdsa-sha2-nistp256-cert-v01@openssh.com' register_sk_alg(SSH_SK_ECDSA, _SKECDSAKey, b'nistp256') register_public_key_alg(_algorithm, _SKECDSAKey, True, (_algorithm, b'webauthn-' + _algorithm)) register_certificate_alg(1, _algorithm, _cert_algorithm, _SKECDSAKey, SSHOpenSSHCertificateV01, True) asyncssh-2.20.0/asyncssh/sk_eddsa.py000066400000000000000000000153331475467777400174510ustar00rootroot00000000000000# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """U2F EdDSA public key encryption handler""" from hashlib import sha256 from typing import Optional, Tuple, cast from .crypto import EdDSAPublicKey, ed25519_available from .packet import Byte, String, UInt32, SSHPacket from .public_key import KeyExportError, SSHKey, SSHOpenSSHCertificateV01 from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_sk_alg from .sk import SSH_SK_ED25519, SSH_SK_USER_PRESENCE_REQD, sk_enroll, sk_sign _PrivateKeyArgs = Tuple[bytes, str, int, bytes, bytes] _PublicKeyArgs = Tuple[bytes, str] class _SKEd25519Key(SSHKey): """Handler for U2F Ed25519 public key encryption""" _key: EdDSAPublicKey algorithm = b'sk-ssh-ed25519@openssh.com' sig_algorithms = (algorithm,) all_sig_algorithms = set(sig_algorithms) use_executor = True def __init__(self, public_value: bytes, application: str, flags: int = 0, key_handle: Optional[bytes] = None, reserved: bytes = b''): super().__init__(EdDSAPublicKey.construct(b'ed25519', public_value)) self._application = application self._app_hash = sha256(application.encode('utf-8')).digest() self._flags = flags self._key_handle = key_handle self._reserved = reserved def __eq__(self, other: object) -> bool: # This isn't protected access - both objects are _SKEd25519Key instances # pylint: disable=protected-access return (isinstance(other, type(self)) and self._key.public_value == other._key.public_value and self._application == other._application and self._flags == other._flags and self._key_handle == other._key_handle and self._reserved == other._reserved) def __hash__(self) -> int: return hash((self._key.public_value, self._application, self._flags, self._key_handle, self._reserved)) @classmethod def generate(cls, algorithm: bytes, *, # type: ignore application: str = 'ssh:', user: str = 'AsyncSSH', pin: Optional[str] = None, resident: bool = False, touch_required: bool = True) -> '_SKEd25519Key': """Generate a new U2F Ed25519 private key""" # pylint: disable=arguments-differ flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 public_value, key_handle = sk_enroll(SSH_SK_ED25519, application, user, pin, resident) return cls(public_value, application, flags, key_handle, b'') @classmethod def make_private(cls, key_params: object) -> SSHKey: """Construct a U2F Ed25519 private key""" public_value, application, flags, key_handle, reserved = \ cast(_PrivateKeyArgs, key_params) return cls(public_value, application, flags, key_handle, reserved) @classmethod def make_public(cls, key_params: object) -> SSHKey: """Construct a U2F Ed25519 public key""" public_value, application = cast(_PublicKeyArgs, key_params) return cls(public_value, application) @classmethod def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format U2F Ed25519 private key""" public_value = packet.get_string() application = packet.get_string().decode('utf-8') flags = packet.get_byte() key_handle = packet.get_string() reserved = packet.get_string() return public_value, application, flags, key_handle, reserved @classmethod def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format U2F Ed25519 public key""" public_value = packet.get_string() application = packet.get_string().decode('utf-8') return public_value, application def encode_ssh_private(self) -> bytes: """Encode an SSH format U2F Ed25519 private key""" if self._key_handle is None: raise KeyExportError('Key is not private') return b''.join((String(self._key.public_value), String(self._application), Byte(self._flags), String(self._key_handle), String(self._reserved))) def encode_ssh_public(self) -> bytes: """Encode an SSH format U2F Ed25519 public key""" return b''.join((String(self._key.public_value), String(self._application))) def encode_agent_cert_private(self) -> bytes: """Encode U2F Ed25519 certificate private key data for agent""" return self.encode_ssh_private() def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument if self._key_handle is None: raise ValueError('Key handle needed for signing') flags, counter, sig, _ = sk_sign(data, self._application, self._key_handle, self._flags) return String(sig) + Byte(flags) + UInt32(counter) def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" # pylint: disable=unused-argument sig = packet.get_string() flags = packet.get_byte() counter = packet.get_uint32() packet.check_end() if self._touch_required and not flags & SSH_SK_USER_PRESENCE_REQD: return False return self._key.verify(self._app_hash + Byte(flags) + UInt32(counter) + sha256(data).digest(), sig) if ed25519_available: # pragma: no branch register_sk_alg(SSH_SK_ED25519, _SKEd25519Key) register_public_key_alg(b'sk-ssh-ed25519@openssh.com', _SKEd25519Key, True) register_certificate_alg(1, b'sk-ssh-ed25519@openssh.com', b'sk-ssh-ed25519-cert-v01@openssh.com', _SKEd25519Key, SSHOpenSSHCertificateV01, True) asyncssh-2.20.0/asyncssh/socks.py000066400000000000000000000166351475467777400170240ustar00rootroot00000000000000# Copyright (c) 2018-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SOCKS forwarding support""" from ipaddress import ip_address from typing import TYPE_CHECKING, Callable, Optional from .forward import SSHForwarderCoro, SSHLocalForwarder from .session import DataType if TYPE_CHECKING: # pylint: disable=cyclic-import from .connection import SSHConnection _RecvHandler = Optional[Callable[[bytes], None]] SOCKS4 = 0x04 SOCKS5 = 0x05 SOCKS_CONNECT = 0x01 SOCKS4_OK = 0x5a SOCKS5_OK = 0x00 SOCKS5_AUTH_NONE = 0x00 SOCKS5_ADDR_IPV4 = 0x01 SOCKS5_ADDR_HOSTNAME = 0x03 SOCKS5_ADDR_IPV6 = 0x04 SOCKS4_OK_RESPONSE = bytes((0, SOCKS4_OK, 0, 0, 0, 0, 0, 0)) SOCKS5_OK_RESPONSE_HDR = bytes((SOCKS5, SOCKS5_OK, 0)) _socks5_addr_len = { SOCKS5_ADDR_IPV4: 4, SOCKS5_ADDR_IPV6: 16 } class SSHSOCKSForwarder(SSHLocalForwarder): """SOCKS dynamic port forwarding connection handler""" def __init__(self, conn: 'SSHConnection', coro: SSHForwarderCoro): super().__init__(conn, coro) self._inpbuf = b'' self._bytes_needed = 2 self._recv_handler: _RecvHandler = self._recv_version self._addrtype = 0 self._host = '' self._port = 0 def _connect(self) -> None: """Send request to open a new tunnel connection""" assert self._transport is not None self._recv_handler = None orig_host, orig_port = self._transport.get_extra_info('peername')[:2] self.forward(self._host, self._port, orig_host, orig_port) def _send_socks4_ok(self) -> None: """Send SOCKS4 success response""" assert self._transport is not None self._transport.write(SOCKS4_OK_RESPONSE) def _send_socks5_ok(self) -> None: """Send SOCKS5 success response""" assert self._transport is not None addrlen = _socks5_addr_len[self._addrtype] + 2 self._transport.write(SOCKS5_OK_RESPONSE_HDR + bytes((self._addrtype,)) + addrlen * b'\0') def _recv_version(self, data: bytes) -> None: """Parse SOCKS version""" if data[0] == SOCKS4: if data[1] == SOCKS_CONNECT: self._bytes_needed = 6 self._recv_handler = self._recv_socks4_addr else: self.close() elif data[0] == SOCKS5: self._bytes_needed = data[1] self._recv_handler = self._recv_socks5_authlist else: self.close() def _recv_socks4_addr(self, data: bytes) -> None: """Parse SOCKSv4 address and port""" self._port = (data[0] << 8) + data[1] # If address is 0.0.0.x, read a hostname later if data[2:5] != b'\0\0\0' or data[5] == 0: self._host = str(ip_address(data[2:])) self._bytes_needed = -1 self._recv_handler = self._recv_socks4_user def _recv_socks4_user(self, data: bytes) -> None: """Parse SOCKSv4 username""" # pylint: disable=unused-argument if self._host: self._send_socks4_ok() self._connect() else: self._bytes_needed = -1 self._recv_handler = self._recv_socks4_hostname def _recv_socks4_hostname(self, data: bytes) -> None: """Parse SOCKSv4 hostname""" try: self._host = data.decode('utf-8') except UnicodeDecodeError: self.close() return self._send_socks4_ok() self._connect() def _recv_socks5_authlist(self, data: bytes) -> None: """Parse SOCKSv5 list of authentication methods""" assert self._transport is not None if SOCKS5_AUTH_NONE in data: self._transport.write(bytes((SOCKS5, SOCKS5_AUTH_NONE))) self._bytes_needed = 4 self._recv_handler = self._recv_socks5_command else: self.close() def _recv_socks5_command(self, data: bytes) -> None: """Parse SOCKSv5 command""" if data[0] == SOCKS5 and data[1] == SOCKS_CONNECT and data[2] == 0: if data[3] == SOCKS5_ADDR_HOSTNAME: self._bytes_needed = 1 self._recv_handler = self._recv_socks5_hostlen self._addrtype = SOCKS5_ADDR_IPV4 else: addrlen = _socks5_addr_len.get(data[3]) if addrlen: self._bytes_needed = addrlen self._recv_handler = self._recv_socks5_addr self._addrtype = data[3] else: self.close() else: self.close() def _recv_socks5_addr(self, data: bytes) -> None: """Parse SOCKSv5 address""" self._host = str(ip_address(data)) self._bytes_needed = 2 self._recv_handler = self._recv_socks5_port def _recv_socks5_hostlen(self, data: bytes) -> None: """Parse SOCKSv5 host length""" self._bytes_needed = data[0] self._recv_handler = self._recv_socks5_host def _recv_socks5_host(self, data: bytes) -> None: """Parse SOCKSv5 host""" try: self._host = data.decode('utf-8') except UnicodeDecodeError: self.close() return self._bytes_needed = 2 self._recv_handler = self._recv_socks5_port def _recv_socks5_port(self, data: bytes) -> None: """Parse SOCKSv5 port""" self._port = (data[0] << 8) + data[1] self._send_socks5_ok() self._connect() def data_received(self, data: bytes, datatype: DataType = None) -> None: """Handle incoming data from the SOCKS client""" if self._recv_handler: self._inpbuf += data while self._recv_handler: # type: ignore[truthy-function] if self._bytes_needed < 0: idx = self._inpbuf.find(b'\0') if idx >= 0: data = self._inpbuf[:idx] self._inpbuf = self._inpbuf[idx+1:] self._recv_handler(data) elif len(self._inpbuf) > 255: # SOCKSv4 user or hostname too long self.close() return else: return else: if len(self._inpbuf) >= self._bytes_needed: data = self._inpbuf[:self._bytes_needed] self._inpbuf = self._inpbuf[self._bytes_needed:] self._recv_handler(data) else: return data = self._inpbuf self._inpbuf = b'' if data: super().data_received(data, datatype) asyncssh-2.20.0/asyncssh/stream.py000066400000000000000000000735141475467777400171740ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH stream handlers""" import asyncio import inspect import re from typing import TYPE_CHECKING, Any, AnyStr, AsyncIterator from typing import Callable, Dict, Generic, Iterable, List from typing import Optional, Pattern, Set, Tuple, Union, cast from .constants import EXTENDED_DATA_STDERR from .logging import SSHLogger from .misc import MaybeAwait, BreakReceived, SignalReceived from .misc import SoftEOFReceived, TerminalSizeChanged from .session import DataType, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .sftp import SFTPServer, run_sftp_server from .scp import run_scp_server if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHChannel from .connection import SSHConnection if TYPE_CHECKING: _WaiterFuture = asyncio.Future[None] else: _WaiterFuture = asyncio.Future _RecvBuf = List[Union[AnyStr, Exception]] _RecvBufMap = Dict[DataType, _RecvBuf[AnyStr]] _ReadLocks = Dict[DataType, asyncio.Lock] _ReadWaiters = Dict[DataType, Optional[_WaiterFuture]] _DrainWaiters = Dict[DataType, Set[_WaiterFuture]] SSHSocketSessionFactory = Callable[['SSHReader', 'SSHWriter'], MaybeAwait[None]] _OptSocketSessionFactory = Optional[SSHSocketSessionFactory] SSHServerSessionFactory = Callable[['SSHReader', 'SSHWriter', 'SSHWriter'], MaybeAwait[None]] _OptServerSessionFactory = Optional[SSHServerSessionFactory] SFTPServerFactory = Callable[['SSHChannel[bytes]'], SFTPServer] _OptSFTPServerFactory = Optional[SFTPServerFactory] _NEWLINE = object() class SSHReader(Generic[AnyStr]): """SSH read stream handler""" def __init__(self, session: 'SSHStreamSession[AnyStr]', chan: 'SSHChannel[AnyStr]', datatype: DataType = None): self._session: 'SSHStreamSession[AnyStr]' = session self._chan: 'SSHChannel[AnyStr]' = chan self._datatype = datatype async def __aiter__(self) -> AsyncIterator[AnyStr]: """Allow SSHReader to be an async iterator""" async for result in self._session.aiter(self._datatype): yield result @property def channel(self) -> 'SSHChannel[AnyStr]': """The SSH channel associated with this stream""" return self._chan @property def logger(self) -> SSHLogger: """The SSH logger associated with this stream""" return self._chan.logger def get_extra_info(self, name: str, default: Any = None) -> Any: """Return additional information about this stream This method returns extra information about the channel associated with this stream. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ return self._chan.get_extra_info(name, default) def feed_data(self, data: AnyStr) -> None: """Feed data to the associated session This method feeds data to the SSH session associated with this stream, providing compatibility with the :meth:`feed_data() ` method on :class:`asyncio.StreamReader`. This is mostly useful for testing. """ self._session.data_received(data, self._datatype) def feed_eof(self) -> None: """Feed EOF to the associated session This method feeds an end-of-file indication to the SSH session associated with this stream, providing compatibility with the :meth:`feed_eof() ` method on :class:`asyncio.StreamReader`. This is mostly useful for testing. """ self._session.eof_received() async def read(self, n: int = -1) -> AnyStr: """Read data from the stream This method is a coroutine which reads up to `n` bytes or characters from the stream. If `n` is not provided or set to `-1`, it reads until EOF or a signal is received. If EOF is received and the receive buffer is empty, an empty `bytes` or `str` object is returned. If the next data in the stream is a signal, the signal is delivered as a raised exception. .. note:: Unlike traditional `asyncio` stream readers, the data will be delivered as either `bytes` or a `str` depending on whether an encoding was specified when the underlying channel was opened. """ return await self._session.read(self._datatype, n, exact=False) async def readline(self) -> AnyStr: """Read one line from the stream This method is a coroutine which reads one line, ending in `'\\n'`. If EOF is received before `'\\n'` is found, the partial line is returned. If EOF is received and the receive buffer is empty, an empty `bytes` or `str` object is returned. If the next data in the stream is a signal, the signal is delivered as a raised exception. .. note:: In Python 3.5 and later, :class:`SSHReader` objects can also be used as async iterators, returning input data one line at a time. """ return await self._session.readline(self._datatype) async def readuntil(self, separator: object, max_separator_len = 0) -> AnyStr: """Read data from the stream until `separator` is seen This method is a coroutine which reads from the stream until the requested separator is seen. If a match is found, the returned data will include the separator at the end. The `separator` argument can be a single `bytes` or `str` value, a sequence of multiple `bytes` or `str` values, or a compiled regex (`re.Pattern`) to match against, returning data as soon as a matching separator is found in the stream. When passing a regex pattern as the separator, the `max_separator_len` argument should be set to the maximum length of an expected separator match. This can greatly improve performance, by minimizing how far back into the stream must be searched for a match. When passing literal separators to match against, the max separator length will be set automatically. .. note:: For best results, a separator regex should both begin and end with data which is as unique as possible, and should not start or end with optional or repeated elements. Otherwise, you run the risk of failing to match parts of a separator when it is split across multiple reads. If EOF or a signal is received before a match occurs, an :exc:`IncompleteReadError ` is raised and its `partial` attribute will contain the data in the stream prior to the EOF or signal. If the next data in the stream is a signal, the signal is delivered as a raised exception. """ return await self._session.readuntil(separator, self._datatype, max_separator_len) async def readexactly(self, n: int) -> AnyStr: """Read an exact amount of data from the stream This method is a coroutine which reads exactly n bytes or characters from the stream. If EOF or a signal is received in the stream before `n` bytes are read, an :exc:`IncompleteReadError ` is raised and its `partial` attribute will contain the data before the EOF or signal. If the next data in the stream is a signal, the signal is delivered as a raised exception. """ return await self._session.read(self._datatype, n, exact=True) def at_eof(self) -> bool: """Return whether the stream is at EOF This method returns `True` when EOF has been received and all data in the stream has been read. """ return self._session.at_eof(self._datatype) def get_redirect_info(self) -> Tuple['SSHStreamSession[AnyStr]', DataType]: """Get information needed to redirect from this SSHReader""" return self._session, self._datatype class SSHWriter(Generic[AnyStr]): """SSH write stream handler""" def __init__(self, session: 'SSHStreamSession[AnyStr]', chan: 'SSHChannel[AnyStr]', datatype: DataType = None): self._session: 'SSHStreamSession[AnyStr]' = session self._chan: 'SSHChannel[AnyStr]' = chan self._datatype = datatype @property def channel(self) -> 'SSHChannel[AnyStr]': """The SSH channel associated with this stream""" return self._chan @property def logger(self) -> SSHLogger: """The SSH logger associated with this stream""" return self._chan.logger def get_extra_info(self, name: str, default: Any = None) -> Any: """Return additional information about this stream This method returns extra information about the channel associated with this stream. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ return self._chan.get_extra_info(name, default) def can_write_eof(self) -> bool: """Return whether the stream supports :meth:`write_eof`""" return self._chan.can_write_eof() def close(self) -> None: """Close the channel .. note:: After this is called, no data can be read or written from any of the streams associated with this channel. """ return self._chan.close() def is_closing(self) -> bool: """Return if the stream is closing or is closed""" return self._chan.is_closing() async def wait_closed(self) -> None: """Wait until the stream is closed This should be called after :meth:`close` to wait until the underlying connection is closed. """ await self._chan.wait_closed() async def drain(self) -> None: """Wait until the write buffer on the channel is flushed This method is a coroutine which blocks the caller if the stream is currently paused for writing, returning when enough data has been sent on the channel to allow writing to resume. This can be used to avoid buffering an excessive amount of data in the channel's send buffer. """ await self._session.drain(self._datatype) def write(self, data: AnyStr) -> None: """Write data to the stream This method writes bytes or characters to the stream. .. note:: Unlike traditional `asyncio` stream writers, the data must be supplied as either `bytes` or a `str` depending on whether an encoding was specified when the underlying channel was opened. """ return self._chan.write(data, self._datatype) def writelines(self, list_of_data: Iterable[AnyStr]) -> None: """Write a collection of data to the stream""" return self._chan.writelines(list_of_data, self._datatype) def write_eof(self) -> None: """Write EOF on the channel This method sends an end-of-file indication on the channel, after which no more data can be written. .. note:: On an :class:`SSHServerChannel` where multiple output streams are created, writing EOF on one stream signals EOF for all of them, since it applies to the channel as a whole. """ return self._chan.write_eof() def get_redirect_info(self) -> Tuple['SSHStreamSession[AnyStr]', DataType]: """Get information needed to redirect to this SSHWriter""" return self._session, self._datatype class SSHStreamSession(Generic[AnyStr]): """SSH stream session handler""" def __init__(self) -> None: self._chan: Optional['SSHChannel[AnyStr]'] = None self._conn: Optional['SSHConnection'] = None self._encoding: Optional[str] = None self._errors = 'strict' self._loop: Optional[asyncio.AbstractEventLoop] = None self._limit = 0 self._exception: Optional[Exception] = None self._eof_received = False self._connection_lost = False self._read_paused = False self._write_paused = False self._recv_buf_len = 0 self._recv_buf: _RecvBufMap[AnyStr] = {None: []} self._read_locks: _ReadLocks = {None: asyncio.Lock()} self._read_waiters: _ReadWaiters = {None: None} self._drain_waiters: _DrainWaiters = {None: set()} async def aiter(self, datatype: DataType) -> AsyncIterator[AnyStr]: """Allow SSHReader to be an async iterator""" while not self.at_eof(datatype): yield await self.readline(datatype) async def _block_read(self, datatype: DataType) -> None: """Wait for more data to arrive on the stream""" try: assert self._loop is not None waiter: _WaiterFuture = self._loop.create_future() self._read_waiters[datatype] = waiter await waiter finally: self._read_waiters[datatype] = None def _unblock_read(self, datatype: DataType) -> None: """Signal that more data has arrived on the stream""" waiter = self._read_waiters[datatype] if waiter and not waiter.done(): waiter.set_result(None) def _should_block_drain(self, datatype: DataType) -> bool: """Return whether output is still being written to the channel""" # pylint: disable=unused-argument return self._write_paused and not self._connection_lost def _unblock_drain(self, datatype: DataType) -> None: """Signal that more data can be written on the stream""" if not self._should_block_drain(datatype): for waiter in self._drain_waiters[datatype]: if not waiter.done(): # pragma: no branch waiter.set_result(None) def _should_pause_reading(self) -> bool: """Return whether to pause reading from the channel""" return bool(self._limit) and self._recv_buf_len >= self._limit def _maybe_pause_reading(self) -> bool: """Pause reading if necessary""" if not self._read_paused and self._should_pause_reading(): assert self._chan is not None self._read_paused = True self._chan.pause_reading() return True else: return False def _maybe_resume_reading(self) -> bool: """Resume reading if necessary""" if self._read_paused and not self._should_pause_reading(): assert self._chan is not None self._read_paused = False self._chan.resume_reading() return True else: return False def connection_made(self, chan: 'SSHChannel[AnyStr]') -> None: """Handle a newly opened channel""" self._chan = chan self._conn = chan.get_connection() self._encoding, self._errors = chan.get_encoding() self._loop = chan.get_loop() self._limit = self._chan.get_recv_window() for datatype in chan.get_read_datatypes(): self._recv_buf[datatype] = [] self._read_locks[datatype] = asyncio.Lock() self._read_waiters[datatype] = None for datatype in chan.get_write_datatypes(): self._drain_waiters[datatype] = set() def connection_lost(self, exc: Optional[Exception]) -> None: """Handle an incoming channel close""" self._connection_lost = True self._exception = exc if not self._eof_received: if exc: for datatype in self._read_waiters: self._recv_buf[datatype].append(exc) self.eof_received() for datatype in self._drain_waiters: self._unblock_drain(datatype) def data_received(self, data: AnyStr, datatype: DataType) -> None: """Handle incoming data on the channel""" self._recv_buf[datatype].append(data) self._recv_buf_len += len(data) self._unblock_read(datatype) self._maybe_pause_reading() def eof_received(self) -> bool: """Handle an incoming end of file on the channel""" self._eof_received = True for datatype in self._read_waiters: self._unblock_read(datatype) return True def at_eof(self, datatype: DataType) -> bool: """Return whether end of file has been received on the channel""" return self._eof_received and not self._recv_buf[datatype] def pause_writing(self) -> None: """Handle a request to pause writing on the channel""" self._write_paused = True def resume_writing(self) -> None: """Handle a request to resume writing on the channel""" self._write_paused = False for datatype in self._drain_waiters: self._unblock_drain(datatype) async def read(self, datatype: DataType, n: int, exact: bool) -> AnyStr: """Read data from the channel""" recv_buf = self._recv_buf[datatype] data: List[AnyStr] = [] break_read = False async with self._read_locks[datatype]: while True: while recv_buf and n != 0: if isinstance(recv_buf[0], Exception): if data: break_read = True break else: exc = cast(Exception, recv_buf.pop(0)) if isinstance(exc, SoftEOFReceived): n = 0 break else: raise exc l = len(recv_buf[0]) if l > n > 0: data.append(recv_buf[0][:n]) recv_buf[0] = recv_buf[0][n:] self._recv_buf_len -= n n = 0 break data.append(cast(AnyStr, recv_buf.pop(0))) self._recv_buf_len -= l n -= l if self._maybe_resume_reading(): continue if n == 0 or (n > 0 and data and not exact) or \ (n < 0 and recv_buf) or \ self._eof_received or break_read: break await self._block_read(datatype) result = cast(AnyStr, '' if self._encoding else b'').join(data) if n > 0 and exact: raise asyncio.IncompleteReadError(cast(bytes, result), len(result) + n) return result async def readline(self, datatype: DataType) -> AnyStr: """Read one line from the stream""" try: return await self.readuntil(_NEWLINE, datatype) except asyncio.IncompleteReadError as exc: return cast(AnyStr, exc.partial) async def readuntil(self, separator: object, datatype: DataType, max_separator_len = 0) -> AnyStr: """Read data from the channel until a separator is seen""" if not separator: raise ValueError('Separator cannot be empty') buf = cast(AnyStr, '' if self._encoding else b'') recv_buf = self._recv_buf[datatype] if separator is _NEWLINE: seplen = 1 separators = cast(AnyStr, '\n' if self._encoding else b'\n') pat = re.compile(separators) elif isinstance(separator, (bytes, str)): seplen = len(separator) pat = re.compile(re.escape(cast(AnyStr, separator))) elif isinstance(separator, Pattern): seplen = max_separator_len pat = cast(Pattern[AnyStr], separator) else: bar = cast(AnyStr, '|' if self._encoding else b'|') seplist = list(cast(Iterable[AnyStr], separator)) seplen = max(len(sep) for sep in seplist) separators = bar.join(re.escape(sep) for sep in seplist) pat = re.compile(separators) curbuf = 0 buflen = 0 async with self._read_locks[datatype]: while True: while curbuf < len(recv_buf): if isinstance(recv_buf[curbuf], Exception): if buf: recv_buf[:curbuf] = [] self._recv_buf_len -= buflen raise asyncio.IncompleteReadError( cast(bytes, buf), None) else: exc = recv_buf.pop(0) if isinstance(exc, SoftEOFReceived): return buf else: raise cast(Exception, exc) newbuf = cast(AnyStr, recv_buf[curbuf]) buf += newbuf start = 0 if seplen == 0 else max(buflen + 1 - seplen, 0) match = pat.search(buf, start) if match: idx = match.end() recv_buf[:curbuf] = [] recv_buf[0] = buf[idx:] buf = buf[:idx] self._recv_buf_len -= idx if not recv_buf[0]: recv_buf.pop(0) self._maybe_resume_reading() return buf buflen += len(newbuf) curbuf += 1 if self._read_paused or self._eof_received: recv_buf[:curbuf] = [] self._recv_buf_len -= buflen self._maybe_resume_reading() raise asyncio.IncompleteReadError(cast(bytes, buf), None) await self._block_read(datatype) async def drain(self, datatype: DataType) -> None: """Wait for data written to the channel to drain""" while self._should_block_drain(datatype): try: assert self._loop is not None waiter: _WaiterFuture = self._loop.create_future() self._drain_waiters[datatype].add(waiter) await waiter finally: self._drain_waiters[datatype].remove(waiter) if self._connection_lost: exc = self._exception if not exc and self._write_paused: exc = BrokenPipeError() if exc: raise exc class SSHClientStreamSession(SSHStreamSession[AnyStr], SSHClientSession[AnyStr]): """SSH client stream session handler""" class SSHServerStreamSession(SSHStreamSession[AnyStr], SSHServerSession[AnyStr]): """SSH server stream session handler""" def __init__(self, session_factory: _OptServerSessionFactory, sftp_factory: _OptSFTPServerFactory = None, sftp_version = 0, allow_scp = False): super().__init__() self._session_factory = session_factory self._sftp_factory = sftp_factory self._sftp_version = sftp_version self._allow_scp = allow_scp and bool(sftp_factory) def _init_sftp_server(self) -> SFTPServer: """Initialize an SFTP server for this stream to use""" assert self._chan is not None self._chan.set_encoding(None) self._encoding = None assert self._sftp_factory is not None return self._sftp_factory(cast('SSHChannel[bytes]', self._chan)) def shell_requested(self) -> bool: """Return whether a shell can be requested""" return bool(self._session_factory) def exec_requested(self, command: str) -> bool: """Return whether execution of a command can be requested""" # Avoid incorrect pylint suggestion to use ternary # pylint: disable=consider-using-ternary return ((self._allow_scp and command.startswith('scp ')) or bool(self._session_factory)) def subsystem_requested(self, subsystem: str) -> bool: """Return whether starting a subsystem can be requested""" if subsystem == 'sftp': return bool(self._sftp_factory) else: return bool(self._session_factory) def session_started(self) -> None: """Start a session for this newly opened server channel""" assert self._chan is not None command = self._chan.get_command() stdin = SSHReader[AnyStr](self, self._chan) stdout = SSHWriter[AnyStr](self, self._chan) stderr = SSHWriter[AnyStr](self, self._chan, EXTENDED_DATA_STDERR) handler: MaybeAwait[None] if self._chan.get_subsystem() == 'sftp': stdin_bytes = cast(SSHReader[bytes], stdin) stdout_bytes = cast(SSHWriter[bytes], stdout) handler = run_sftp_server(self._init_sftp_server(), stdin_bytes, stdout_bytes, self._sftp_version) elif self._allow_scp and command and command.startswith('scp '): stdin_bytes = cast(SSHReader[bytes], stdin) stdout_bytes = cast(SSHWriter[bytes], stdout) stderr_bytes = cast(SSHWriter[bytes], stderr) handler = run_scp_server(self._init_sftp_server(), command, stdin_bytes, stdout_bytes, stderr_bytes) else: assert self._session_factory is not None handler = self._session_factory(stdin, stdout, stderr) if inspect.isawaitable(handler): assert self._conn is not None assert handler is not None self._conn.create_task(handler, stdin.logger) def exception_received(self, exc: Exception) -> None: """Handle an incoming exception on the channel""" self._recv_buf[None].append(exc) self._unblock_read(None) def break_received(self, msec: int) -> bool: """Handle an incoming break on the channel""" self.exception_received(BreakReceived(msec)) return True def signal_received(self, signal: str) -> None: """Handle an incoming signal on the channel""" self.exception_received(SignalReceived(signal)) def soft_eof_received(self) -> None: """Handle an incoming soft EOF on the channel""" self.exception_received(SoftEOFReceived()) def terminal_size_changed(self, width: int, height: int, pixwidth: int, pixheight: int) -> None: """Handle an incoming terminal size change on the channel""" self.exception_received(TerminalSizeChanged(width, height, pixwidth, pixheight)) class SSHSocketStreamSession(SSHStreamSession[AnyStr]): """Socket stream session handler""" def __init__(self, session_factory: _OptSocketSessionFactory = None): super().__init__() self._session_factory = session_factory def session_started(self) -> None: """Start a session for this newly opened socket channel""" if self._session_factory: assert self._chan is not None reader = SSHReader[AnyStr](self, self._chan) writer = SSHWriter[AnyStr](self, self._chan) handler = self._session_factory(reader, writer) if inspect.isawaitable(handler): assert self._conn is not None assert handler is not None self._conn.create_task(handler, reader.logger) class SSHTCPStreamSession(SSHSocketStreamSession[AnyStr], SSHTCPSession[AnyStr]): """TCP stream session handler""" class SSHUNIXStreamSession(SSHSocketStreamSession[AnyStr], SSHUNIXSession[AnyStr]): """UNIX stream session handler""" class SSHTunTapStreamSession(SSHSocketStreamSession[bytes], SSHTunTapSession): """TUN/TAP stream session handler""" async def aiter(self, datatype: DataType) -> AsyncIterator[bytes]: """Allow SSHReader to be an async iterator""" while True: packet = await self.read(datatype) if packet: yield packet else: break async def read(self, datatype: DataType, n: int = -1, exact: bool = False) -> bytes: """Override read to preserve TUN/TAP packet boundaries""" recv_buf = self._recv_buf[datatype] while not self._eof_received: if recv_buf: data = cast(bytes, recv_buf.pop(0)) self._recv_buf_len -= len(data) self._maybe_resume_reading() return data else: await self._block_read(datatype) return b'' asyncssh-2.20.0/asyncssh/subprocess.py000066400000000000000000000252121475467777400200610ustar00rootroot00000000000000# Copyright (c) 2019-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH subprocess handlers""" from typing import TYPE_CHECKING, Any, AnyStr, Callable from typing import Dict, Generic, Iterable, Optional from .constants import EXTENDED_DATA_STDERR from .process import SSHClientProcess from .session import DataType if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHChannel, SSHClientChannel SubprocessFactory = Callable[[], 'SSHSubprocessProtocol'] class SSHSubprocessPipe(Generic[AnyStr]): """SSH subprocess pipe""" def __init__(self, chan: 'SSHClientChannel[AnyStr]', datatype: DataType = None): self._chan: 'SSHClientChannel[AnyStr]' = chan self._datatype = datatype def close(self) -> None: """Shut down the remote process""" self._chan.close() def get_extra_info(self, name: str, default: Any = None) -> Any: """Return additional information about the remote process This method returns extra information about the channel associated with this subprocess. See :meth:`get_extra_info() ` on :class:`SSHClientChannel` for additional information. """ return self._chan.get_extra_info(name, default) class SSHSubprocessReadPipe(SSHSubprocessPipe[AnyStr]): """SSH subprocess pipe reader""" def pause_reading(self) -> None: """Pause delivery of incoming data from the remote process""" self._chan.pause_reading() def resume_reading(self) -> None: """Resume delivery of incoming data from the remote process""" self._chan.resume_reading() class SSHSubprocessWritePipe(SSHSubprocessPipe[AnyStr]): """SSH subprocess pipe writer""" def abort(self) -> None: """Forcibly close the channel to the remote process""" self._chan.abort() def can_write_eof(self) -> bool: """Return whether the pipe supports :meth:`write_eof`""" return self._chan.can_write_eof() def get_write_buffer_size(self) -> int: """Return the current size of the pipe's output buffer""" return self._chan.get_write_buffer_size() def set_write_buffer_limits(self, high: Optional[int] = None, low: Optional[int] = None) -> None: """Set the high- and low-water limits for write flow control""" self._chan.set_write_buffer_limits(high, low) def write(self, data: AnyStr) -> None: """Write data on this pipe""" self._chan.write(data, self._datatype) def writelines(self, list_of_data: Iterable[AnyStr]) -> None: """Write a list of data bytes on this pipe""" self._chan.writelines(list_of_data, self._datatype) def write_eof(self) -> None: """Write EOF on this pipe""" self._chan.write_eof() class SSHSubprocessProtocol(Generic[AnyStr]): """SSH subprocess protocol This class conforms to :class:`asyncio.SubprocessProtocol`, but with the following enhancement: * If encoding is set when the subprocess is created, all data passed to :meth:`pipe_data_received` will be string values containing Unicode data. However, for compatibility with :class:`asyncio.SubprocessProtocol`, encoding defaults to `None`, in which case all data is delivered as bytes. """ def connection_made(self, transport: 'SSHSubprocessTransport[AnyStr]') -> None: """Called when a remote process is successfully started This method is called when a remote process is successfully started. The transport parameter should be stored if needed for later use. :param transport: The transport to use to communicate with the remote process. :type transport: :class:`SSHSubprocessTransport` """ def pipe_data_received(self, fd: int, data: AnyStr) -> None: """Called when data is received from the remote process This method is called when data is received from the remote process. If an encoding was specified when the process was started, the data will be delivered as a string after decoding with the requested encoding. Otherwise, the data will be delivered as bytes. :param fd: The integer file descriptor of the pipe data was received on. This will be 1 for stdout or 2 for stderr. :param data: The data received from the remote process :type fd: `int` :type data: `str` or `bytes` """ def pipe_connection_lost(self, fd: int, exc: Optional[Exception]) -> None: """Called when the pipe to a remote process is closed This method is called when a pipe to a remote process is closed. If the channel is shut down cleanly, *exc* will be `None`. Otherwise, it will be an exception explaining the reason the pipe was closed. :param fd: The integer file descriptor of the pipe which was closed. This will be 1 for stdout or 2 for stderr. :param exc: The exception which caused the channel to close, or `None` if the channel closed cleanly. :type fd: `int` :type exc: :class:`Exception` or `None` """ def process_exited(self) -> None: """Called when a remote process has exited This method is called when the remote process has exited. Exit status information can be retrieved by calling :meth:`get_returncode() ` on the transport provided in :meth:`connection_made`. """ class SSHSubprocessTransport(SSHClientProcess[AnyStr]): """SSH subprocess transport This class conforms to :class:`asyncio.SubprocessTransport`, but with the following enhancements: * All functionality available through :class:`SSHClientProcess` is also available here, such as the ability to dynamically redirect stdin, stdout, and stderr at any time during the lifetime of the process. * If encoding is set when the subprocess is created, all data written to the transports created by :meth:`get_pipe_transport` should be strings containing Unicode data. The encoding defaults to `None`, though, to preserve compatibility with :class:`asyncio.SubprocessTransport`, which expects data to be written as bytes. """ _chan: 'SSHClientChannel[AnyStr]' def __init__(self, protocol_factory: SubprocessFactory): super().__init__() self._pipes: Dict[int, SSHSubprocessPipe[AnyStr]] = {} self._protocol: SSHSubprocessProtocol[AnyStr] = protocol_factory() def get_protocol(self) -> SSHSubprocessProtocol[AnyStr]: """Return the subprocess protocol associated with this transport""" return self._protocol def connection_made(self, chan: 'SSHChannel[AnyStr]') -> None: """Handle a newly opened channel""" super().connection_made(chan) self._protocol.connection_made(self) self._pipes = { 0: SSHSubprocessWritePipe(self._chan), 1: SSHSubprocessReadPipe(self._chan), 2: SSHSubprocessReadPipe(self._chan, EXTENDED_DATA_STDERR) } def session_started(self) -> None: """Override SSHClientProcess to avoid creating SSHReader/SSHWriter streams, since this class uses read/write pipe objects instead""" def connection_lost(self, exc: Optional[Exception]) -> None: """Handle an incoming channel close""" self._protocol.pipe_connection_lost(1, exc) self._protocol.pipe_connection_lost(2, exc) super().connection_lost(exc) def data_received(self, data: AnyStr, datatype: DataType) -> None: """Handle incoming data from the remote process""" writer = self._writers.get(datatype) if writer: writer.write(data) else: fd = 2 if datatype == EXTENDED_DATA_STDERR else 1 self._protocol.pipe_data_received(fd, data) def exit_status_received(self, status: int) -> None: """Handle exit status for the remote process""" super().exit_status_received(status) self._protocol.process_exited() def exit_signal_received(self, signal: str, core_dumped: bool, msg: str, lang: str) -> None: """Handle exit signal for the remote process""" super().exit_signal_received(signal, core_dumped, msg, lang) self._protocol.process_exited() def get_pid(self) -> Optional[int]: """Return the PID of the remote process This method always returns `None`, since SSH doesn't report remote PIDs. """ # pylint: disable=no-self-use return None def get_pipe_transport(self, fd: int) -> \ Optional[SSHSubprocessPipe[AnyStr]]: """Return a transport for the requested stream :param fd: The integer file descriptor (0-2) to return the transport for, where 0 means stdin, 1 means stdout, and 2 means stderr. :type fd: `int` :returns: an :class:`SSHSubprocessReadPipe` or :class:`SSHSubprocessWritePipe` """ return self._pipes.get(fd) def get_returncode(self) -> Optional[int]: """Return the exit status or signal for the remote process This method returns the exit status of the session if one has been sent. If an exit signal was sent, this method returns the negative of the numeric value of that signal, matching the behavior of :meth:`asyncio.SubprocessTransport.get_returncode`. If neither has been sent, this method returns `None`. :returns: `int` or `None` """ return self.returncode asyncssh-2.20.0/asyncssh/tuntap.py000066400000000000000000000276341475467777400172160ustar00rootroot00000000000000# Copyright (c) 2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH TUN/TAP forwarding support""" import asyncio import errno import os import socket import struct import sys import threading from typing import Callable, Optional, Tuple, cast if sys.platform != 'win32': # pragma: no branch import fcntl SSH_TUN_MODE_POINTTOPOINT = 1 # layer 3 IP packets SSH_TUN_MODE_ETHERNET = 2 # layer 2 Ethenet frames SSH_TUN_UNIT_ANY = 0x7fffffff # The server may choose the unit SSH_TUN_AF_INET = 2 # IPv4 SSH_TUN_AF_INET6 = 24 # IPv6 DARWIN_CTLIOCGINFO = 0xc0644e03 DARWIN_CTLIOCGINFO_FMT = 'I96s' DARWIN_SIOCGIFFLAGS = 0xc0206911 DARWIN_SIOCSIFFLAGS = 0x80206910 LINUX_TUNSETIFF = 0x400454ca LINUX_IFF_TUN = 0x1 LINUX_IFF_TAP = 0x2 LINUX_IFF_NO_PI = 0x1000 IFF_FMT = '16sH' IFF_UP = 0x1 class SSHTunTapTransport(asyncio.Transport): """Layer 2/3 tunnel transport""" def __init__(self, loop: asyncio.AbstractEventLoop, interface: str): super().__init__(extra={'interface': interface}) self._loop = loop self._protocol: Optional[asyncio.Protocol] = None def get_protocol(self) -> asyncio.BaseProtocol: # pragma: no cover """Get protocol object associated with transport""" assert self._protocol is not None return self._protocol def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: """Set protocol associated with transport""" self._protocol = cast(asyncio.Protocol, protocol) def abort(self) -> None: # pragma: no cover """Abort this transport""" self.close() def is_reading(self) -> bool: """Return if the transport is reading data""" raise NotImplementedError def pause_reading(self) -> None: """Pause reading""" raise NotImplementedError def resume_reading(self) -> None: """Resume reading""" raise NotImplementedError def can_write_eof(self) -> bool: # pragma: no cover """This transport doesn't support writing EOF""" return False def get_write_buffer_size(self) -> int: # pragma: no cover """This transport has no output buffer""" return 0 def get_write_buffer_limits(self) -> Tuple[int, int]: # pragma: no cover """This transport doesn't support write buffer limits""" return 0, 0 def set_write_buffer_limits(self, high: Optional[int] = None, low: Optional[int] = None) -> None: """This transport doesn't support write buffer limits""" def write_eof(self) -> None: """Ignore writing EOF on this transport""" def write(self, data: bytes) -> None: """Write a packet""" raise NotImplementedError def is_closing(self) -> bool: # pragma: no cover """Return if the transport is closing""" return False def close(self) -> None: """Close this transport""" raise NotImplementedError class SSHTunTapOSXTransport(SSHTunTapTransport): """TunTapOSX transport""" def __init__(self, loop: asyncio.AbstractEventLoop, mode: int, unit: Optional[int]): prefix = 'tun' if mode == SSH_TUN_MODE_POINTTOPOINT else 'tap' if unit is None: for i in range(16): try: file = open(f'/dev/{prefix}{i}', 'rb+', buffering=0) except OSError: pass else: unit = i break else: raise OSError(errno.EBUSY, f'No {prefix} devices available') else: file = open(f'/dev/{prefix}{unit}', 'rb+', buffering=0) interface = f'{prefix}{unit}' name = interface.encode() sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: ifr = struct.pack(IFF_FMT, name, 0) ifr = fcntl.ioctl(sock, DARWIN_SIOCGIFFLAGS, ifr) _, flags = struct.unpack(IFF_FMT, ifr) flags |= IFF_UP ifr = struct.pack(IFF_FMT, name, flags) fcntl.ioctl(sock, DARWIN_SIOCSIFFLAGS, ifr) finally: sock.close() super().__init__(loop, interface) self._file = file self._read_thread: Optional[threading.Thread] = None os.set_blocking(file.fileno(), True) def is_reading(self) -> bool: """Return if the transport is reading data""" return self._read_thread is not None # pragma: no cover def pause_reading(self) -> None: """Pause reading""" if self._read_thread: # pragma: no branch self._read_thread.join() self._read_thread = None def resume_reading(self) -> None: """Resume reading""" if not self._read_thread: # pragma: no branch self._read_thread = threading.Thread(target=self._read_loop) self._read_thread.daemon = True self._read_thread.start() def _read_loop(self) -> None: """Loop reading packets until read is paused or done""" assert self._protocol is not None while True: try: data = self._file.read(65536) except OSError as exc: if exc.errno != errno.EBADF: # pragma: no cover self._loop.call_soon_threadsafe( self._protocol.connection_lost, exc) break else: self._loop.call_soon_threadsafe( self._protocol.data_received, data) def write(self, data: bytes) -> None: """Write a packet""" self._file.write(data) def close(self) -> None: """Close this transport""" self._file.close() self.pause_reading() class SSHDarwinUTunTransport(SSHTunTapTransport): """Darwin UTun transport""" def __init__(self, loop: asyncio.AbstractEventLoop, unit: Optional[int]): sock = socket.socket(socket.PF_SYSTEM, socket.SOCK_DGRAM, socket.SYSPROTO_CONTROL) try: arg = struct.pack(DARWIN_CTLIOCGINFO_FMT, 0, b'com.apple.net.utun_control') ctl_info = fcntl.ioctl(sock, DARWIN_CTLIOCGINFO, arg) ctl_id, _ = struct.unpack(DARWIN_CTLIOCGINFO_FMT, ctl_info) unit = 0 if unit is None else unit - 15 sock.setblocking(False) sock.connect((ctl_id, unit)) _, unit = sock.getpeername() except OSError: sock.close() raise unit: int super().__init__(loop, f'utun{unit-1}') self._sock = sock self._reading = False def is_reading(self) -> bool: # pragma: no cover """Return if the transport is reading data""" return self._reading def pause_reading(self) -> None: """Pause reading""" self._reading = False self._loop.remove_reader(self._sock) def resume_reading(self) -> None: """Resume reading""" self._reading = True self._loop.add_reader(self._sock, self._read_ready) def _read_ready(self) -> None: """Read available packets from the transport""" assert self._protocol is not None while True: try: data = self._sock.recv(65540)[4:] except (BlockingIOError, InterruptedError): break except OSError as exc: # pragma: no cover self._protocol.connection_lost(exc) break else: self._protocol.data_received(data) def write(self, data: bytes) -> None: """Write a packet""" version = data[0] >> 4 family = socket.AF_INET if version == 4 else socket.AF_INET6 data = family.to_bytes(4, 'big') + data self._sock.send(data) def close(self) -> None: """Close this transport""" self._sock.close() self.pause_reading() class SSHLinuxTunTapTransport(SSHTunTapTransport): """Linux TUN/TAP transport""" def __init__(self, loop: asyncio.AbstractEventLoop, mode: int, unit: Optional[int]): file = open('/dev/net/tun', 'rb+', buffering=0) if mode == SSH_TUN_MODE_POINTTOPOINT: flags = LINUX_IFF_TUN | LINUX_IFF_NO_PI prefix = 'tun' else: flags = LINUX_IFF_TAP | LINUX_IFF_NO_PI prefix = 'tap' name = b'' if unit is None else f'{prefix}{unit}'.encode() ifr = struct.pack(IFF_FMT, name, flags) try: ifr = fcntl.ioctl(file, LINUX_TUNSETIFF, ifr) except OSError: file.close() raise name, _ = struct.unpack(IFF_FMT, ifr) interface = name.strip(b'\0').decode() super().__init__(loop, interface) self._file = file self._reading = False os.set_blocking(file.fileno(), False) def is_reading(self) -> bool: # pragma: no cover """Return if the transport is reading data""" return self._reading def pause_reading(self) -> None: """Pause reading""" self._reading = False try: self._loop.remove_reader(self._file) except OSError: # pragma: no cover pass def resume_reading(self) -> None: """Resume reading""" self._reading = True self._loop.add_reader(self._file, self._read_ready) def _read_ready(self) -> None: """Read available packets from the transport""" assert self._protocol is not None while True: try: data = self._file.read(65536) except OSError as exc: # pragma: no cover self._protocol.connection_lost(exc) break else: if data is None: break self._protocol.data_received(data) def write(self, data: bytes) -> None: """Write a packet""" self._file.write(data) def close(self) -> None: """Close this transport""" self._file.close() self.pause_reading() def create_tuntap(protocol_factory: Callable[[], asyncio.BaseProtocol], mode: int, unit: Optional[int]) -> \ Tuple[SSHTunTapTransport, asyncio.BaseProtocol]: """Create a local TUN or TAP network interface""" loop = asyncio.get_event_loop() transport: Optional[SSHTunTapTransport] = None if sys.platform == 'darwin': if unit is None: try: transport = SSHTunTapOSXTransport(loop, mode, unit) except OSError: if mode == SSH_TUN_MODE_POINTTOPOINT: transport = SSHDarwinUTunTransport(loop, unit) else: raise elif mode == SSH_TUN_MODE_POINTTOPOINT and unit >= 16: transport = SSHDarwinUTunTransport(loop, unit) else: transport = SSHTunTapOSXTransport(loop, mode, unit) elif sys.platform == 'linux': transport = SSHLinuxTunTapTransport(loop, mode, unit) else: raise OSError(errno.EPROTONOSUPPORT, f'TunTap not supported on {sys.platform}') assert transport is not None protocol = protocol_factory() protocol.connection_made(transport) transport.set_protocol(protocol) transport.resume_reading() return transport, protocol asyncssh-2.20.0/asyncssh/version.py000066400000000000000000000016031475467777400173540ustar00rootroot00000000000000# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """AsyncSSH version information""" __author__ = 'Ron Frederick' __author_email__ = 'ronf@timeheart.net' __url__ = 'http://asyncssh.timeheart.net' __version__ = '2.20.0' asyncssh-2.20.0/asyncssh/x11.py000066400000000000000000000411551475467777400163060ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """X11 forwarding support""" import asyncio import os from pathlib import Path import socket import time from typing import TYPE_CHECKING, Callable, Dict, Iterable from typing import NamedTuple, Optional, Sequence, Set, Tuple from .constants import OPEN_CONNECT_FAILED from .forward import SSHForwarder, SSHForwarderCoro from .listener import SSHListener, create_tcp_forward_listener from .logging import logger from .misc import ChannelOpenError from .session import DataType if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHChannel from .connection import SSHServerConnection _RecvHandler = Optional[Callable[[bytes], None]] # Xauth address families XAUTH_FAMILY_IPV4 = 0 XAUTH_FAMILY_DECNET = 1 XAUTH_FAMILY_IPV6 = 6 XAUTH_FAMILY_HOSTNAME = 256 XAUTH_FAMILY_WILD = 65535 # Xauth protocol values XAUTH_PROTO_COOKIE = b'MIT-MAGIC-COOKIE-1' XAUTH_COOKIE_LEN = 16 # Xauth lock information XAUTH_LOCK_SUFFIX = '-c' XAUTH_LOCK_TRIES = 5 XAUTH_LOCK_DELAY = 0.2 XAUTH_LOCK_DEAD = 5 # X11 display and port numbers X11_BASE_PORT = 6000 X11_DISPLAY_START = 10 X11_MAX_DISPLAYS = 64 # Host to listen on when doing X11 forwarding X11_LISTEN_HOST = 'localhost' def _parse_display(display: str) -> Tuple[str, str, int]: """Parse an X11 display value""" try: host, dpynum = display.rsplit(':', 1) if host.startswith('[') and host.endswith(']'): host = host[1:-1] idx = dpynum.find('.') if idx >= 0: screen = int(dpynum[idx+1:]) dpynum = dpynum[:idx] else: screen = 0 except (ValueError, UnicodeEncodeError): raise ValueError('Invalid X11 display') from None return host, dpynum, screen async def _lookup_host(loop: asyncio.AbstractEventLoop, host: str, family: int) -> Sequence[str]: """Look up IPv4 or IPv6 addresses of a host name""" try: addrinfo = await loop.getaddrinfo(host, 0, family=family, type=socket.SOCK_STREAM) except socket.gaierror: return [] return [ai[4][0] for ai in addrinfo] class SSHXAuthorityEntry(NamedTuple): """An entry in an Xauthority file""" family: int addr: bytes dpynum: bytes proto: bytes data: bytes def __bytes__(self) -> bytes: """Construct an Xauthority entry""" def _uint16(value: int) -> bytes: """Construct a big-endian 16-bit unsigned integer""" return value.to_bytes(2, 'big') def _string(data: bytes) -> bytes: """Construct a binary string with a 16-bit length""" return _uint16(len(data)) + data return b''.join((_uint16(self.family), _string(self.addr), _string(self.dpynum), _string(self.proto), _string(self.data))) class SSHX11ClientForwarder(SSHForwarder): """X11 forwarding connection handler""" def __init__(self, listener: 'SSHX11ClientListener', peer: SSHForwarder): super().__init__(peer) self._listener = listener self._inpbuf = b'' self._bytes_needed = 12 self._recv_handler: _RecvHandler = self._recv_prefix self._endian = b'' self._prefix = b'' self._auth_proto_len = 0 self._auth_data_len = 0 self._auth_proto = b'' self._auth_proto_pad = b'' self._auth_data = b'' self._auth_data_pad = b'' def _encode_uint16(self, value: int) -> bytes: """Encode a 16-bit unsigned integer""" if self._endian == b'B': return bytes((value >> 8, value & 255)) else: return bytes((value & 255, value >> 8)) def _decode_uint16(self, value: bytes) -> int: """Decode a 16-bit unsigned integer""" if self._endian == b'B': return (value[0] << 8) + value[1] else: return (value[1] << 8) + value[0] @staticmethod def _padded_len(length: int) -> int: """Return length rounded up to the next multiple of 4 bytes""" return ((length + 3) // 4) * 4 @staticmethod def _pad(data: bytes) -> bytes: """Pad a string to a multiple of 4 bytes""" length = len(data) % 4 return data + ((4 - length) * b'\00' if length else b'') def _recv_prefix(self, data: bytes) -> None: """Parse X11 client prefix""" self._endian = data[:1] self._prefix = data self._auth_proto_len = self._decode_uint16(data[6:8]) self._auth_data_len = self._decode_uint16(data[8:10]) self._recv_handler = self._recv_auth_proto self._bytes_needed = self._padded_len(self._auth_proto_len) def _recv_auth_proto(self, data: bytes) -> None: """Extract X11 auth protocol""" self._auth_proto = data[:self._auth_proto_len] self._auth_proto_pad = data[self._auth_proto_len:] self._recv_handler = self._recv_auth_data self._bytes_needed = self._padded_len(self._auth_data_len) def _recv_auth_data(self, data: bytes) -> None: """Extract X11 auth data""" self._auth_data = data[:self._auth_data_len] self._auth_data_pad = data[self._auth_data_len:] try: self._auth_data = self._listener.validate_auth(self._auth_data) except KeyError: reason = b'Invalid authentication key\n' response = b''.join((bytes((0, len(reason))), self._encode_uint16(11), self._encode_uint16(0), self._encode_uint16((len(reason) + 3) // 4), self._pad(reason))) try: self.write(response) self.write_eof() except OSError: # pragma: no cover pass self._inpbuf = b'' else: self._inpbuf = (self._prefix + self._auth_proto + self._auth_proto_pad + self._auth_data + self._auth_data_pad) self._recv_handler = None self._bytes_needed = 0 def data_received(self, data: bytes, datatype: DataType = None) -> None: """Handle incoming data from the X11 client""" if self._recv_handler: self._inpbuf += data while self._recv_handler: # type: ignore[truthy-function] if len(self._inpbuf) >= self._bytes_needed: data = self._inpbuf[:self._bytes_needed] self._inpbuf = self._inpbuf[self._bytes_needed:] self._recv_handler(data) else: return data = self._inpbuf self._inpbuf = b'' if data: super().data_received(data, datatype) class SSHX11ClientListener: """Client listener used to accept forwarded X11 connections""" def __init__(self, loop: asyncio.AbstractEventLoop, host: str, dpynum: str, auth_proto: bytes, auth_data: bytes): self._host = host self._dpynum = dpynum self._auth_proto = auth_proto self._local_auth = auth_data if host.startswith('/'): self._connect_coro: SSHForwarderCoro = loop.create_unix_connection self._connect_args: Sequence[object] = (host + ':' + dpynum,) elif host in ('', 'unix'): self._connect_coro = loop.create_unix_connection self._connect_args = ('/tmp/.X11-unix/X' + dpynum,) else: self._connect_coro = loop.create_connection self._connect_args = (host, X11_BASE_PORT + int(dpynum)) self._remote_auth: Dict['SSHChannel', bytes] = {} self._channel: Dict[bytes, Tuple['SSHChannel', bool]] = {} def attach(self, display: str, chan: 'SSHChannel', single_connection: bool) -> Tuple[bytes, bytes, int]: """Attach a channel to this listener""" host, dpynum, screen = _parse_display(display) if self._host != host or self._dpynum != dpynum: raise ValueError('Already forwarding to another X11 display') remote_auth = os.urandom(len(self._local_auth)) self._remote_auth[chan] = remote_auth self._channel[remote_auth] = chan, single_connection return self._auth_proto, remote_auth, screen def detach(self, chan: 'SSHChannel') -> bool: """Detach a channel from this listener""" try: remote_auth = self._remote_auth.pop(chan) del self._channel[remote_auth] except KeyError: pass return not bool(self._remote_auth) async def forward_connection(self) -> SSHX11ClientForwarder: """Forward an incoming connection to the local X server""" peer: SSHForwarder try: _, peer = await self._connect_coro(SSHForwarder, *self._connect_args) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHX11ClientForwarder(self, peer) def validate_auth(self, remote_auth: bytes) -> bytes: """Validate client auth and enforce single connection flag""" chan, single_connection = self._channel[remote_auth] if single_connection: del self._channel[remote_auth] del self._remote_auth[chan] return self._local_auth class SSHX11ServerListener: """Server listener used to forward X11 connections""" def __init__(self, tcp_listener: SSHListener, display: str): self._tcp_listener = tcp_listener self._display = display self._channels: Set[object] = set() def attach(self, chan: 'SSHChannel', screen: int) -> str: """Attach a channel to this listener and return its display""" self._channels.add(chan) return f'{self._display}.{screen}' def detach(self, chan: 'SSHChannel') -> bool: """Detach a channel from this listener""" try: self._channels.remove(chan) except KeyError: pass if not self._channels: self._tcp_listener.close() return True else: return False def get_xauth_path(auth_path: Optional[str]) -> str: """Compute the path to the Xauthority file""" if not auth_path: auth_path = os.environ.get('XAUTHORITY') if not auth_path: auth_path = str(Path('~', '.Xauthority').expanduser()) return auth_path def walk_xauth(auth_path: str) -> Iterable[SSHXAuthorityEntry]: """Walk the entries in an Xauthority file""" def _read_bytes(n: int) -> bytes: """Read exactly n bytes""" data = auth_file.read(n) if len(data) != n: raise EOFError return data def _read_uint16() -> int: """Read a 16-bit unsigned integer""" return int.from_bytes(_read_bytes(2), 'big') def _read_string() -> bytes: """Read a string""" return _read_bytes(_read_uint16()) try: with open(auth_path, 'rb') as auth_file: while True: try: family = _read_uint16() except EOFError: break try: yield SSHXAuthorityEntry(family, _read_string(), _read_string(), _read_string(), _read_string()) except EOFError: raise ValueError('Incomplete Xauthority entry') from None except OSError: pass async def lookup_xauth(loop: asyncio.AbstractEventLoop, auth_path: Optional[str], host: str, dpynum: str) -> Tuple[bytes, bytes]: """Look up Xauthority data for the specified display""" auth_path = get_xauth_path(auth_path) if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() dpynum = dpynum.encode('ascii') ipv4_addrs: Sequence[str] = [] ipv6_addrs: Sequence[str] = [] for entry in walk_xauth(auth_path): if entry.dpynum and entry.dpynum != dpynum: continue if entry.family == XAUTH_FAMILY_IPV4: if not ipv4_addrs: ipv4_addrs = await _lookup_host(loop, host, socket.AF_INET) addr = socket.inet_ntop(socket.AF_INET, entry.addr) match = addr in ipv4_addrs elif entry.family == XAUTH_FAMILY_IPV6: if not ipv6_addrs: ipv6_addrs = await _lookup_host(loop, host, socket.AF_INET6) addr = socket.inet_ntop(socket.AF_INET6, entry.addr) match = addr in ipv6_addrs elif entry.family == XAUTH_FAMILY_HOSTNAME: match = entry.addr == host.encode('idna') elif entry.family == XAUTH_FAMILY_WILD: match = True else: match = False if match: return entry.proto, entry.data logger.debug1('No xauth entry found for display: using random auth') return XAUTH_PROTO_COOKIE, os.urandom(XAUTH_COOKIE_LEN) async def update_xauth(auth_path: Optional[str], host: str, dpynum: str, auth_proto: bytes, auth_data: bytes) -> None: """Update Xauthority data for the specified display""" if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() host = host.encode('idna') dpynum = str(dpynum).encode('ascii') auth_path = get_xauth_path(auth_path) new_auth_path = auth_path + XAUTH_LOCK_SUFFIX new_file = None try: if time.time() - os.stat(new_auth_path).st_ctime > XAUTH_LOCK_DEAD: os.unlink(new_auth_path) except FileNotFoundError: pass for _ in range(XAUTH_LOCK_TRIES): try: new_file = open(new_auth_path, 'xb') except FileExistsError: await asyncio.sleep(XAUTH_LOCK_DELAY) else: break if not new_file: raise ValueError('Unable to acquire Xauthority lock') new_entry = SSHXAuthorityEntry(XAUTH_FAMILY_HOSTNAME, host, dpynum, auth_proto, auth_data) new_file.write(bytes(new_entry)) for entry in walk_xauth(auth_path): if (entry.family != new_entry.family or entry.addr != new_entry.addr or entry.dpynum != new_entry.dpynum): new_file.write(bytes(entry)) new_file.close() os.replace(new_auth_path, auth_path) async def create_x11_client_listener(loop: asyncio.AbstractEventLoop, display: str, auth_path: Optional[str]) -> \ SSHX11ClientListener: """Create a listener to accept X11 connections forwarded over SSH""" host, dpynum, _ = _parse_display(display) auth_proto, auth_data = await lookup_xauth(loop, auth_path, host, dpynum) return SSHX11ClientListener(loop, host, dpynum, auth_proto, auth_data) async def create_x11_server_listener(conn: 'SSHServerConnection', loop: asyncio.AbstractEventLoop, auth_path: Optional[str], auth_proto: bytes, auth_data: bytes) -> \ Optional[SSHX11ServerListener]: """Create a listener to forward X11 connections over SSH""" for dpynum in range(X11_DISPLAY_START, X11_MAX_DISPLAYS): try: tcp_listener = await create_tcp_forward_listener( conn, loop, conn.create_x11_connection, X11_LISTEN_HOST, X11_BASE_PORT + dpynum) except OSError: continue display = f'{X11_LISTEN_HOST}:{dpynum}' try: await update_xauth(auth_path, X11_LISTEN_HOST, str(dpynum), auth_proto, auth_data) except ValueError: tcp_listener.close() break return SSHX11ServerListener(tcp_listener, display) return None asyncssh-2.20.0/docs/000077500000000000000000000000001475467777400144125ustar00rootroot00000000000000asyncssh-2.20.0/docs/_templates/000077500000000000000000000000001475467777400165475ustar00rootroot00000000000000asyncssh-2.20.0/docs/_templates/sidebarbottom.html000066400000000000000000000007541475467777400223010ustar00rootroot00000000000000

Change Log

Contributing

API Documentation

Source on PyPI

Source on GitHub

Issue Tracker

Search

asyncssh-2.20.0/docs/_templates/sidebartop.html000066400000000000000000000001231475467777400215650ustar00rootroot00000000000000 AsyncSSH
Version {{version}}

asyncssh-2.20.0/docs/api.rst000066400000000000000000002410161475467777400157210ustar00rootroot00000000000000.. module:: asyncssh .. _API: API Documentation ***************** Overview ======== The AsyncSSH API is modeled after the new Python ``asyncio`` framework, with a :func:`create_connection` coroutine to create an SSH client and a :func:`create_server` coroutine to create an SSH server. Like the ``asyncio`` framework, these calls take a parameter of a factory which creates protocol objects to manage the connections once they are open. For AsyncSSH, :func:`create_connection` should be passed a ``client_factory`` which returns objects derived from :class:`SSHClient` and :func:`create_server` should be passed a ``server_factory`` which returns objects derived from :class:`SSHServer`. In addition, each connection will have an associated :class:`SSHClientConnection` or :class:`SSHServerConnection` object passed to the protocol objects which can be used to perform actions on the connection. For client connections, authentication can be performed by passing in a username and password or SSH keys as arguments to :func:`create_connection` or by implementing handler methods on the :class:`SSHClient` object which return credentials when the server requests them. If no credentials are provided, AsyncSSH automatically attempts to send the username of the local user and the keys found in their :file:`.ssh` subdirectory. A list of expected server host keys can also be specified, with AsyncSSH defaulting to looking for matching lines in the user's :file:`.ssh/known_hosts` file. For server connections, handlers can be implemented on the :class:`SSHServer` object to return which authentication methods are supported and to validate credentials provided by clients. Once an SSH client connection is established and authentication is successful, multiple simultaneous channels can be opened on it. This is accomplished calling methods such as :meth:`create_session() `, :meth:`create_connection() `, :meth:`create_unix_connection() `, :meth:`create_tun() `, and :meth:`create_tap() ` on the :class:`SSHClientConnection` object. The client can also set up listeners on remote TCP ports and UNIX domain sockets by calling :meth:`create_server() ` and :meth:`create_unix_server() `. All of these methods take ``session_factory`` arguments that return :class:`SSHClientSession`, :class:`SSHTCPSession`, or :class:`SSHUNIXSession` objects used to manage the channels once they are open. Alternately, channels can be opened using :meth:`open_session() `, :meth:`open_connection() `, :meth:`open_unix_connection() `, :meth:`open_tun() `, or :meth:`open_tap() `, which return :class:`SSHReader` and :class:`SSHWriter` objects that can be used to perform I/O on the channel. The methods :meth:`start_server() ` and :meth:`start_unix_server() ` can be used to set up listeners on remote TCP ports or UNIX domain sockets and get back these :class:`SSHReader` and :class:`SSHWriter` objects in a callback when new connections are opened. SSH client sessions can also be opened by calling :meth:`create_process() `. This returns a :class:`SSHClientProcess` object which has members ``stdin``, ``stdout``, and ``stderr`` which are :class:`SSHReader` and :class:`SSHWriter` objects. This API also makes it very easy to redirect input and output from the remote process to local files, pipes, sockets, or other :class:`SSHReader` and :class:`SSHWriter` objects. In cases where you just want to run a remote process to completion and get back an object containing captured output and exit status, the :meth:`run() ` method can be used. It returns an :class:`SSHCompletedProcess` with the results of the run, or can be set up to raise :class:`ProcessError` if the process exits with a non-zero exit status. It can also raise :class:`TimeoutError` if a specified timeout expires before the process exits. The client can also set up TCP port forwarding by calling :meth:`forward_local_port() ` or :meth:`forward_remote_port() ` and UNIX domain socket forwarding by calling :meth:`forward_local_path() ` or :meth:`forward_remote_path() `. Mixed forwarding from a TCP port to a UNIX domain socket or vice-versa can be set up using the functions :meth:`forward_local_port_to_path() `, :meth:`forward_local_path_to_port() `, :meth:`forward_remote_port_to_path() `, and :meth:`forward_remote_path_to_port() `. In these cases, data transfer on the channels is managed automatically by AsyncSSH whenever new connections are opened, so custom session objects are not required. Dynamic TCP port forwarding can be set up by calling :meth:`forward_socks() `. The SOCKS listener set up by AsyncSSH on the requested port accepts SOCKS connect requests and is compatible with SOCKS versions 4, 4a, and 5. Bidirectional packet forwarding at layer 2 or 3 is also supported using the functions :meth:`forward_tun() ` and :meth:`forward_tap() ` to set up tunnels between local and remote TUN or TAP interfaces. Once a tunnel is established, packets arriving on TUN/TAP interfaces on either side are sent over the tunnel and automatically sent out the TUN/TAP interface on the other side. When an SSH server receives a new connection and authentication is successful, handlers such as :meth:`session_requested() `, :meth:`connection_requested() `, :meth:`unix_connection_requested() `, :meth:`server_requested() `, and :meth:`unix_server_requested() ` on the associated :class:`SSHServer` object will be called when clients attempt to open channels or set up listeners. These methods return coroutines which can set up the requested sessions or connections, returning :class:`SSHServerSession` or :class:`SSHTCPSession` objects or handler functions that accept :class:`SSHReader` and :class:`SSHWriter` objects as arguments which manage the channels once they are open. To better support interactive server applications, AsyncSSH defaults to providing echoing of input and basic line editing capabilities when an inbound SSH session requests a pseudo-terminal. This behavior can be disabled by setting the ``line_editor`` argument to ``False`` when starting up an SSH server. When this feature is enabled, server sessions can enable or disable line mode using the :meth:`set_line_mode() ` method of :class:`SSHLineEditorChannel`. They can also enable or disable input echoing using the :meth:`set_echo() ` method. Handling of specific keys during line editing can be customized using the :meth:`register_key() ` and :meth:`unregister_key() ` methods. Each session object also has an associated :class:`SSHClientChannel`, :class:`SSHServerChannel`, or :class:`SSHTCPChannel` object passed to it which can be used to perform actions on the channel. These channel objects provide a superset of the functionality found in ``asyncio`` transport objects. In addition to the above functions and classes, helper functions for importing public and private keys can be found below under :ref:`PublicKeySupport`, exceptions can be found under :ref:`Exceptions`, supported algorithms can be found under :ref:`SupportedAlgorithms`, and some useful constants can be found under :ref:`Constants`. Main Functions ============== .. autofunction:: connect .. autofunction:: connect_reverse .. autofunction:: listen .. autofunction:: listen_reverse .. autofunction:: run_client .. autofunction:: run_server .. autofunction:: create_connection .. autofunction:: create_server .. autofunction:: get_server_host_key .. autofunction:: get_server_auth_methods .. autofunction:: scp Main Classes ============ .. autoclass:: SSHClient ================================== = General connection handlers ================================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: debug_msg_received ================================== = ======================================== = Host key validation handlers ======================================== = .. automethod:: validate_host_public_key .. automethod:: validate_host_ca_key ======================================== = ==================================== = General authentication handlers ==================================== = .. automethod:: auth_banner_received .. automethod:: auth_completed ==================================== = ========================================= = Public key authentication handlers ========================================= = .. automethod:: public_key_auth_requested ========================================= = ========================================= = Password authentication handlers ========================================= = .. automethod:: password_auth_requested .. automethod:: password_change_requested .. automethod:: password_changed .. automethod:: password_change_failed ========================================= = ============================================ = Keyboard-interactive authentication handlers ============================================ = .. automethod:: kbdint_auth_requested .. automethod:: kbdint_challenge_received ============================================ = .. autoclass:: SSHServer ================================== = General connection handlers ================================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: debug_msg_received ================================== = =============================== = General authentication handlers =============================== = .. automethod:: begin_auth .. automethod:: auth_completed =============================== = ====================================== = GSSAPI authentication handlers ====================================== = .. automethod:: validate_gss_principal ====================================== = ========================================= = Host-based authentication handlers ========================================= = .. automethod:: host_based_auth_supported .. automethod:: validate_host_public_key .. automethod:: validate_host_ca_key .. automethod:: validate_host_based_user ========================================= = ========================================= = Public key authentication handlers ========================================= = .. automethod:: public_key_auth_supported .. automethod:: validate_public_key .. automethod:: validate_ca_key ========================================= = ======================================= = Password authentication handlers ======================================= = .. automethod:: password_auth_supported .. automethod:: validate_password .. automethod:: change_password ======================================= = ============================================ = Keyboard-interactive authentication handlers ============================================ = .. automethod:: kbdint_auth_supported .. automethod:: get_kbdint_challenge .. automethod:: validate_kbdint_response ============================================ = ========================================= = Channel session open handlers ========================================= = .. automethod:: session_requested .. automethod:: connection_requested .. automethod:: unix_connection_requested .. automethod:: server_requested .. automethod:: unix_server_requested .. automethod:: tun_requested .. automethod:: tap_requested ========================================= = Connection Classes ================== .. autoclass:: SSHClientConnection() ======================================================================= = Connection attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = =================================== = General connection methods =================================== = .. automethod:: get_extra_info .. automethod:: set_extra_info .. automethod:: set_keepalive .. automethod:: get_server_host_key .. automethod:: send_debug .. automethod:: is_closed =================================== = ====================================================================================================================================================== = Client session open methods ====================================================================================================================================================== = .. automethod:: create_session .. automethod:: open_session .. automethod:: create_process(*args, bufsize=io.DEFAULT_BUFFER_SIZE, input=None, stdin=PIPE, stdout=PIPE, stderr=PIPE, **kwargs) .. automethod:: create_subprocess(protocol_factory, *args, bufsize=io.DEFAULT_BUFFER_SIZE, input=None, stdin=PIPE, stdout=PIPE, stderr=PIPE, **kwargs) .. automethod:: run(*args, check=False, timeout=None, **kwargs) .. automethod:: start_sftp_client .. automethod:: create_ssh_connection .. automethod:: connect_ssh .. automethod:: connect_reverse_ssh .. automethod:: listen_ssh .. automethod:: listen_reverse_ssh ====================================================================================================================================================== = ====================================== = Client connection open methods ====================================== = .. automethod:: create_connection .. automethod:: open_connection .. automethod:: create_server .. automethod:: start_server .. automethod:: create_unix_connection .. automethod:: open_unix_connection .. automethod:: create_unix_server .. automethod:: start_unix_server .. automethod:: create_tun .. automethod:: create_tap .. automethod:: open_tun .. automethod:: open_tap ====================================== = =========================================== = Client forwarding methods =========================================== = .. automethod:: forward_local_port .. automethod:: forward_local_path .. automethod:: forward_local_port_to_path .. automethod:: forward_local_path_to_port .. automethod:: forward_remote_port .. automethod:: forward_remote_path .. automethod:: forward_remote_port_to_path .. automethod:: forward_remote_path_to_port .. automethod:: forward_socks .. automethod:: forward_tun .. automethod:: forward_tap =========================================== = =========================== = Connection close methods =========================== = .. automethod:: abort .. automethod:: close .. automethod:: disconnect .. automethod:: wait_closed =========================== = .. autoclass:: SSHServerConnection() ======================================================================= = Connection attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = ============================== = General connection methods ============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info .. automethod:: set_keepalive .. automethod:: send_debug .. automethod:: is_closed ============================== = ============================================ = Server authentication methods ============================================ = .. automethod:: send_auth_banner .. automethod:: set_authorized_keys .. automethod:: get_key_option .. automethod:: check_key_permission .. automethod:: get_certificate_option .. automethod:: check_certificate_permission ============================================ = ====================================== = Server connection open methods ====================================== = .. automethod:: create_connection .. automethod:: open_connection .. automethod:: create_unix_connection .. automethod:: open_unix_connection ====================================== = ===================================== = Server channel creation methods ===================================== = .. automethod:: create_server_channel .. automethod:: create_tcp_channel .. automethod:: create_unix_channel .. automethod:: create_tuntap_channel ===================================== = =========================== = Connection close methods =========================== = .. automethod:: abort .. automethod:: close .. automethod:: disconnect .. automethod:: wait_closed =========================== = .. autoclass:: SSHClientConnectionOptions() .. autoclass:: SSHServerConnectionOptions() Process Classes =============== .. autoclass:: SSHClientProcess ======================================================================= = Client process attributes ======================================================================= = .. autoattribute:: channel .. autoattribute:: logger .. autoattribute:: env .. autoattribute:: command .. autoattribute:: subsystem .. autoattribute:: stdin .. autoattribute:: stdout .. autoattribute:: stderr .. autoattribute:: exit_status .. autoattribute:: exit_signal .. autoattribute:: returncode ======================================================================= = ==================================== = Other client process methods ==================================== = .. automethod:: get_extra_info .. automethod:: redirect .. automethod:: collect_output .. automethod:: communicate .. automethod:: wait .. automethod:: change_terminal_size .. automethod:: send_break .. automethod:: send_signal ==================================== = ======================================================================= = Client process close methods ======================================================================= = .. automethod:: terminate .. automethod:: kill .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ======================================================================= = .. autoclass:: SSHServerProcess ============================== = Server process attributes ============================== = .. autoattribute:: channel .. autoattribute:: logger .. autoattribute:: command .. autoattribute:: subsystem .. autoattribute:: env .. autoattribute:: term_type .. autoattribute:: term_size .. autoattribute:: term_modes .. autoattribute:: stdin .. autoattribute:: stdout .. autoattribute:: stderr ============================== = ============================== = Other server process methods ============================== = .. automethod:: get_extra_info .. automethod:: redirect ============================== = ================================ = Server process close methods ================================ = .. automethod:: exit .. automethod:: exit_with_signal .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ================================ = .. autoclass:: SSHCompletedProcess() .. autoclass:: SSHSubprocessReadPipe() ==================================== = General subprocess pipe info methods ==================================== = .. automethod:: get_extra_info ==================================== = ======================================================================= = Subprocess pipe read methods ======================================================================= = .. automethod:: pause_reading .. automethod:: resume_reading ======================================================================= = ======================================================================= = General subprocess pipe close methods ======================================================================= = .. automethod:: close ======================================================================= = .. autoclass:: SSHSubprocessWritePipe() ==================================== = General subprocess pipe info methods ==================================== = .. automethod:: get_extra_info ==================================== = ======================================================================= = Subprocess pipe write methods ======================================================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================================================= = ======================================================================= = General subprocess pipe close methods ======================================================================= = .. automethod:: abort .. automethod:: close ======================================================================= = .. autoclass:: SSHSubprocessProtocol ==================================== = General subprocess protocol handlers ==================================== = .. automethod:: connection_made .. automethod:: pipe_connection_lost ==================================== = ================================== = Subprocess protocol read handlers ================================== = .. automethod:: pipe_data_received ================================== = ================================== = Other subprocess protocol handlers ================================== = .. automethod:: process_exited ================================== = .. autoclass:: SSHSubprocessTransport ==================================== = General subprocess transport methods ==================================== = .. automethod:: get_extra_info .. automethod:: get_pid .. automethod:: get_pipe_transport .. automethod:: get_returncode .. automethod:: change_terminal_size .. automethod:: send_break .. automethod:: send_signal ==================================== = ======================================================================= = Subprocess transport close methods ======================================================================= = .. automethod:: terminate .. automethod:: kill .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ======================================================================= = Session Classes =============== .. autoclass:: SSHClientSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = ==================================== = Other client session handlers ==================================== = .. automethod:: xon_xoff_requested .. automethod:: exit_status_received .. automethod:: exit_signal_received ==================================== = .. autoclass:: SSHServerSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = =================================== = Server session open handlers =================================== = .. automethod:: pty_requested .. automethod:: shell_requested .. automethod:: exec_requested .. automethod:: subsystem_requested =================================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = ===================================== = Other server session handlers ===================================== = .. automethod:: break_received .. automethod:: signal_received .. automethod:: terminal_size_changed ===================================== = .. autoclass:: SSHTCPSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = .. autoclass:: SSHUNIXSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = .. autoclass:: SSHTunTapSession =============================== = General session handlers =============================== = .. automethod:: connection_made .. automethod:: connection_lost .. automethod:: session_started =============================== = ============================= = General session read handlers ============================= = .. automethod:: data_received .. automethod:: eof_received ============================= = ============================== = General session write handlers ============================== = .. automethod:: pause_writing .. automethod:: resume_writing ============================== = Channel Classes =============== .. autoclass:: SSHClientChannel() ========================= = Channel attributes ========================= = .. autoattribute:: logger ========================= = =============================== = General channel info methods =============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info .. automethod:: get_environment .. automethod:: get_command .. automethod:: get_subsystem =============================== = ============================== = Client channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = Client channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ===================================== = Other client channel methods ===================================== = .. automethod:: get_exit_status .. automethod:: get_exit_signal .. automethod:: get_returncode .. automethod:: change_terminal_size .. automethod:: send_break .. automethod:: send_signal .. automethod:: kill .. automethod:: terminate ===================================== = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================= = .. autoclass:: SSHServerChannel() ======================================================================= = Channel attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = =============================== = General channel info methods =============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info .. automethod:: get_environment .. automethod:: get_command .. automethod:: get_subsystem =============================== = ================================== = Server channel info methods ================================== = .. automethod:: get_terminal_type .. automethod:: get_terminal_size .. automethod:: get_terminal_mode .. automethod:: get_terminal_modes .. automethod:: get_x11_display .. automethod:: get_agent_path ================================== = ============================== = Server channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = Server channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_stderr .. automethod:: writelines_stderr .. automethod:: write_eof ======================================= = ================================= = Other server channel methods ================================= = .. automethod:: set_xon_xoff .. automethod:: exit .. automethod:: exit_with_signal ================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================= = .. autoclass:: SSHLineEditorChannel() ============================== = Line editor methods ============================== = .. automethod:: register_key .. automethod:: unregister_key .. automethod:: set_line_mode .. automethod:: set_echo ============================== = .. autoclass:: SSHTCPChannel() ======================================================================= = Channel attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = ============================== = General channel info methods ============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info ============================== = ============================== = General channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = General channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================= = .. autoclass:: SSHUNIXChannel() ======================================================================= = Channel attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = ============================== = General channel info methods ============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info ============================== = ============================== = General channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = General channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================= = .. autoclass:: SSHTunTapChannel() ======================================================================= = Channel attributes ======================================================================= = .. autoattribute:: logger ======================================================================= = ============================== = General channel info methods ============================== = .. automethod:: get_extra_info .. automethod:: set_extra_info ============================== = ============================== = General channel read methods ============================== = .. automethod:: pause_reading .. automethod:: resume_reading ============================== = ======================================= = General channel write methods ======================================= = .. automethod:: can_write_eof .. automethod:: get_write_buffer_size .. automethod:: set_write_buffer_limits .. automethod:: write .. automethod:: writelines .. automethod:: write_eof ======================================= = ============================= = General channel close methods ============================= = .. automethod:: abort .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================= = Forwarder Classes ================= .. autoclass:: SSHForwarder() ============================== = .. automethod:: get_extra_info .. automethod:: close ============================== = Listener Classes ================ .. autoclass:: SSHAcceptor() ============================= = .. automethod:: get_addresses .. automethod:: get_port .. automethod:: close .. automethod:: wait_closed .. automethod:: update ============================= = .. autoclass:: SSHListener() =========================== = .. automethod:: get_port .. automethod:: close .. automethod:: wait_closed =========================== = Stream Classes ============== .. autoclass:: SSHReader() ============================== = .. autoattribute:: channel .. autoattribute:: logger .. automethod:: get_extra_info .. automethod:: feed_data .. automethod:: feed_eof .. automethod:: at_eof .. automethod:: read .. automethod:: readline .. automethod:: readuntil .. automethod:: readexactly ============================== = .. autoclass:: SSHWriter() ============================== = .. autoattribute:: channel .. autoattribute:: logger .. automethod:: get_extra_info .. automethod:: can_write_eof .. automethod:: drain .. automethod:: write .. automethod:: writelines .. automethod:: write_eof .. automethod:: close .. automethod:: is_closing .. automethod:: wait_closed ============================== = SFTP Support ============ .. autoclass:: SFTPClient() ======================================= = SFTP client attributes ======================================= = .. autoattribute:: logger .. autoattribute:: version .. autoattribute:: limits .. autoattribute:: supports_remote_copy ======================================= = =========================== = File transfer methods =========================== = .. automethod:: get .. automethod:: put .. automethod:: copy .. automethod:: mget .. automethod:: mput .. automethod:: mcopy .. automethod:: remote_copy =========================== = ============================================================================================================================================================================================================================== = File access methods ============================================================================================================================================================================================================================== = .. automethod:: open(path, mode='r', attrs=SFTPAttrs(), encoding='utf-8', errors='strict', block_size=SFTP_BLOCK_SIZE, max_requests=_MAX_SFTP_REQUESTS) .. automethod:: open56(path, desired_access=ACE4_READ_DATA | ACE4_READ_ATTRIBUTES, flags=FXF_OPEN_EXISTING, attrs=SFTPAttrs(), encoding='utf-8', errors='strict', block_size=SFTP_BLOCK_SIZE, max_requests=_MAX_SFTP_REQUESTS) .. automethod:: truncate .. automethod:: rename .. automethod:: posix_rename .. automethod:: remove .. automethod:: unlink .. automethod:: readlink .. automethod:: symlink .. automethod:: link .. automethod:: realpath ============================================================================================================================================================================================================================== = ======================================================= = File attribute access methods ======================================================= = .. automethod:: stat .. automethod:: lstat .. automethod:: setstat .. automethod:: statvfs .. automethod:: chown(path, uid or owner, gid or group) .. automethod:: chmod .. automethod:: utime .. automethod:: exists .. automethod:: lexists .. automethod:: getatime .. automethod:: getatime_ns .. automethod:: getmtime .. automethod:: getcrtime_ns .. automethod:: getcrtime .. automethod:: getmtime_ns .. automethod:: getsize .. automethod:: isdir .. automethod:: isfile .. automethod:: islink ======================================================= = ================================================= = Directory access methods ================================================= = .. automethod:: chdir .. automethod:: getcwd .. automethod:: mkdir(path, attrs=SFTPAttrs()) .. automethod:: makedirs(path, attrs=SFTPAttrs()) .. automethod:: rmdir .. automethod:: rmtree .. automethod:: scandir .. automethod:: readdir .. automethod:: listdir .. automethod:: glob .. automethod:: glob_sftpname ================================================= = =========================== = Cleanup methods =========================== = .. automethod:: exit .. automethod:: wait_closed =========================== = .. autoclass:: SFTPClientFile() ================================================= = .. automethod:: read .. automethod:: read_parallel .. automethod:: write .. automethod:: seek(offset, from_what=SEEK_SET) .. automethod:: tell .. automethod:: stat .. automethod:: setstat .. automethod:: statvfs .. automethod:: truncate .. automethod:: chown(uid or owner, gid or group) .. automethod:: chmod .. automethod:: utime .. automethod:: lock .. automethod:: unlock .. automethod:: fsync .. automethod:: close ================================================= = .. autoclass:: SFTPServer ============================= = SFTP server attributes ============================= = .. autoattribute:: channel .. autoattribute:: connection .. autoattribute:: env .. autoattribute:: logger ============================= = ================================== = Path remapping and display methods ================================== = .. automethod:: format_user .. automethod:: format_group .. automethod:: format_longname .. automethod:: map_path .. automethod:: reverse_map_path ================================== = ============================ = File access methods ============================ = .. automethod:: open .. automethod:: open56 .. automethod:: close .. automethod:: read .. automethod:: write .. automethod:: rename .. automethod:: posix_rename .. automethod:: remove .. automethod:: readlink .. automethod:: symlink .. automethod:: link .. automethod:: realpath ============================ = ============================= = File attribute access methods ============================= = .. automethod:: stat .. automethod:: lstat .. automethod:: fstat .. automethod:: setstat .. automethod:: fsetstat .. automethod:: statvfs .. automethod:: fstatvfs .. automethod:: lock .. automethod:: unlock ============================= = ======================== = Directory access methods ======================== = .. automethod:: mkdir .. automethod:: rmdir .. automethod:: scandir ======================== = ===================== = Cleanup methods ===================== = .. automethod:: exit ===================== = .. autoclass:: SFTPAttrs() .. autoclass:: SFTPVFSAttrs() .. autoclass:: SFTPName() .. autoclass:: SFTPLimits() .. index:: Public key and certificate support .. _PublicKeySupport: Public Key Support ================== AsyncSSH has extensive public key and certificate support. Supported public key types include DSA, RSA, and ECDSA. In addition, Ed25519 and Ed448 keys are supported if OpenSSL 1.1.1b or later is installed. Alternately, Ed25519 support is available when the libnacl package and libsodium library are installed. Supported certificate types include OpenSSH version 01 certificates for DSA, RSA, ECDSA, Ed25519, and Ed448 keys and X.509 certificates for DSA, RSA, and ECDSA keys. Support is also available for the certificate critical options of force-command and source-address and the extensions permit-X11-forwarding, permit-agent-forwarding, permit-port-forwarding, and permit-pty in OpenSSH certificates. Several public key and certificate formats are supported including PKCS#1 and PKCS#8 DER and PEM, OpenSSH, RFC4716, and X.509 DER and PEM formats. PEM and PKCS#8 password-based encryption of private keys is supported, as is OpenSSH private key encryption when the bcrypt package is installed. .. index:: Specifying private keys .. _SpecifyingPrivateKeys: Specifying private keys ----------------------- Private keys may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more private keys from. An alternate form involves passing in a list of values which can be either a reference to a private key or a tuple containing a reference to a private key and a reference to a corresponding certificate or certificate chain. Key references can either be the name of a file to load a key from, a byte string to import as a key, or an already loaded :class:`SSHKey` private key. See the function :func:`import_private_key` for the list of supported private key formats. Certificate references can be the name of a file to load a certificate from, a byte string to import as a certificate, an already loaded :class:`SSHCertificate`, or ``None`` if no certificate should be associated with the key. Whenever a filename is provided to read the private key from, an attempt is made to load a corresponding certificate or certificate chain from a file constructed by appending '-cert.pub' to the end of the name. X.509 certificates may also be provided in the same file as the private key, when using DER or PEM format. When using X.509 certificates, a list of certificates can also be provided. These certificates should form a trust chain from a user or host certificate up to some self-signed root certificate authority which is trusted by the remote system. Instead of passing tuples of keys and certificates or relying on file naming conventions for certificates, you also have the option of providing a list of keys and a separate list of certificates. In this case, AsyncSSH will automatically match up the keys with their associated certificates when they are present. New private keys can be generated using the :func:`generate_private_key` function. The resulting :class:`SSHKey` objects have methods which can then be used to export the generated keys in several formats for consumption by other tools, as well as methods for generating new OpenSSH or X.509 certificates. .. index:: Specifying public keys .. _SpecifyingPublicKeys: Specifying public keys ---------------------- Public keys may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more public keys from. An alternate form involves passing in a list of values each of which can be either the name of a file to load a key from, a byte string to import it from, or an already loaded :class:`SSHKey` public key. See the function :func:`import_public_key` for the list of supported public key formats. .. index:: Specifying certificates .. _SpecifyingCertificates: Specifying certificates ----------------------- Certificates may be passed into AsyncSSH in a variety of forms. The simplest option is to pass the name of a file to read one or more certificates from. An alternate form involves passing in a list of values each of which can be either the name of a file to load a certificate from, a byte string to import it from, or an already loaded :class:`SSHCertificate` object. See the function :func:`import_certificate` for the list of supported certificate formats. .. index:: Specifying X.509 subject names .. _SpecifyingX509Subjects: Specifying X.509 subject names ------------------------------ X.509 certificate subject names may be specified in place of public keys or certificates in authorized_keys and known_hosts files, allowing any X.509 certificate which matches that subject name to be considered a known host or authorized key. The syntax supported for this is compatible with PKIX-SSH, which adds X.509 certificate support to OpenSSH. To specify a subject name pattern instead of a specific certificate, base64-encoded certificate data should be replaced with the string 'Subject:' followed by a comma-separated list of X.509 relative distinguished name components. AsyncSSH extends the PKIX-SSH syntax to also support matching on a prefix of a subject name. To indicate this, a partial subject name can be specified which ends in ',*'. Any subject which matches the relative distinguished names listed before the ",*" will be treated as a match, even if the certificate provided has additional relative distinguished names following what was matched. .. index:: Specifying X.509 purposes .. _SpecifyingX509Purposes: Specifying X.509 purposes ------------------------- When performing X.509 certificate authentication, AsyncSSH can be passed in an allowed set of ExtendedKeyUsage purposes. Purposes are matched in X.509 certificates as OID values, but AsyncSSH also allows the following well-known purpose values to be specified by name: .. table:: :align: left ================= ================== Name OID ================= ================== serverAuth 1.3.6.1.5.5.7.3.1 clientAuth 1.3.6.1.5.5.7.3.2 secureShellClient 1.3.6.1.5.5.7.3.20 secureShellServer 1.3.6.1.5.5.7.3.21 ================= ================== Values not in the list above can be specified directly by OID as a dotted numeric string value. Either a single value or a list of values can be provided. The check succeeds if any of the specified values are present in the certificate's ExtendedKeyUsage. It will also succeed if the certificate does not contain an ExtendedKeyUsage or if the ExtendedKeyUsage contains the OID 2.5.29.37.0, which indicates the certificate can be used for any purpose. This check defaults to requiring a purpose of 'secureShellCient' for client certificates and 'secureShellServer' for server certificates and should not normally need to be changed. However, certificates which contain other purposes can be supported by providing alternate values to match against, or by passing in the purpose 'any' to disable this checking. .. index:: Specifying time values .. _SpecifyingTimeValues: Specifying time values ---------------------- When generating certificates, an optional validity interval can be specified using the ``valid_after`` and ``valid_before`` parameters to the :meth:`generate_user_certificate() ` and :meth:`generate_host_certificate() ` methods. These values can be specified in any of the following ways: * An int or float UNIX epoch time, such as what is returned by :func:`time.time`. * A :class:`datetime.datetime` value. * A string value of ``now`` to request the current time. * A string value in the form ``YYYYMMDD`` to specify an absolute date. * A string value in the form ``YYYYMMDDHHMMSS`` to specify an absolute date and time. * A time interval described in :ref:`SpecifyingTimeIntervals` which is interpreted as a relative time from now. This value can be negative to refer to times in the past or positive to refer to times in the future. Key and certificate classes/functions ------------------------------------- .. autoclass:: SSHKey() ============================================== = .. automethod:: get_algorithm .. automethod:: get_comment_bytes .. automethod:: get_comment .. automethod:: set_comment .. automethod:: get_fingerprint .. automethod:: convert_to_public .. automethod:: generate_user_certificate .. automethod:: generate_host_certificate .. automethod:: generate_x509_user_certificate .. automethod:: generate_x509_host_certificate .. automethod:: generate_x509_ca_certificate .. automethod:: export_private_key .. automethod:: export_public_key .. automethod:: write_private_key .. automethod:: write_public_key .. automethod:: append_private_key .. automethod:: append_public_key ============================================== = .. autoclass:: SSHKeyPair() ================================= = .. automethod:: get_key_type .. automethod:: get_algorithm .. automethod:: set_certificate .. automethod:: get_comment_bytes .. automethod:: get_comment .. automethod:: set_comment ================================= = .. autoclass:: SSHCertificate() ================================== = .. automethod:: get_algorithm .. automethod:: get_comment_bytes .. automethod:: get_comment .. automethod:: set_comment .. automethod:: export_certificate .. automethod:: write_certificate .. automethod:: append_certificate ================================== = .. autofunction:: generate_private_key .. autofunction:: import_private_key .. autofunction:: import_public_key .. autofunction:: import_certificate .. autofunction:: read_private_key .. autofunction:: read_public_key .. autofunction:: read_certificate .. autofunction:: read_private_key_list .. autofunction:: read_public_key_list .. autofunction:: read_certificate_list .. autofunction:: load_keypairs .. autofunction:: load_public_keys .. autofunction:: load_certificates .. autofunction:: load_pkcs11_keys .. autofunction:: load_resident_keys .. autofunction:: set_default_skip_rsa_key_validation .. index:: SSH agent support .. _SSHAgentSupport: SSH Agent Support ================= AsyncSSH supports the ability to use private keys managed by the OpenSSH ssh-agent on UNIX systems. It can connect via a UNIX domain socket to the agent and offload all private key operations to it, avoiding the need to read these keys into AsyncSSH itself. An ssh-agent is automatically used in :func:`create_connection` when a valid ``SSH_AUTH_SOCK`` is set in the environment. An alternate path to the agent can be specified via the ``agent_path`` argument to this function. An ssh-agent can also be accessed directly from AsyncSSH by calling :func:`connect_agent`. When successful, this function returns an :class:`SSHAgentClient` which can be used to get a list of available keys, add and remove keys, and lock and unlock access to this agent. SSH agent forwarding may be enabled when making outbound SSH connections by specifying the ``agent_forwarding`` argument when calling :func:`create_connection`, allowing processes running on the server to tunnel requests back over the SSH connection to the client's ssh-agent. Agent forwarding can be enabled when starting an SSH server by specifying the ``agent_forwarding`` argument when calling :func:`create_server`. In this case, the client's ssh-agent can be accessed from the server by passing the :class:`SSHServerConnection` as the argument to :func:`connect_agent` instead of a local path. Alternately, when an :class:`SSHServerChannel` has been opened, the :meth:`get_agent_path() ` method may be called on it to get a path to a UNIX domain socket which can be passed as the ``SSH_AUTH_SOCK`` to local applications which need this access. Any requests sent to this socket are forwarded over the SSH connection to the client's ssh-agent. .. autoclass:: SSHAgentClient() ===================================== = .. automethod:: get_keys .. automethod:: add_keys .. automethod:: add_smartcard_keys .. automethod:: remove_keys .. automethod:: remove_smartcard_keys .. automethod:: remove_all .. automethod:: lock .. automethod:: unlock .. automethod:: query_extensions .. automethod:: close .. automethod:: wait_closed ===================================== = .. autoclass:: SSHAgentKeyPair() ================================= = .. automethod:: get_key_type .. automethod:: get_algorithm .. automethod:: get_comment_bytes .. automethod:: get_comment .. automethod:: set_comment .. automethod:: remove ================================= = .. autofunction:: connect_agent .. index:: Config file support .. _ConfigFileSupport: Config File Support =================== AsyncSSH has partial support for parsing OpenSSH client and server configuration files (documented in the "ssh_config" and "sshd_config" UNIX man pages, respectively). Not all OpenSSH configuration options are applicable, so unsupported options are simply ignored. See below for the OpenSSH config options that AsyncSSH supports. AsyncSSH also supports "Host" and "Match" conditional blocks. As with the config options themselves, not all match criteria are supported, but the supported criteria should function similar to OpenSSH. AsyncSSH also supports the "Include" directive, to allow one config file trigger the loading of others. .. index:: Supported client config options .. _SupportedClientConfigOptions: Supported client config options ------------------------------- The following OpenSSH client config options are currently supported: | AddressFamily | BindAddress | CanonicalDomains | CanonicalizeFallbackLocal | CanonicalizeHostname | CanonicalizeMaxDots | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | CertificateFile | ChallengeResponseAuthentication | Ciphers | Compression | ConnectTimeout | EnableSSHKeySign | ForwardAgent | ForwardX11Trusted | GlobalKnownHostsFile | GSSAPIAuthentication | GSSAPIDelegateCredentials | GSSAPIKeyExchange | HostbasedAuthentication | HostKeyAlgorithms | HostKeyAlias | Hostname | IdentityAgent | IdentityFile | KbdInteractiveAuthentication | KexAlgorithms | MACs | PasswordAuthentication | PreferredAuthentications | Port | ProxyCommand | ProxyJump | PubkeyAuthentication | RekeyLimit | RemoteCommand | RequestTTY | SendEnv | ServerAliveCountMax | ServerAliveInterval | SetEnv | TCPKeepAlive | User | UserKnownHostsFile For the "Match" conditional, the following criteria are currently supported: | All | Canonical | Exec | Final | Host | LocalUser | OriginalHost | User .. warning:: When instantiating :class:`SSHClientConnectionOptions` objects manually within an asyncio task, you may block the event loop if the options refer to a config file with "Match Exec" directives which don't return immediate results. In such cases, the asyncio `run_in_executor()` function should be used. This is taken care of automatically when options objects are created by AsyncSSH APIs such as :func:`connect` and :func:`listen`. Match criteria can be negated by prefixing the criteria name with '!'. This will negate the criteria and causing the match block to be evaluated only if the negated criteria all fail to match. The following client config token expansions are currently supported: .. table:: :align: left ===== ============================================================ Token Expansion ===== ============================================================ %% Literal '%' %C SHA-1 Hash of connection info (local host, host, port, user) %d Local user's home directory %h Remote host %i Local uid (UNIX-only) %L Short local hostname (without the domain) %l Local hostname (including the domain) %n Original remote host %p Remote port %r Remote username %u Local username ===== ============================================================ These expansions are available in the values of the following config options: | CertificateFile | IdentityAgent | IdentityFile | RemoteCommand .. index:: Supported server config options .. _SupportedServerConfigOptions: Supported server config options ------------------------------- The following OpenSSH server config options are currently supported: | AddressFamily | AuthorizedKeysFile | AllowAgentForwarding | BindAddress | CanonicalDomains | CanonicalizeFallbackLocal | CanonicalizeHostname | CanonicalizeMaxDots | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | ChallengeResponseAuthentication | Ciphers | ClientAliveCountMax | ClientAliveInterval | Compression | GSSAPIAuthentication | GSSAPIKeyExchange | HostbasedAuthentication | HostCertificate | HostKey | KbdInteractiveAuthentication | KexAlgorithms | LoginGraceTime | MACs | PasswordAuthentication | PermitTTY | Port | ProxyCommand | PubkeyAuthentication | RekeyLimit | TCPKeepAlive | UseDNS For the "Match" conditional, the following criteria are currently supported: | All | Canonical | Exec | Final | Address | Host | LocalAddress | LocalPort | User .. warning:: When instantiating :class:`SSHServerConnectionOptions` objects manually within an asyncio task, you may block the event loop if the options refer to a config file with "Match Exec" directives which don't return immediate results. In such cases, the asyncio `run_in_executor()` function should be used. This is taken care of automatically when options objects are created by AsyncSSH APIs such as :func:`connect` and :func:`listen`. Match criteria can be negated by prefixing the criteria name with '!'. This will negate the criteria and causing the match block to be evaluated only if the negated criteria all fail to match. The following server config token expansions are currently supported: .. table:: :align: left ===== =========== Token Expansion ===== =========== %% Literal '%' %u Username ===== =========== These expansions are available in the values of the following config options: | AuthorizedKeysFile .. index:: Specifying byte counts .. _SpecifyingByteCounts: Specifying byte counts ---------------------- A byte count may be passed into AsyncSSH as an integer value, or as a string made up of a mix of numbers followed by an optional letter of 'k', 'm', or 'g', indicating kilobytes, megabytes, or gigabytes, respectively. Multiple of these values can be included. For instance, '2.5m' means 2.5 megabytes. This could also be expressed as '2m512k' or '2560k'. .. index:: Specifying time intervals .. _SpecifyingTimeIntervals: Specifying time intervals ------------------------- A time interval may be passed into AsyncSSH as an integer or float value, or as a string made up of a mix of positive or negative numbers and the letters 'w', 'd', 'h', 'm', and 's', indicating weeks, days, hours, minutes, or seconds, respectively. Multiple of these values can be included. For instance, '1w2d3h' means 1 week, 2 days, and 3 hours. .. index:: Known hosts .. _KnownHosts: Known Hosts =========== AsyncSSH supports OpenSSH-style known_hosts files, including both plain and hashed host entries. Regular and negated host patterns are supported in plain entries. AsyncSSH also supports the ``@cert_authority`` marker to indicate keys and certificates which should be trusted as certificate authorities and the ``@revoked`` marker to indicate keys and certificates which should be explicitly reported as no longer trusted. .. index:: Specifying known hosts .. _SpecifyingKnownHosts: Specifying known hosts ---------------------- Known hosts may be passed into AsyncSSH via the ``known_hosts`` argument to :func:`create_connection`. This can be the name of a file or list of files containing known hosts, a byte string containing data in known hosts format, or an :class:`SSHKnownHosts` object which was previously imported from a string by calling :func:`import_known_hosts` or read from files by calling :func:`read_known_hosts`. In all of these cases, the host patterns in the list will be compared against the target host, address, and port being connected to and the matching trusted host keys, trusted CA keys, revoked keys, trusted X.509 certificates, revoked X.509 certificates, trusted X.509 subject names, and revoked X.509 subject names will be returned. Alternately, a function can be passed in as the ``known_hosts`` argument that accepts a target host, address, and port and returns lists containing trusted host keys, trusted CA keys, revoked keys, trusted X.509 certificates, revoked X.509 certificates, trusted X.509 subject names, and revoked X.509 subject names. If no matching is required and the caller already knows exactly what the above values should be, these seven lists can also be provided directly in the ``known_hosts`` argument. See :ref:`SpecifyingPublicKeys` for the allowed form of public key values which can be provided, :ref:`SpecifyingCertificates` for the allowed form of certificates, and :ref:`SpecifyingX509Subjects` for the allowed form of X.509 subject names. Known hosts classes/functions ----------------------------- .. autoclass:: SSHKnownHosts() ===================== = .. automethod:: match ===================== = .. autofunction:: import_known_hosts .. autofunction:: read_known_hosts .. autofunction:: match_known_hosts .. index:: Authorized keys .. _AuthorizedKeys: Authorized Keys =============== AsyncSSH supports OpenSSH-style authorized_keys files, including the cert-authority option to validate user certificates, enforcement of from and principals options to restrict key matching, enforcement of no-X11-forwarding, no-agent-forwarding, no-pty, no-port-forwarding, and permitopen options, and support for command and environment options. .. index:: Specifying authorized keys .. _SpecifyingAuthorizedKeys: Specifying authorized keys -------------------------- Authorized keys may be passed into AsyncSSH via the ``authorized_client_keys`` argument to :func:`create_server` or by calling :meth:`set_authorized_keys() ` on the :class:`SSHServerConnection` from within the :meth:`begin_auth() ` method in :class:`SSHServer`. Authorized keys can be provided as either the name of a file or list of files to read authorized keys from or an :class:`SSHAuthorizedKeys` object which was previously imported from a string by calling :func:`import_authorized_keys` or read from files by calling :func:`read_authorized_keys`. An authorized keys file may contain public keys or X.509 certificates in OpenSSH format or X.509 certificate subject names. See :ref:`SpecifyingX509Subjects` for more information on using subject names in place of specific X.509 certificates. Authorized keys classes/functions --------------------------------- .. autoclass:: SSHAuthorizedKeys() .. autofunction:: import_authorized_keys .. autofunction:: read_authorized_keys .. index:: Logging .. _Logging: Logging ======= AsyncSSH supports logging through the standard Python `logging` package. Logging is done under the logger named `'asyncssh'` as well as a child logger named `'asyncssh.sftp'` to allow different log levels to be set for SFTP related log messages. The base AsyncSSH log level can be set using the :func:`set_log_level` function and the SFTP log level can be set using the :func:`set_sftp_log_level` function. In addition, when either of these loggers is set to level DEBUG, AsyncSSH provides fine-grained control over the level of debug logging via the :func:`set_debug_level` function. AsyncSSH also provides logger objects as members of connection, channel, stream, and process objects that automatically log additional context about the connection or channel the logger is a member of. These objects can be used by application code to output custom log information associated with a particular connection or channel. Logger objects are also provided as members of SFTP client and server objects. .. autofunction:: set_log_level .. autofunction:: set_sftp_log_level .. autofunction:: set_debug_level .. index:: Exceptions .. _Exceptions: Exceptions ========== .. autoexception:: PasswordChangeRequired .. autoexception:: BreakReceived .. autoexception:: SignalReceived .. autoexception:: TerminalSizeChanged .. autoexception:: DisconnectError .. autoexception:: CompressionError .. autoexception:: ConnectionLost .. autoexception:: HostKeyNotVerifiable .. autoexception:: IllegalUserName .. autoexception:: KeyExchangeFailed .. autoexception:: MACError .. autoexception:: PermissionDenied .. autoexception:: ProtocolError .. autoexception:: ProtocolNotSupported .. autoexception:: ServiceNotAvailable .. autoexception:: ChannelOpenError .. autoexception:: ChannelListenError .. autoexception:: ProcessError .. autoexception:: TimeoutError .. autoexception:: SFTPError .. autoexception:: SFTPEOFError .. autoexception:: SFTPNoSuchFile .. autoexception:: SFTPPermissionDenied .. autoexception:: SFTPFailure .. autoexception:: SFTPBadMessage .. autoexception:: SFTPNoConnection .. autoexception:: SFTPConnectionLost .. autoexception:: SFTPOpUnsupported .. autoexception:: SFTPInvalidHandle .. autoexception:: SFTPNoSuchPath .. autoexception:: SFTPFileAlreadyExists .. autoexception:: SFTPWriteProtect .. autoexception:: SFTPNoMedia .. autoexception:: SFTPNoSpaceOnFilesystem .. autoexception:: SFTPQuotaExceeded .. autoexception:: SFTPUnknownPrincipal .. autoexception:: SFTPLockConflict .. autoexception:: SFTPDirNotEmpty .. autoexception:: SFTPNotADirectory .. autoexception:: SFTPInvalidFilename .. autoexception:: SFTPLinkLoop .. autoexception:: SFTPCannotDelete .. autoexception:: SFTPInvalidParameter .. autoexception:: SFTPFileIsADirectory .. autoexception:: SFTPByteRangeLockConflict .. autoexception:: SFTPByteRangeLockRefused .. autoexception:: SFTPDeletePending .. autoexception:: SFTPFileCorrupt .. autoexception:: SFTPOwnerInvalid .. autoexception:: SFTPGroupInvalid .. autoexception:: SFTPNoMatchingByteRangeLock .. autoexception:: KeyImportError .. autoexception:: KeyExportError .. autoexception:: KeyEncryptionError .. autoexception:: KeyGenerationError .. autoexception:: ConfigParseError .. index:: Supported algorithms .. _SupportedAlgorithms: Supported Algorithms ==================== Algorithms can be specified as either a list of exact algorithm names or as a string of comma-separated algorithm names that may optionally include wildcards. An '*' in a name matches zero or more characters and a '?' matches exactly one character. When specifying algorithms as a string, it can also be prefixed with '^' to insert the matching algorithms in front of the default algorithms of that type, a '+' to insert the matching algorithms after the default algorithms, or a '-' to return the default algorithms with the matching algorithms removed. .. index:: Key exchange algorithms .. _KexAlgs: Key exchange algorithms ----------------------- The following are the default key exchange algorithms currently supported by AsyncSSH: | gss-curve25519-sha256 | gss-curve448-sha512 | gss-nistp521-sha512 | gss-nistp384-sha384 | gss-nistp256-sha256 | gss-1.3.132.0.10-sha256 | gss-gex-sha256 | gss-group14-sha256 | gss-group15-sha512 | gss-group16-sha512 | gss-group17-sha512 | gss-group18-sha512 | gss-group14-sha1 | mlkem768x25519-sha256 | mlkem768nistp256-sha256 | mlkem1024nistp384-sha384 | sntrup761x25519-sha512 | sntrup761x25519-sha512\@openssh.com | curve25519-sha256 | curve25519-sha256\@libssh.org | curve448-sha512 | ecdh-sha2-nistp521 | ecdh-sha2-nistp384 | ecdh-sha2-nistp256 | ecdh-sha2-1.3.132.0.10 | diffie-hellman-group-exchange-sha256 | diffie-hellman-group14-sha256 | diffie-hellman-group15-sha512 | diffie-hellman-group16-sha512 | diffie-hellman-group17-sha512 | diffie-hellman-group18-sha512 | diffie-hellman-group14-sha256\@ssh.com | diffie-hellman-group14-sha1 | rsa2048-sha256 The following key exchange algorithms are supported by AsyncSSH, but disabled by default: | gss-gex-sha1 | gss-group1-sha1 | diffie-hellman-group-exchange-sha224\@ssh.com | diffie-hellman-group-exchange-sha384\@ssh.com | diffie-hellman-group-exchange-sha512\@ssh.com | diffie-hellman-group-exchange-sha1 | diffie-hellman-group14-sha224\@ssh.com | diffie-hellman-group15-sha256\@ssh.com | diffie-hellman-group15-sha384\@ssh.com | diffie-hellman-group16-sha384\@ssh.com | diffie-hellman-group16-sha512\@ssh.com | diffie-hellman-group18-sha512\@ssh.com | diffie-hellman-group1-sha1 | rsa1024-sha1 GSS authentication support is only available when the gssapi package is installed on UNIX or the pywin32 package is installed on Windows. Curve25519 and Curve448 support is available when OpenSSL 1.1.1 or later is installed. Alternately, Curve25519 is available when the libnacl package and libsodium library are installed. SNTRUP support is available when the Open Quantum Safe (liboqs) dynamic library is installed. .. index:: Encryption algorithms .. _EncryptionAlgs: Encryption algorithms --------------------- The following are the default encryption algorithms currently supported by AsyncSSH: | chacha20-poly1305\@openssh.com | aes256-gcm\@openssh.com | aes128-gcm\@openssh.com | aes256-ctr | aes192-ctr | aes128-ctr The following encryption algorithms are supported by AsyncSSH, but disabled by default: | aes256-cbc | aes192-cbc | aes128-cbc | 3des-cbc | blowfish-cbc | cast128-cbc | seed-cbc\@ssh.com | arcfour256 | arcfour128 | arcfour Chacha20-Poly1305 support is available when either OpenSSL 1.1.1b or later or the libnacl package and libsodium library are installed. .. index:: MAC algorithms .. _MACAlgs: MAC algorithms -------------- The following are the default MAC algorithms currently supported by AsyncSSH: | umac-64-etm\@openssh.com | umac-128-etm\@openssh.com | hmac-sha2-256-etm\@openssh.com | hmac-sha2-512-etm\@openssh.com | hmac-sha1-etm\@openssh.com | umac-64\@openssh.com | umac-128\@openssh.com | hmac-sha2-256 | hmac-sha2-512 | hmac-sha1 | hmac-sha256-2\@ssh.com | hmac-sha224\@ssh.com | hmac-sha256\@ssh.com | hmac-sha384\@ssh.com | hmac-sha512\@ssh.com The following MAC algorithms are supported by AsyncSSH, but disabled by default: | hmac-md5-etm\@openssh.com | hmac-sha2-256-96-etm\@openssh.com | hmac-sha2-512-96-etm\@openssh.com | hmac-sha1-96-etm\@openssh.com | hmac-md5-96-etm\@openssh.com | hmac-md5 | hmac-sha2-256-96 | hmac-sha2-512-96 | hmac-sha1-96 | hmac-md5-96 UMAC support is only available when the nettle library is installed. .. index:: Compression algorithms .. _CompressionAlgs: Compression algorithms ---------------------- The following are the default compression algorithms currently supported by AsyncSSH: | zlib\@openssh.com | none The following compression algorithms are supported by AsyncSSH, but disabled by default: | zlib .. index:: Signature algorithms .. _SignatureAlgs: Signature algorithms -------------------- The following are the default public key signature algorithms currently supported by AsyncSSH: | x509v3-ssh-ed25519 | x509v3-ssh-ed448 | x509v3-ecdsa-sha2-nistp521 | x509v3-ecdsa-sha2-nistp384 | x509v3-ecdsa-sha2-nistp256 | x509v3-ecdsa-sha2-1.3.132.0.10 | x509v3-rsa2048-sha256 | x509v3-ssh-rsa | sk-ssh-ed25519\@openssh.com | sk-ecdsa-sha2-nistp256\@openssh.com | webauthn-sk-ecdsa-sha2-nistp256\@openssh.com | ssh-ed25519 | ssh-ed448 | ecdsa-sha2-nistp521 | ecdsa-sha2-nistp384 | ecdsa-sha2-nistp256 | ecdsa-sha2-1.3.132.0.10 | rsa-sha2-256 | rsa-sha2-512 | ssh-rsa-sha224\@ssh.com | ssh-rsa-sha256\@ssh.com | ssh-rsa-sha384\@ssh.com | ssh-rsa-sha512\@ssh.com | ssh-rsa The following public key signature algorithms are supported by AsyncSSH, but disabled by default: | x509v3-ssh-dss | ssh-dss .. index:: Public key & certificate algorithms .. _PublicKeyAlgs: Public key & certificate algorithms ----------------------------------- The following are the default public key and certificate algorithms currently supported by AsyncSSH: | x509v3-ssh-ed25519 | x509v3-ssh-ed448 | x509v3-ecdsa-sha2-nistp521 | x509v3-ecdsa-sha2-nistp384 | x509v3-ecdsa-sha2-nistp256 | x509v3-ecdsa-sha2-1.3.132.0.10 | x509v3-rsa2048-sha256 | x509v3-ssh-rsa | sk-ssh-ed25519-cert-v01\@openssh.com | sk-ecdsa-sha2-nistp256-cert-v01\@openssh.com | ssh-ed25519-cert-v01\@openssh.com | ssh-ed448-cert-v01\@openssh.com | ecdsa-sha2-nistp521-cert-v01\@openssh.com | ecdsa-sha2-nistp384-cert-v01\@openssh.com | ecdsa-sha2-nistp256-cert-v01\@openssh.com | ecdsa-sha2-1.3.132.0.10-cert-v01\@openssh.com | rsa-sha2-256-cert-v01\@openssh.com | rsa-sha2-512-cert-v01\@openssh.com | ssh-rsa-cert-v01\@openssh.com | sk-ssh-ed25519\@openssh.com | sk-ecdsa-sha2-nistp256\@openssh.com | ssh-ed25519 | ssh-ed448 | ecdsa-sha2-nistp521 | ecdsa-sha2-nistp384 | ecdsa-sha2-nistp256 | ecdsa-sha2-1.3.132.0.10 | rsa-sha2-256 | rsa-sha2-512 | ssh-rsa-sha224\@ssh.com | ssh-rsa-sha256\@ssh.com | ssh-rsa-sha384\@ssh.com | ssh-rsa-sha512\@ssh.com | ssh-rsa The following public key and certificate algorithms are supported by AsyncSSH, but disabled by default: | x509v3-ssh-dss | ssh-dss-cert-v01\@openssh.com | ssh-dss Ed25519 and Ed448 support is available when OpenSSL 1.1.1b or later is installed. Alternately, Ed25519 is available when the libnacl package and libsodium library are installed. .. index:: Constants .. _Constants: Constants ========= .. index:: Disconnect reasons .. _DisconnectReasons: Disconnect reasons ------------------ The following values defined in section 11.1 of :rfc:`4253#section-11.1` can be specified as disconnect reason codes: | DISC_HOST_NOT_ALLOWED_TO_CONNECT | DISC_PROTOCOL_ERROR | DISC_KEY_EXCHANGE_FAILED | DISC_RESERVED | DISC_MAC_ERROR | DISC_COMPRESSION_ERROR | DISC_SERVICE_NOT_AVAILABLE | DISC_PROTOCOL_VERSION_NOT_SUPPORTED | DISC_HOST_KEY_NOT_VERIFIABLE | DISC_CONNECTION_LOST | DISC_BY_APPLICATION | DISC_TOO_MANY_CONNECTIONS | DISC_AUTH_CANCELLED_BY_USER | DISC_NO_MORE_AUTH_METHODS_AVAILABLE | DISC_ILLEGAL_USER_NAME .. index:: Channel open failure reasons .. _ChannelOpenFailureReasons: Channel open failure reasons ---------------------------- The following values defined in section 5.1 of :rfc:`4254#section-5.1` can be specified as channel open failure reason codes: | OPEN_ADMINISTRATIVELY_PROHIBITED | OPEN_CONNECT_FAILED | OPEN_UNKNOWN_CHANNEL_TYPE | OPEN_RESOURCE_SHORTAGE In addition, AsyncSSH defines the following channel open failure reason codes: | OPEN_REQUEST_X11_FORWARDING_FAILED | OPEN_REQUEST_PTY_FAILED | OPEN_REQUEST_SESSION_FAILED .. index:: SFTP error codes .. _SFTPErrorCodes: SFTP error codes ---------------- The following values defined in section 9.1 of the `SSH File Transfer Protocol Internet Draft `_ can be specified as SFTP error codes: .. table:: :align: left =============================== ==================== Error code Minimum SFTP version =============================== ==================== FX_OK 3 FX_EOF 3 FX_NO_SUCH_FILE 3 FX_PERMISSION_DENIED 3 FX_FAILURE 3 FX_BAD_MESSAGE 3 FX_NO_CONNECTION 3 FX_CONNECTION_LOST 3 FX_OP_UNSUPPORTED 3 FX_INVALID_HANDLE 4 FX_NO_SUCH_PATH 4 FX_FILE_ALREADY_EXISTS 4 FX_WRITE_PROTECT 4 FX_NO_MEDIA 4 FX_NO_SPACE_ON_FILESYSTEM 5 FX_QUOTA_EXCEEDED 5 FX_UNKNOWN_PRINCIPAL 5 FX_LOCK_CONFLICT 5 FX_DIR_NOT_EMPTY 6 FX_NOT_A_DIRECTORY 6 FX_INVALID_FILENAME 6 FX_LINK_LOOP 6 FX_CANNOT_DELETE 6 FX_INVALID_PARAMETER 6 FX_FILE_IS_A_DIRECTORY 6 FX_BYTE_RANGE_LOCK_CONFLICT 6 FX_BYTE_RANGE_LOCK_REFUSED 6 FX_DELETE_PENDING 6 FX_FILE_CORRUPT 6 FX_OWNER_INVALID 6 FX_GROUP_INVALID 6 FX_NO_MATCHING_BYTE_RANGE_LOCK 6 =============================== ==================== .. index:: Extended data types .. _ExtendedDataTypes: Extended data types ------------------- The following values defined in section 5.2 of :rfc:`4254#section-5.2` can be specified as SSH extended channel data types: | EXTENDED_DATA_STDERR .. index:: POSIX terminal modes .. _PTYModes: POSIX terminal modes -------------------- The following values defined in section 8 of :rfc:`4254#section-8` can be specified as PTY mode opcodes: | PTY_OP_END | PTY_VINTR | PTY_VQUIT | PTY_VERASE | PTY_VKILL | PTY_VEOF | PTY_VEOL | PTY_VEOL2 | PTY_VSTART | PTY_VSTOP | PTY_VSUSP | PTY_VDSUSP | PTY_VREPRINT | PTY_WERASE | PTY_VLNEXT | PTY_VFLUSH | PTY_VSWTCH | PTY_VSTATUS | PTY_VDISCARD | PTY_IGNPAR | PTY_PARMRK | PTY_INPCK | PTY_ISTRIP | PTY_INLCR | PTY_IGNCR | PTY_ICRNL | PTY_IUCLC | PTY_IXON | PTY_IXANY | PTY_IXOFF | PTY_IMAXBEL | PTY_ISIG | PTY_ICANON | PTY_XCASE | PTY_ECHO | PTY_ECHOE | PTY_ECHOK | PTY_ECHONL | PTY_NOFLSH | PTY_TOSTOP | PTY_IEXTEN | PTY_ECHOCTL | PTY_ECHOKE | PTY_PENDIN | PTY_OPOST | PTY_OLCUC | PTY_ONLCR | PTY_OCRNL | PTY_ONOCR | PTY_ONLRET | PTY_CS7 | PTY_CS8 | PTY_PARENB | PTY_PARODD | PTY_OP_ISPEED | PTY_OP_OSPEED asyncssh-2.20.0/docs/changes.rst000066400000000000000000003457561475467777400166000ustar00rootroot00000000000000.. currentmodule:: asyncssh Change Log ========== Release 2.20.0 (17 Feb 2025) ---------------------------- * Added support for specifying an explicit path when configuring agent forwarding. Thanks go to Aleksandr Ilin for pointing out that this options supports more than just a boolean value. * Added support for environment variable expansion in SSH config, for options which support percent expansion. * Added a new begin_auth callback in SSHClient, reporting the username being sent during SSH client authentication. This can be useful when the user is conditionally set via an SSH config file. * Improved strict-kex interoperability during re-keying. Thanks go to GitHub user emeryalden for reporting this issue and helping to track down the source of the problem. * Updated SFTP max_requests default to reduce memory usage when using large block sizes. * Updated testing to add Python 3.13 and drop Python 3.7, avoiding deprecation warnings from the cryptography package. * Fixed unit test issues under Windows, allowing unit tests to run on Windows on all supported versions of Python. * Fixed a couple of issues with Python 3.14. Thanks go to Georg Sauthoff for initially reporting this. Release 2.19.0 (12 Dec 2024) ---------------------------- * Added support for WebAuthN authentication with U2F security keys, allowing non-admin Windows users to use these keys for authentication. Previously, authentication with U2F keys worked on Windows, but only for admin users. * Added support for hostname canonicalization, compatible with the configuration parameters used in OpenSSH, as well as support for the "canonical" and "final" match keywords and negation support for match. Thanks go to GitHub user commonism who suggested this and provided a proposed implementation for negation. * Added client and server support for SFTP copy-data extension and a new SFTP remote_copy() function which allows data to be moved between two remote files without downloading and re-uploading the data. Thanks go to Ali Khosravi for suggesting this addition. * Moved project metadata from setup.py to pyproject.toml. Thanks go to Marc Mueller for contributing this. * Updated SSH connection to keep strong references to outstanding tasks, to avoid potential issues with the garbage collector while the connection is active. Thanks go to GitHub user Birnendampf for pointing out this potential issue and suggesting a simple fix. * Fixed some issues with block_size argument in SFTP copy functions. Thanks go to Krzysztof Kotlenga for finding and reporting these issues. * Fixed an import error when fido2 package wasn't available. Thanks go to GitHub user commonism for reporting this issue. Release 2.18.0 (26 Oct 2024) ---------------------------- * Added support for post-quantum ML-KEM key exchange algorithms, interoperable with OpenSSH 9.9. * Added support for the OpenSSH "limits" extension, allowing the client to query server limits such as the maximum supported read and write sizes. The client will automatically default to the reported maximum size on servers that support this extension. * Added more ways to specify environment variables via the `env` option. Sequences of either 'key=value' strings or (key, value) tuples are now supported, in addition to a dict. * Added support for getting/setting environment variables as byte strings on platforms which support it. Previously, only Unicode strings were accepted and they were always encoded on the wire using UTF-8. * Added support for non-TCP sockets (such as a socketpair) as the `sock` parameter in connect calls. Thanks go to Christian Wendt for reporting this problem and proposing a fix. * Changed compression to be disabled by default to avoid it becoming a performance bottleneck on high-bandwidth connections. This now also matches the OpenSSH default. * Improved speed of parallelized SFTP reads when read-ahead goes beyond the end of the file. Thanks go to Maximilian Knespel for reporting this issue and providing performance measurements on the code before and after the change. * Improved cancellation handling during SCP transfers. * Improved support for selecting the currently available security key when the application lists multiple keys to try. Thanks go to GitHub user zanda8893 for reporting the issue and helping to work out the details of the problem. * Improved handling of reverse DNS failures in host-based authentication. Thanks go to GitHub user xBiggs for suggesting this change. * Improved debug logging of byte strings with non-printable characters. * Switched to using an executor on GSSAPI calls to avoid blocking the event loop. * Fixed handling of "UserKnownHostsFile none" in config files. This previously caused it to use the default known hosts, rather than disabling known host checking. * Fixed a runtime warning about not awaiting a coroutine in unit tests. * Fixed a unit test failure on Windows when calling abort on a transport. * Fixed a problem where a "MAC verification failed" error was sometimes sent on connection close. * Fixed SSHClientProcess code to not raise a runtime exception when waiting more than once for a process to finish. Thanks go to GitHub user starflows for reporting this issue. * Handled an error when attempting to import older verions of pyOpenSSL. Thanks go to Maximilian Knespel for reporting this issue and testing the fix. * Updated simple_server example code to switch from crypt to bcrypt, since crypt has been removed in Python 3.13. Thanks go to Colin Watson for providing this update. Release 2.17.0 (2 Sep 2024) --------------------------- * Added support for specifying a per-connection credential store for GSSAPI authentication. Thanks go to GitHub user zarganum for suggesting this feature and proposing a detailed design. * Fixed a regression introduced in AsyncSSH 2.15.0 which could cause connections to be closed with an uncaught exception when a session on the connection was closed. Thanks go to Wilson Conley for being the first to help reproduce this issue, and others who also helped to confirm the fix. * Added a workaround where getaddrinfo() on some systems may return duplicate entries, causing bind() to fail when opening a listener. Thanks go to Colin Watson for reporting this issue and suggesting a fix. * Relaxed padding length check on OpenSSH private keys to provide better compatibility with keys generated by PuTTYgen. * Improved documentation on SSHClient and SSHServer classes to explain when they are created and their relationship to the SSHClientConnection and SSHServerConnection classes. * Updated examples to use Python 3.7 and made some minor improvements. Release 2.16.0 (17 Aug 2024) ---------------------------- * Added client and server support for the OpenSSH "hostkeys" extension. When using known_hosts, clients can provide a handler which will be called with the changes between the keys currently trusted in the client's known hosts and those available on the server. On the server side, an application can choose whether or not to enable the sending of this host key information. Thanks go to Matthijs Kooijman for getting me to take another look at how this might be supported. * Related to the above, AsyncSSH now allows the configuration of multiple server host keys of the same type when the send_server_host_keys option is enabled. Only the first key of each type will be used in the SSH handshake, but the others can appear in the list of supported host keys for clients to begin trusting, allowing for smoother key rotation. * Fixed logging and typing issues in SFTP high-level copy functions. A mix of bytes, str, and PurePath entries are now supported in places where a list of file paths is allowed, and the type signatures have been updated to reflect that the functions accept either a single path or a list of paths. Thanks go to GitHub user eyalgolan1337 for reporting these issues. * Improved typing on SFTP listdir() function. Thanks go to Tim Stumbaugh for contributing this change. * Reworked the config file parser to improve on a previous fix related to handling key/value pairs with an equals delimiter. * Improved handling of ciphers deprecated in cryptography 43.0.0. Thanks go to Guillaume Mulocher for reporting this issue. * Improved support for use of Windows pathnames in ProxyCommand. Thanks go to GitHub user chipolux for reporting this issue and investigating the existing OpenSSH parsing behavior. Release 2.15.0 (3 Jul 2024) --------------------------- * Added experimental support for tunneling of TUN/TAP network interfaces on Linux and macOS, allowing for either automatic packet forwarding or explicit reading and writing of packets sent through the tunnel by the application. Both callback and stream APIs are available. * Added support for forwarding terminal size and terminal size changes when stdin on an SSHServerProcess is redirected to a local TTY. * Added support for multiple tunnel/ProxyJump hosts. Thanks go to Adam Martin for suggesting this enhancement and proposing a solution. * Added support for OpenSSH lsetstat SFTP extension to set attributes on symbolic links on platforms which support that and use it to improve symlink handling in the SFTP get, put, and copy methods. In addition, a follow_symlinks option has been added on various SFTPClient methods which get and set these attributes. Thanks go to GitHub user eyalgolan1337 for reporting this issue. * Added support for password and passphrase arguments to be a callable or awaitable, called when performing authentication or loading encrypted private keys. Thanks go to GitHub user goblin for suggesting this enhancement. * Added support for proper flow control when using AsyncFileWriter or StreamWriter classes to do SSH process redirection. Thanks go to Benjy Wiener for reporting this issue and providing feedback on the fix. * Added is_closed() method SSHClientConnection/SSHServerConnection to return whether the associated network connection is closed or not. * Added support for setting and matching tags in OpenSSH config files. * Added an example of using "await" in addition to "async with" when opening a new SSHClientConnection. Thanks go to Michael Davis for suggesting this added documentation. * Improved handling CancelledError in SCP, avoiding an issue where AsyncSSH could sometimes get stuck waiting for the channel to close. Thanks go to Max Orlov for reporting the problem and providing code to reproduce it. * Fixed a regression from 2.14.1 related to rekeying an SSH connection when there's acitivty on the connection in the middle of rekeying. Thanks go to GitHub user eyalgolan1337 for helping to narrow down this problem and test the fix. * Fixed a problem with process redirection when a close is received without a preceding EOF. Thanks go to GitHub user xuoguoto who helped to provide sample scripts and ran tests to help track this down. * Fixed the processing of paths in SFTP client symlink requests. Thanks go to André Glüpker for reporting the problem and providing test code to demonstrate it. * Fixed an OpenSSH config file parsing issue. Thanks go to Siddh Raman Pant for reporting this issue. * Worked around a bug in a user auth banner generated by the cryptlib library. Thanks go to GitHub user mmayomoar for reporting this issue and suggesting a fix. Release 2.14.2 (18 Dec 2023) ---------------------------- * Implemented "strict kex" support and other countermeasures to protect against the Terrapin Attack described in `CVE-2023-48795 `_. Thanks once again go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for identifying and reporting this vulnerability and providing detailed analysis and suggestions about proposed fixes. * Fixed config parser to properly an optional equals delimiter in all config arguments. Thanks go to Fawaz Orabi for reporting this issue. * Fixed TCP send error handling to avoid race condition when receiving incoming disconnect message. * Improved type signature in SSHConnection async context manager. Thanks go to Pieter-Jan Briers for providing this. Release 2.14.1 (8 Nov 2023) --------------------------- * Hardened AsyncSSH state machine against potential message injection attacks, described in more detail in `CVE-2023-46445 `_ and `CVE-2023-46446 `_. Thanks go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for identifying and reporting these vulnerabilities and providing detailed analysis and suggestions about the proposed fixes. * Added support for passing in a regex in readuntil in SSHReader, contributed by Oded Engel. * Added support for get_addresses() and get_port() methods on SSHAcceptor. Thanks go to Allison Karlitskaya for suggesting this feature. * Fixed an issue with AsyncFileWriter potentially writing data out of order. Thanks go to Chan Chun Wai for reporting this issue and providing code to reproduce it. * Updated testing to include Python 3.12. * Updated readthedocs integration to use YAML config file. Release 2.14.0 (30 Sep 2023) ---------------------------- * Added support for a new accept_handler argument when setting up local port forwarding, allowing the client host and port to be validated and/or logged for each new forwarded connection. An accept handler can also be returned from the server_requested function to provide this functionality when acting as a server. Thanks go to GitHub user zgxkbtl for suggesting this feature. * Added an option to disable expensive RSA private key checks when using OpenSSL 3.x. Functions that read private keys have been modified to include a new unsafe_skip_rsa_key_validation argument which can be used to avoid these additional checks, if you are loading keys from a trusted source. * Added host information into AsyncSSH exceptions when host key validation fails, and a few other improvements related to X.509 certificate validation errors. Thanks go to Peter Moore for suggesting this and providing an example. * Fixed a regression which prevented keys loaded into an SSH agent with a certificate from working correctly beginning in AsyncSSH after version 2.5.0. Thanks go to GitHub user htol for reporting this issue and suggesting the commit which caused the problem. * Fixed an issue which was triggering an internal exception when shutting down server sessions with the line editor enabled which could cause some output to be lost on exit, especially when running on Windows. Thanks go to GitHub user jerrbe for reporting this issue. * Fixed an issue in a unit test seen in Python 3.12 beta. Thanks go to Georg Sauthoff for providing this fix. * Fixed a documentation error in SSHClientConnectionOptions and SSHServerConnectionOptions. Thanks go to GitHub user bowenerchen for reporting this issue. Release 2.13.2 (21 Jun 2023) ---------------------------- * Fixed an issue with host-based authentication when using proxy_command, allowing it to be used if the caller explicitly specifies client_host. Thanks go to GitHub user yuqingm7 for reporting this issue. * Improved handling of signature algorithms for OpenSSH certificates so that RSA SHA-2 signatures will work with both older and newer versions of OpenSSH. * Worked around an issue with some Cisco SSH implementations generating invalid "ignore" packets. Thanks go to Jost Luebbe for reporting and helping to debug this issue. * Fixed unit tests to avoid errors when cryptography's version of OpenSSL disables support for SHA-1 signatures. * Fixed unit tests to avoid errors when the filesystem enforces that filenames be valid UTF-8 strings. Thanks go to Robert Schütz and Martin Weinelt for reporting this issue. * Added documentation about which config options apply when passing a string as a tunnel argument. Release 2.13.1 (18 Feb 2023) ---------------------------- * Updated type definitions for mypy 1.0.0, removing a dependency on implicit Optional types, and working around an issue that could trigger a mypy internal error. * Updated unit tests to avoid calculation of SHA-1 signatures, which are no longer allowed in cryptography 39.0.0. Release 2.13.0 (27 Dec 2022) ---------------------------- * Updated testing and coverage to drop Python 3.6 and add Python 3.11. Thanks go to GitHub user hexchain for maintaining the GitHub workflows supporting this! * Added new "recv_eof" option to not pass an EOF from a channel to a redirected target, allowing output from multiple SSH sessions to be sent and mixed with other direct output to that target. This is meant to be similar to the existing "send_eof" option which controls whether EOF on a redirect source is passed through to the SSH channel. Thanks go to Stuart Reynolds for inspiring this idea. * Added new methods to make it easy to perform forwarding between TCP ports and UNIX domain sockets. Thanks go to Alex Rogozhnikov for suggesting this use case. * Added a workaround for a problem seen on a Huawei SFTP server where it sends an invalid combination of file attribute flags. In cases where the flags are otherwise valid and the right amount of attribute data is available, AsyncSSH will ignore the invalid flags and proceed. * Fixed an issue with copying files to SFTP servers that don't support random access I/O. The potential to trigger this failyre goes back several releases, but a change in AsyncSSH 2.12 made out-of-order writes much more likely. This fix returns AsyncSSH to its previous behavior where out-of-order writes are unlikely even when taking advantage of parallel reads. Thanks go to Patrik Lindgren and Stefan Walkner for reporting this issue and helping to identify the source of the problem. * Fixed an issue when requesting remote port forwarding on a dynamically allocated port. Thanks go to Daniel Shimon for reporting this and proposing a fix. * Fixed an issue where readexactly could block indefinitely when a signal is delivered in the stream before the requested number of bytes are available. Thanks go to Artem Bezborodko for reporting this and providing a fix. * Fixed an interoperability issue with OpenSSH when using SSH certificates with RSA keys with a SHA-2 signature. Thanks go to Łukasz Siudut for reporting this. * Fixed an issue with handling "None" in ProxyCommand, GlobalKnownHostsFile, and UserKnownHostsFile config file options. Thanks go to GitHub user dtrifiro for reporting this issue and suggesting a fix. Release 2.12.0 (10 Aug 2022) ---------------------------- * Added top-level functions run_client() and run_server() which allow you to begin running an SSH client or server on an already-connected socket. This capability is also available via a new "sock" argument in the existing connect(), connect_reverse(), get_server_host_key(), and get_server_auth_methods() functions. * Added "sock" argument to listen() and listen_reverse() functions which takes an already-bound listening socket instead of a host and port to bind a new socket to. * Added support for forwarding break, signal, and terminal size updates when redirection of stdin is set up between two SSHProcess instances. * Added support for sntrup761x25519-sha512@openssh.com post-quantum key exchange algorithm. For this to be available, the Open Quantum Safe (liboqs) dynamic library must be installed. * Added "sig_alg" argument to set a signature algorithm when creating OpenSSH certificates, allowing a choice between ssh-rsa, rsa-sha2-256, and rsa-sha2-512 for certificates signed by RSA keys. * Added new read_parallel() method in SFTPClientFile which allows parallel reads to be performed from a remote file, delivering incremental results as these reads complete. Previously, large reads would automatically be parallelized, but a result was only returned after all reads completed. * Added definition of __all__ for public symbols in AsyncSSH to make pyright autocompletion work better. Thanks go to Nicolas Riebesel for providing this change. * Updated SFTP and SCP glob and copy functions to use scandir() instead of listdir() to improve efficiency. * Updated default for "ignore_encrypted" client connection option to ignore encrypted keys specified in an OpenSSH config file when no passphrase is provided, similar to what was previously done for keys with default names. * Fixed an issue when using an SSH agent with RSA keys and an X.509 certificate while requesting SHA-2 signatures. * Fixed an issue with use of expanduser() in unit tests on newer versions of Python. Thanks go to Georg Sauthoff for providing an initial version of this fix. * Fixed an issue with fallback to a Pageant agent not working properly on Windows when no agent_path or SSH_AUTH_SOCK was set. * Fixed improper escaping in readuntil(), causing certain punctuation in separator to not match properly. Thanks go to Github user MazokuMaxy for reporting this issue. * Fixed the connection close handler to properly mark channels as fully closed when the peer unexpected closes the connection, allowing exceptions to fire if an application continues to try and use the channel. Thanks go to Taha Jahangir for reporting this issue and suggesting a possible fix. * Eliminated unit testing against OpenSSH for tests involving DSA and RSA keys using SHA-1 signatures, since this support is being dropped in some distributions of OpenSSH. These tests are still performed, but using only AsyncSSH code. Thanks go to Ken Dreyer and Georg Sauthoff for reporting this issue and helping me to reproduce it. Release 2.11.0 (4 Jun 2022) --------------------------- * Made a number of improvements in SFTP glob support, with thanks to Github user LuckyDams for all the help working out these changes! * Added a new glob_sftpname() method which returns glob matches together with attribute information, avoiding the need for a caller to make separate calls to stat() on the returned results. * Switched from listdir() to scandir() to reduce the number of stat() operations required while finding matches. * Added code to remove duplicates when glob() is called with multiple patterns that match the same path. * Added a cache of directory listing and stat results to improve performance when matching patterns with overlapping paths. * Fixed an "index out of range" bug in recursive glob matching and aligned it better with results reeturned by UNIX shells. * Changed matching to ignore inaccessible or non-existent paths in a glob pattern, to allow accessible paths to be fully explored before returning an error. The error handler will now be called only if a pattern results in no matches, or if a more serious error occurs while scanning. * Changed SFTP makedirs() method to work better in cases where parts of requested path already exist but don't allow read access. As long as the entire path can be created, makedirs() will succeed, even if some directories on the path don't allow their contents to be read. Thanks go to Peter Rowlands for providing this fix. * Replaced custom Diffie Hellman implementation in AsyncSSH with the one in the cryptography package, resulting in an over 10x speedup. Thanks go to Github user iwanb for suggesting this change. * Fixed AsyncSSH to re-acquire GSS credentials when performing key renegotiation to avoid expired credentials on long-lived connections. Thanks go to Github user PromyLOPh for pointing out this issue and suggesting a fix. * Fixed GSS MIC to work properly with GSS key exchange when AsyncSSH is running as a server. This was previously fixed on the client side, but a similar fix for the server was missed. * Changed connection timeout unit tests to work better in environments where a firewall is present. Thanks go to Stefano Rivera for reporting this issue. * Improved unit tests of Windows SSPI GSSAPI module. * Improved speed of unit tests by reducing the number of key generation calls. RSA key generation in particular has gotten much more expensive in OpenSSL 3. Release 2.10.1 (16 Apr 2022) ---------------------------- * Added a workaround for a bug in dropbear which can improperly reject full-sized data packets when compression is enabled. Thanks go to Matti Niemenmaa for reporting this issue and helping to reproduce it. * Added support for "Match Exec" in config files and updated AsyncSSH API calls to do config parsing in an executor to avoid blocking the event loop if a "Match Exec" command doesn't return immediately. * Fixed an issue where settings associated with server channels set when creating a listener rather than at the time a new channel is opened were not always being applied correctly. * Fixed config file handling to be more consistent with OpenSSH, making all relative paths be evaluated relative to ~/.ssh and allowing references to config file patterns which don't match anything to only trigger a debug message rather than an error. Thanks go to Caleb Ho for reporting this issue! * Updated minimum required version of cryprography package to 3.1, to allow calls to it to be made without passing in a "backend" argument. This was missed back in the 2.9 release. Thanks go to Github users sebby97 and JavaScriptDude for reporting this issue! Release 2.10.0 (26 Mar 2022) ---------------------------- * Added new get_server_auth_methods() function which returns the set of auth methods available for a given user and SSH server. * Added support for new line_echo argument when creating a server channel which controls whether input in the line editor is echoed to the output immediately or under the control of the application, allowing more control over the ordering of input and output. * Added explicit support for RSA SHA-2 certificate algorithms. Previously, SHA-2 signatures were supported using the original ssh-rsa-cert-v01@openssh.com algorithm name, but recent versions of SSH now disable this algorithm by default, so the new SHA-2 algorithm names need to be advertised for SHA-2 signatures to work when using OpenSSH certificates. * Improved handling of config file loading when options argument is used, allowing config loading to be overridden at connect() time even if the options passed in referenced a config file. * Improved speed of unit tests by avoiding some network timeouts when connecting to invalid addresses. * Merged GitHub workflows contributed by GitHub user hexchain to run unit tests and collect code coverage information on multiple platforms and Python versions. Thanks so much for this work! * Fixed issue with GSS auth unit tests hanging on Windows. * Fixed issue with known_hosts matching when ProxyJump is being used. Thanks go to GitHub user velavokr for reporting this and helping to debug it. * Fixed type annotations for SFTP client and server open methods. Thanks go to Marat Sharafutdinov for reporting this! Release 2.9.0 (23 Jan 2022) --------------------------- * Added mypy-compatible type annotations to all AsyncSSH modules, and a "py.typed" file to signal that annotations are now available for this package. * Added experimental support for SFTP versions 4-6. While AsyncSSH still defaults to only advertising version 3 when acting as both a client and a server, applications can explicitly enable support for later versions, which will be used if both ends of the connection agree. Not all features are fully supported, but a number of useful enhancements are now available, including as users and groups specified by name, higher resolution timestamps, and more granular error reporting. * Updated documentation to make it clear that keys from a PKCS11 provider or ssh-agent will be used even when client_keys is specified, unless those sources are explicitly disabled. * Improved handling of task cancellation in AsyncSSH to avoid triggering an error of "Future exception was never retrieved". Thanks go to Krzysztof Kotlenga for reporting this issue and providing test code to reliably reproduce it. * Changed implementation of OpenSSH keepalive handler to improve interoperability with servers which don't expect a "success" response when this message is sent. Release 2.8.1 (8 Nov 2021) -------------------------- * Fixed a regression in handling of the passphrase argument used to decrypt private keys. Release 2.8.0 (3 Nov 2021) -------------------------- * Added new connect_timeout option to set a timeout which includes the time taken to open an outbound TCP connection, allowing connections to be aborted without waiting for the default socket connect timeout. The existing login_timeout option only applies after the TCP connection was established, so it could not be used for this. The support for the ConnectTimeout config file option has also been updated to use this new capability, making it more consistent with OpenSSH's behavior. * Added the ability to use the passphrase argument specified in a connect call to be used to decrypt keys used to connect to bastion hosts. Previously, this argument was only applied when making a connection to the main host and encrypted keys could only be used when they were loaded separately. * Updated AsyncSSH's "Record" class to make it more IDE-friendly when it comes to things like auto-completion. This class is used as a base class for SSHCompletedProcess and various SFTP attribute classes. Thanks go to Github user zentarim for suggesting this improvement. * Fixed a potential uncaught exception when handling forwarded connections which are immediately closed by a peer. Release 2.7.2 (15 Sep 2021) --------------------------- * Fixed a regression related to server host key selection when attempting to use a leading '+' to add algorithms to the front of the default list. * Fixed logging to properly handle SFTPName objects with string filenames. * Fixed SSH_EXT_INFO to only be sent after the first key exchange. Release 2.7.1 (6 Sep 2021) -------------------------- * Added an option to allow encrypted keys to be ignored when no passphrase is set. This behavior previously happened by default when loading keys from default locations, but now this option to load_keypairs() can be specified when loading any set of keys. * Changed loading of default keys to automatically skip key types which aren't supported due to missing dependencies. * Added the ability to specify "default" for server_host_key_algs, as a way for a client to request that its full set of default algorithms be advertised to the server, rather than just the algorithms matching keys in the client's known hosts list. Thanks go to Manfred Kaiser for suggesting this improvement. * Added support for tilde-expansion in the config file "include" directive. Thanks go to Zack Cerza for reporting this and suggesting a fix. * Improved interoperatbility of AsyncSSH SOCKS listener by sending a zero address rather than an empty hostname in the SOCKS CONNECT response. Thanks go to Github user juouy for reporting this and suggesting a fix. * Fixed a couple of issues related to sending SSH_EXT_INFO messages. * Fixed an issue with using SSHAcceptor as an async context manager. Thanks go to Paulo Costa for reporting this. * Fixed an issue where a tunnel wasn't always cleaned up properly when creating a remote listener. * Improved handling of connection drops, avoiding exceptions from being raised in some cases when the transport is abruptly closed. * Made AsyncSSH SFTP support more tolerant of file permission values with undefined bits set. Thanks go to GitHub user ccwufu for reporting this. * Added some missing key exchange algorithms in the AsyncSSH documentation. Thanks go to Jeremy Norris for noticing and reporting this. * Added support for running AsyncSSH unit tests on systems with OpenSSL 3.0 installed. Thanks go to Ken Dreyer for raising this issue and pointing out the new OpenSSL "provider" support for legacy algorithms. Release 2.7.0 (19 Jun 2021) --------------------------- * Added support for the ProxyCommand config file option and a corresponding proxy_command argument in the SSH connection options, allowing a subprocess to be used to make the connection to the SSH server. When the config option is used, it should be fully compatible with OpenSSH percent expansion in the command to run. * Added support for accessing terminal information as properties in the SSHServerProcess class. As part of this change, both the environment and terminal modes are now available as read-only mappings. Thanks again to velavokr for suggesitng this and submitting a PR with a proposed version of the change. * Fixed terminal information passed to pty_requested() callback to properly reflect requested terminal type, size, and modes. Thanks go to velavokr for reporting this issue and proposing a fix. * Fixed an edge case where a connection object might not be cleaned up properly if the connection request was cancelled before it was fully established. * Fixed an issue where some unit tests weren't properly closing connection objects before exiting. Release 2.6.0 (1 May 2021) -------------------------- * Added support for the HostKeyAlias client config option and a corresponding host_key_alias option, allowing known_hosts lookups and host certificate validation to be done against a different hoetname than what is used to make the connection. Thanks go to Pritam Baral for contributing this feature! * Added the capability to specify client channel options as connection options, allowing them to be set in a connect() call or as values in SSHClientConnectionOptions. These values will act as defaults for any sessions opened on the connection but can still be overridden via arguments in the create_session() call. * Added support for dynamically updating SSH options set up in a listen() or listen_reverse() call. A new SSHAcceptor class is now returned by these calls which has an update() method which takes the same keyword arguments as SSHClientConnectionOptions or SSHServerConnectionOptions, allowing you to update any of the options on an existing listener except those involved in setting up the listening sockets themselves. Updates will apply to future connections accepted by that listener. * Added support for a number of algorithms supported by the ssh.com Tectia SSH client/server: Key exchange: | diffie-hellman-group14-sha256\@ssh.com (enabled by default) | diffie-hellman-group14-sha224\@ssh.com (available but not default) | diffie-hellman-group15-sha256\@ssh.com | diffie-hellman-group15-sha384\@ssh.com | diffie-hellman-group16-sha384\@ssh.com | diffie-hellman-group16-sha512\@ssh.com | diffie-hellman-group18-sha512\@ssh.com HMAC: | hmac-sha256-2\@ssh.com (all enabled by default) | hmac-sha224\@ssh.com | hmac-sha256\@ssh.com | hmac-sha384\@ssh.com | hmac-sha512\@ssh.com RSA public key algorithms: | ssh-rsa-sha224\@ssh.com (all enabled by default) | ssh-rsa-sha256\@ssh.com | ssh-rsa-sha384\@ssh.com | ssh-rsa-sha512\@ssh.com Encryption: | seed-cbc\@ssh.com (available but not default) * Added a new 'ignore-failure' value to the x11_forwarding argument in create_session(). When specified, AsyncSSH will attempt to set up X11 forwarding but ignore failures, behaving as if forwarding was never requested instead of raising a ConnectionOpenError. * Extended support for replacing certificates in an SSHKeyPair, allowing alternate certificates to be used with SSH agent and PKCS11 keys. This provides a way to use X.509 certificates with an SSH agent key or OpenSSH certificates with a PKCS11 key. * Extended the config file parser to support '=' as a delimiter between keywords and arguments. While this syntax appears to be rarely used, it is supported by OpenSSH. * Updated Fido2 support to use version 0.9.1 of the fido2 package, which included some changes that were not backward compatible with 0.8.1. * Fixed problem with setting config options with percent substitutions to 'none'. Percent substitution should not be performed in this case. Thanks go to Yuqing Miao for finding and reporting this issue! * Fixed return type of filenames in SFTPClient scandir() and readlink() when the argument passed in is a Path value. Previously, the return value in this case was bytes, but that was only meant to apply when the input argument was passed as bytes. * Fixed a race condition related to closing a channel before it is fully open, preventing a client from potentially hanging forever if a session was closed while the client was still attempting to request a PTY or make other requests as part of opening the session. * Fixed a potential race condition related to making parallel calls to SFTPClient makedirs() which try to create the same directory or a common parent directory. * Fixed RFC 4716 parser to allow colons in header values. * Improved error message when AsyncSSH is unable to get the local username on a client. Thanks go to Matthew Plachter for reporting this issue. Release 2.5.0 (23 Dec 2020) --------------------------- * Added support for limiting which identities in an SSH agent will be used when making a connection, via a new "agent_identities" config option. This change also adds compatibility with the OpenSSL config file option "IdentitiesOnly". * Added support for including Subject Key Identifier and Authority Key Identifier extensions in generated X.509 certificates to better comply with RFC 5280. * Added support for makedirs() and rmtree() methods in the AsyncSSH SFTP client, as well as a new scandir() method which returns an async iterator to more efficiently process very large directories. Thanks go to Joseph Ernest for suggesting these improvements. * Significantly reworked AsyncSSH line editor support to improve its performance by several orders of magnitude on long input lines, and added a configurable maximum line length when the editor is in use to avoid potential denial-of-service attacks. This limit defaults to 1024 bytes, but with the improvements it can reasonably handle lines which are megabytes in size if needed. * Changed AsyncSSH to allow SSH agent identities to still be used when an explicit list of client keys is specified, for better compatibility with OpenSSH. The previous behavior can still be achieved by explicitly setting the agent_path option to None when setting client_keys. * Changed AsyncSSH to enforce a limit of 1024 characters on usernames when acting as a server to avoid a potential denial-of-service issue related to SASLprep username normalization. * Changed SCP implementation to explicitly yield to other coroutines when sending a large file to better share an event loop. * Fixed a few potential race conditions related to cleanup of objects during connection close. Thanks go to Thomas Léveil for reporting one of these places and suggesting a fix. * Re-applied a previous fix which was unintentionally lost to allow Pageant to be used by default on Windows. Release 2.4.2 (11 Sep 2020) --------------------------- * Fixed a potential race condition when receiving EOF right after a channel is opened. Thanks go to Alex Shafer for reporting this and helping to track down the root cause. * Fixed a couple of issues related to the error_handler and progress_handler callbacks in AsyncSSH SFTP/SCP. Thanks go to geraldnj for noticing and reporting these. * Fixed a couple of issues related to using pathlib objects with AsyncSSH SCP. Release 2.4.1 (5 Sep 2020) -------------------------- * Fixed SCP server to send back an exit status when closing the SSH channel, since the OpenSSH scp client returns this status to the shell which executed it. Thanks go to girtsf for catching this. * Fixed listeners created by forward_local_port(), forward_local_path(), and forward_socks() to automatically close when the SSH connection closes, unblocking any wait_closed() calls which are in progress. Thanks go to rmawatson for catching this. * Fixed a potential exception that could trigger when the SSH connection is closed while authentication is in progress. * Fixed tunnel connect code to properly clean up an implicitly created tunnel when a failure occurs in trying to open a connection over that tunnel. Release 2.4.0 (29 Aug 2020) --------------------------- * Added support for accessing keys through a PKCS#11 provider, allowing keys on PIV security tokens to be used directly by AsyncSSH without the need to run an SSH agent. X.509 certificates can also be retrieved from the security token and used with SSH servers which support that. * Added support for using Ed25519 and Ed448 keys in X.509 certificates, and the corresponding SSH certificate and signature algorithms. Certificates can use these keys as either subject keys or signing keys, and certificates can be generated by either AsyncSSH or by OpenSSL version 1.1.1 or later. * Added support for feed_data() and feed_eof() methods in SSHReader, mirroring methods of the same name in asyncio's StreamReader to improve interoperability between the two APIs. Thanks go to Mikhail Terekhov for suggesting this and providing an example implementation. * Updated unit tests to test interoperability with OpenSSL 1.1.1 when reading and writing Ed25519 and Ed448 public and private key files. Previously, due to lack of support in OpenSSL, AsyncSSH could only test against OpenSSH, and only in OpenSSH key formats. With OpenSSL 1.1.1, testing is now also done using PKCS#8 format. * Fixed config file parser to properly ignore all comment lines, even if the lines contain unbalanced quotes. * Removed a note about the lack of a timeout parameter in the AsyncSSH connect() method, now that it supports a login_timeout argument. Thanks go to Tomasz Drożdż for catching this. Release 2.3.0 (12 Jul 2020) --------------------------- * Added initial support for reading configuration from OpenSSH-compatible config files, when present. Both client and server configuration files are supported, but not all config options are supported. See the AsyncSSH documentation for the latest list of what client and server options are supported, as well as what match conditions and percent substitutions are understood. * Added support for the concept of only a subset of supported algorithms being enabled by default, and for the ability to use wildcards when specifying algorithm names. Also, OpenSSH's syntax of prefixing the list with '^', '+', or '-' is supported for incrementally adjusting the list of algorithms starting from the default set. * Added support for specifying a preferred list of client authentication methods, in order of preference. Previously, the order of preference was hard-coded into AsyncSSH. * Added the ability to use AsyncSSH's "password" argument on servers which are using keyboard-interactive authentication to prompt for a "passcode". Previously, this was only supported when the prompt was for a "password". * Added support for providing separate lists of private keys and certificates, rather than requiring them to be specifying together as a tuple. When this new option is used, AsyncSSH will automatically associate the private keys with their corresponding certificates if matching certificates are present in the list. * Added support for the "known_hosts" argument to accept a list of known host files, rather than just a single file. Known hosts can also be specified using the GlobalKnownHostFile and UserKnownHostFile config file options, each of which can take multiple filenames. * Added new "request_tty" option to provide finer grained control over whether AsyncSSH will request a TTY when opening new sessions. The default is to still tie this to whether a "term_type" is specified, but now that can be overridden. Supported options of "yes", "no", "force", and "auto" match the values supported by OpenSSH. * Added new "rdns_lookup" option to control whether the server does a reverse DNS of client addresses to allow matching of clients based on hostname in authorized keys and config files. When this option is disabled (the default), matches can only be based on client IP. * Added new "send_env" argument when opening a session to forward local environment variables using their existing values, augmenting the "env" argument that lets you specify remote environment variables to set and their corresponding values. * Added new "tcp_keepalive" option to control whether TCP-level keepalives are enabled or not on SSH connections. Previously, TCP keepalives were enabled unconditionally and this is still the default, but the new option provides a way to disable them. * Added support for sending and parsing client EXT_INFO messages, and for sending the "global-requests-ok" option in these messages when AsyncSSH is acting as a client. * Added support for expansion of '~' home directory expansion when specifying arguments which contain filenames. * Added support for time intervals and byte counts to optionally be specified as string values with units, allowing for values such as "1.5h" or "1h30m" instead of having to specify that as 5400 seconds. Similarly, a byte count of "1g" can be passed to indicate 1 gigabyte, rather than specifying 1073741824 bytes. * Enhanced logging to report lists of sent and received algorithms when no matching algorithm is found. Thanks go to Jeremy Schulman for suggesting this. * Fixed an interoperability issue with PKIXSSH when attempting to use X.509 certificates with a signature algorithm of "x509v3-rsa2048-sha256". * Fixed an issue with some links not working in the ReadTheDocs sidebar. Thanks go to Christoph Giese for reporting this issue. * Fixed keepalive handler to avoid leaking a timer object in some cases. Thanks go to Tom van Neerijnen for reporting this issue. Release 2.2.1 (18 Apr 2020) --------------------------- * Added optional timeout parameter to SSHClientProcess.wait() and SSHClientConnection.run() methods. * Created subclasses for SFTPError exceptions, allowing applications to more easily have distinct exception handling for different errors. * Fixed an issue in SFTP parallel I/O related to handling low-level connection failures. Thanks go to Mikhail Terekhov for reporting this issue. * Fixed an issue with SFTP file copy where a local file could sometimes be left open if an attempt to close a remote file failed. * Fixed an issue in the handling of boolean return values when SSHServer.server_requested() returns a coroutine. Thanks go to Tom van Neerijnen for contributing this fix. * Fixed an issue with passing tuples to the SFTP copy functions. Thanks go to Marc Gagné for reporting this and doing the initial analysis. Release 2.2.0 (29 Feb 2020) --------------------------- * Added support for U2F/FIDO2 security keys, with the following capabilities: * ECDSA (NISTP256) and Ed25519 key algorithms * Key generation, including control over the application and user the key is associated with and whether touch is required when using the key * Certificate generation, both as a key being signed and a CA key * Resident keys, allowing security keys to be used on multiple machines without any information being stored outside of the key * Access to and management of keys loaded in an OpenSSH ssh-agent * Support for both user and host keys and certificates * Support for "no-touch-required" option in authorized_keys files * Support for "no-touch-required" option in OpenSSH certificates * Compatibility with security key support added in OpenSSH version 8.2 * Added login timeout client option and limits on the length and number of banner lines AsyncSSH will accept prior to the SSH version header. * Improved load_keypairs() to read public key files, confirming that they are consistent with their associated private key when they are present. * Fixed issues in the SCP server related to handling filenames with spaces. * Fixed an issue with resuming reading after readuntil() returns an incomplete read. * Fixed a potential issue related to asyncio not reporting sockname/peername when a connection is closed immediately after it is opened. * Made SSHConnection a subclass of asyncio.Protocol to please type checkers. Release 2.1.0 (30 Nov 2019) --------------------------- * Added support in the SSHProcess redirect mechanism to accept asyncio StreamReader and StreamWriter objects, allowing asyncio streams to be plugged in as stdin/stdout/stderr in an SSHProcess. * Added support for key handlers in the AsyncSSH line editor to trigger signals being delivered when certain "hot keys" are hit while reading input. * Improved cleanup of unreturned connection objects when an error occurs or the connection request is canceled or times out. * Improved cleanup of SSH agent client objects to avoid triggering a false positive warning in Python 3.8. * Added an example to the documentation for how to create reverse-direction SSH client and server connections. * Made check of session objects against None explicit to avoid confusion on user-defined sessions that implement __len__ or __bool__. Thanks go to Lars-Dominik Braun for contributing this improvement! Release 2.0.1 (2 Nov 2019) -------------------------- * Some API changes which should have been included in the 2.0.0 release were missed. This release corrects that, but means that additional changes may be needed in applications moving to 2.0.1. This should hopefully be the last of such changes, but if any other issues are discovered, additional changes will be limited to 2.0.x patch releases and the API will stabilize again in the AsyncSSH 2.1 release. See the next bullet for details about the additional incompatible change. * To be consistent with other connect and listen functions, all methods on SSHClientConnection which previously returned None on listen failures have been changed to raise an exception instead. A new ChannelListenError exception will now be raised when an SSH server returns failure on a request to open a remote listener. This change affects the following SSHClientConnection methods: create_server, create_unix_server, start_server, start_unix_server, forward_remote_port, and forward_remote_path. * Restored the ability for SSHListener objects to be used as async context managers. This previously worked in AsyncSSH 1.x and was unintentionally broken in AsyncSSH 2.0.0. * Added support for a number of additional functions to be called from within an "async with" statement. These functions already returned objects capable of being async context managers, but were not decorated to allow them to be directly called from within "async with". This change applies to the top level functions create_server, listen, and listen_reverse and the SSHClientConnection methods create_server, create_unix_server, start_server, start_unix_server, forward_local_port, forward_local_path, forward_remote_port, forward_remote_path, listen_ssh, and listen_reverse_ssh, * Fixed a couple of issues in loading OpenSSH-format certificates which were missing a trailing newline. * Changed load_certificates() to allow multiple certificates to be loaded from a single byte string argument, making it more consistent with how load_certificates() works when reading from a file. Release 2.0.0 (26 Oct 2019) --------------------------- * NEW MAJOR VERSION: See below for potentially incompatible changes. * Updated AsyncSSH to use the modern async/await syntax internally, now requiring Python 3.6 or later. Those wishing to use AsyncSSH on Python 3.4 or 3.5 should stick to the AsyncSSH 1.x releases. * Changed first argument of SFTPServer constructor from an SSHServerConnection (conn) to an SSHServerChannel (chan) to allow custom SFTP server implementations to access environment variables set on the channel that SFTP is run over. Applications which subclass the SFTPServer class and implement an __init__ method will need to be updated to account for this change and pass the new argument through to the SFTPServer parent class. If the subclass has no __init__ and just uses the connection, channel, and env properties of SFTPServer to access this information, no changes should be required. * Removed deprecated "session_encoding" and "session_errors" arguments from create_server() and listen() functions. These arguments were renamed to "encoding" and "errors" back in version 1.16.0 to be consistent with other AsyncSSH APIs. * Removed get_environment(), get_command(), and get_subsystem() methods on SSHServerProcess class. This information was made available as "env", "command", and "subsystem" properties of SSHServerProcess in AsyncSSH 1.11.0. * Removed optional loop argument from all public AsyncSSH APIs, consistent with the deprecation of this argument in the asyncio package in Python 3.8. Calls will now always use the event loop which is active at the time of the call. * Removed support for non-async context managers on AsyncSSH connections and processes and SFTP client connections and file objects. Callers should use "async with" to invoke the async the context managers on these objects. * Added support for SSHAgentClient being an async context manager. To be consistent with other connect calls, connect_agent() will now raise an exception when no agent is found or a connection failure occurs, rather than logging a warning and returning None. Callers should catch OSError or ChannelOpenError exceptions rather than looking for a return value of None when calling this function. * Added set_input() and clear_input() methods on SSHLineEditorChannel to change the value of the current input line when line editing is enabled. * Added is_closing() method to the SSHChannel, SSHProcess, SSHWriter, and SSHSubprocessTransport classes. mirroring the asyncio BaseTransport and StreamWriter methods added in Python 3.7. * Added wait_closed() async method to the SSHWriter class, mirroring the asyncio StreamWriter method added in Python 3.7. Release 1.18.0 (23 Aug 2019) ---------------------------- * Added support for GSSAPI ECDH and Edwards DH key exchange algorithms. * Fixed gssapi-with-mic authentication to work with GSS key exchanges, in cases where gssapi-keyex is not supported. * Made connect_ssh and connect_reverse_ssh methods into async context managers, simplifying the syntax needed to use them to create tunneled SSH connections. * Fixed a couple of issues with known hosts matching on tunneled SSH connections. * Improved flexibility of key/certificate parser automatic format detection to properly recognize PEM even when other arbitrary text is present at the beginning of the file. With this change, the parser can also now handle mixing of multiple key formats in a single file. * Added support for OpenSSL "TRUSTED" PEM certificates. For now, no enforcement is done of the additional trust restrictions, but such certificates can be loaded and used by AsyncSSH without converting them back to regular PEM format. * Fixed some additional SFTP and SCP issues related to parsing of Windows paths with drive letters and paths with multiple colons. * Made AsyncSSH tolerant of a client which sends multiple service requests for the "ssh-userauth" service. This is needed by the Paramiko client when it tries more than one form of authentication on a connection. Release 1.17.1 (23 Jul 2019) ---------------------------- * Improved construction of file paths in SFTP to better handle native Windows source paths containing backslashes or drive letters. * Improved SFTP parallel I/O for large reads and file copies to better handle the case where a read returns less data than what was requested when not at the end of the file, allowing AsyncSSH to get back the right result even if the requested block size is larger than the SFTP server can handle. * Fixed an issue where the requested SFTP block_size wasn't used in the get, copy, mget, and mcopy functions if it was larger than the default size of 16 KB. * Fixed a problem where the list of client keys provided in an SSHClientConnectionOptions object wasn't always preserved properly across the opening of multiple SSH connections. * Changed SSH agent client code to avoid printing a warning on Windows when unable to connect to the SSH agent using the default path. A warning will be printed if the agent_path or SSH_AUTH_SOCK is explicitly set, but AsyncSSH will remain quiet if no agent path is set and no SSH agent is running. * Made AsyncSSH tolerant of unexpected authentication success/failure messages sent after authentication completes. AsyncSSH previously treated this as a protocol error and dropped the connection, while most other SSH implementations ignored these messages and allowed the connection to continue. * Made AsyncSSH tolerant of SFTP status responses which are missing error message and language tag fields, improving interoperability with servers that omit these fields. When missing, AsyncSSH treats these fields as if they were set to empty strings. Release 1.17.0 (31 May 2019) ---------------------------- * Added support for "reverse direction" SSH connections, useful to support applications like NETCONF Call Home, described in RFC 8071. * Added support for the PyCA implementation of Chacha20-Poly1305, eliminating the dependency on libnacl/libsodium to provide this functionality, as long as OpenSSL 1.1.1b or later is installed. * Restored libnacl support for Curve25519/Ed25519 on systems which have an older version of OpenSSL that doesn't have that support. This fallback also applies to Chacha20-Poly1305. * Fixed Pageant support on Windows to use the Pageant agent by default when it is available and client keys are not explicitly configured. * Disabled the use of RSA SHA-2 signatures when using the Pageant or Windows 10 OpenSSH agent on Windows, since neither of those support the signature flags options to request them. * Fixed a regression where a callable was no longer usable in the sftp_factory argument of create_server. Release 1.16.1 (30 Mar 2019) ---------------------------- * Added channel, connection, and env properties to SFTPServer instances, so connection and channel information can be used to influence the SFTP server's behavior. Previously, connection information was made available through the constructor, but channel and environment information was not. Now, all of these are available as properties on the SFTPServer instance without the need to explicitly store anything in a custom constructor. * Optimized SFTP glob matching when the glob pattern contains directory names without glob characters in them. Thanks go to Mikhail Terekhov for contributing this improvement! * Added support for PurePath in a few places that were missed when this support was originally added. Once again, thanks go to Mikhail Terehkov for these fixes. * Fixed bug in SFTP parallel I/O file reader where it sometimes returned EOF prematurely. Thanks go to David G for reporting this problem and providing a reproducible test case. * Fixed test failures seen on Fedora Rawhide. Thanks go to Georg Sauthof for reporting this issue and providing a test environment to help debug it. * Updated Ed25519/448 and Curve25519/448 tests to only run when these algorithms are available. Thanks go to Ondřej Súkup for reporting this issue and providing a suggested fix. Release 1.16.0 (2 Mar 2019) --------------------------- * Added support for Ed448 host/client keys and certificates and rewrote Ed25519 support to use the PyCA implementation, reducing the dependency on libnacl and libsodium to only be needed to support the chacha20-poly1305 cipher. * Added support for PKCS#8 format Ed25519 and Ed448 private and public keys (in addition to the OpenSSH format previously supported). * Added support for multiple delimiters in SSHReader's readuntil() function, causing it to return data as soon as any of the specified delimiters are matched. * Added the ability to register custom key handlers in the line editor which can modify the input line, extending the built-in editing functionality. * Added SSHSubprocessProtocol and SSHSubprocessTransport classes to provide compatibility with asyncio.SubprocessProtocol and asyncio.SubprocessTransport. Code which is designed to call BaseEventLoop.subprocess_shell() or BaseEventLoop.subprocess_exec() can be easily adapted to work against a remote process by calling SSHClientConnection.create_subprocess(). * Added support for sending keepalive messages when the SSH connection is idle, with an option to automatically disconnect the connection if the remote system doesn't respond to these keepalives. * Changed AsyncSSH to ignore errors when loading unsupported key types from the default file locations. * Changed the reuse_port option to only be available on Python releases which support it (3.4.4 and later). * Fixed an issue where MSG_IGNORE packets could sometimes be sent between MSG_NEWKEYS and MSG_EXT_INFO, which caused some SSH implementations to fail to properly parse the MSG_EXT_INFO. * Fixed a couple of errors in the handling of disconnects occurring prior to authentication completing. * Renamed "session_encoding" and "session_errors" arguments in asyncssh.create_server() to "encoding" and "errors", to match the names used for these arguments in other AsyncSSH APIs. The old names are still supported for now, but they are marked as deprecated and will be removed in a future release. Release 1.15.1 (21 Jan 2019) ---------------------------- * Added callback-based host validation in SSHClient, allowing callers to decide programmatically whether to trust server host keys and certificates rather than having to provide a list of trusted values in advance. * Changed SSH client code to only load the default known hosts file if if exists. Previously an error was returned if a known_hosts value wasn't specified and the default known_hosts file didn't exist. For host validate to work in this case, verification callbacks must be implemented or other forms of validation such as X.509 trusted CAs or GSS-based key exchange must be used. * Fixed known hosts validation to completely disable certificate checks when known_hosts is set to None. Previously, key checking was disabled in this case but other checks for certificate expiration and hostname mismatch were still performed, causing connections to fail even when checking was supposed to be disabled. * Switched curve25519 key exchange to use the PyCA implementation, avoiding a dependency on libnacl/libsodium. For now, support for Ed25519 keys still requires these libraries. * Added get_fingerprint() method to return a fingerprint of an SSHKey. Release 1.15.0 (26 Nov 2018) ---------------------------- * Added the ability to pass keyword arguments provided in the scp() command through to asyncssh.connect() calls it makes, allowing things like custom credentials to be specified. * Added support for a reuse_port argument in create_server(). If set, this will be passed to the asyncio loop.create_server() call which creates listening sockets. * Added support for "soft" EOF when line editing in enabled so that EOF can be signalled multiple times on a channel. When Ctrl-D is received on a channel with line editing enabled, EOF is returned to the application but the channel remains open and capable of accepting more input, allowing an interactive shell to process the EOF for one command but still accept input for subsequent commands. * Added support for the Windows 10 OpenSSH ssh-agent. Thanks go to SamP20 for providing an initial proof of concept and a suggested implementation. * Reworked scoped link-local IPv6 address normalization to work better on Linux systems. * Fixed a problem preserving directory structure in recursive scp(). * Fixed SFTP chmod tests to avoid attempting to set the sticky bit on a plain file, as this caused test failures on FreeBSD. * Updated note in SSHClientChannel's send_signal() documentation to reflect that OpenSSH 7.9 and later should now support processing of signal messages. Release 1.14.0 (8 Sep 2018) --------------------------- * Changed license from EPL 1.0 to EPL 2.0 with GPL 2.0 or later as an available secondary license. * Added support for automatically parallelizing large reads and write made using the SFTPClientFile class, similar to what was already available in the get/put/copy methods of SFTPClient. * Added support for get_extra_info() in SSH process classes, returning information associated with the channel the process is tied to. * Added new set_extra_info() method on SSH connection and channel classes, allowing applications to store additional information on these objects. * Added handlers for OpenSSH keepalive global & channel requests to avoid messages about unknown requests in the debug log. These requests are still logged, but at debug level 2 instead of 1 and they are not labeled as unknown. * Fixed race condition when closing sockets associated with forwarded connections. * Improved error handling during connection close in SFTPClient. * Worked around issues with integer overflow on systems with a 32-bit time_t value when dates beyond 2038 are used in X.509 certificates. * Added guards around some imports and tests which were causing problems on Fedora 27. * Changed debug level for reporting PTY modes from 1 to 2 to reduce noise in the logs. * Improved SFTP debug log output when sending EOF responses. Release 1.13.3 (23 Jul 2018) ---------------------------- * Added support for setting the Unicode error handling strategy in conjunction with setting an encoding when creating new SSH sessions, streams, and processes. This strategy can also be set when specifying a session encoding in create_server(), and when providing an encoding in the get_comment() and set_comment() functions on private/public keys and certificates. * Changed handling of Unicode in channels to use incrmeental codec, similar to what was previously done in process redirection. * Added Python 3.7 to the list of classifiers in setup.py, now that it has been released. * Updated Travis CI configuration to add Python 3.7 builds, and moved Linux builds on never versions of Python up to xenial. * Added missing coroutine decorator in test_channel. Release 1.13.2 (3 Jul 2018) --------------------------- * Added support for accessing client host keys via the OpenSSH ssh-keysign program when doing host-based authentication. If ssh-keysign is present and enabled on the system, an AsyncSSH based SSH client can use host-based authentication without access to the host private keys. * Added support for using pathlib path objects when reading and writing private and public keys and certificates. * Added support for auth_completed() callback in the SSHServer class which runs when authentication completes successfully on each new connection. * Fixed host-based authentication unit tests to mock out calls to getnameinfo() to avoid failures on systems with restricted network functionality. Release 1.13.1 (16 Jun 2018) ---------------------------- * Added client and server support for host-based SSH authentication. If enabled, this will allow all users from a given host to be authenticated by a shared host key, rather than each user needing their own key. This should only be used with hosts which are trusted to keep their host keys secure and provide accurate client usernames. * Added support for RSA key exchange algorithms (rsa2048-sha256 and rsa1024-sha1) available in PuTTY and some mobile SSH clients. * Added support for the SECP256K1 elliptic curve for ECDSA keys and ECDH key exchange. This curve is supported by the Bitvise SSH client and server. * Added debug logging of the algorithms listed in a received kexinit message. Release 1.13.0 (20 May 2018) ---------------------------- * Added support for dynamic port forwarding via SOCKS, where AsyncSSH will open a listener which understands SOCKS connect requests and for each request open a TCP/IP tunnel over SSH to the requested host and port. * Added support in SSHProcess for I/O redirection to file objects that implement read(), write(), and close() functions as coroutines, such as the "aiofiles" package. In such cases, AsyncSSH will automaically detect that it needs to make async calls to these methods when it performs I/O. * Added support for using pathlib objects in SSHProcess I/O redirection. * Added multiple improvements to pattern matching support in the SFTPClient glob(), mget(), mput(), and mcopy() methods. AsyncSSH now allows you to use '**' in a pattern to do a recursive directory search, allows character ranges in square brackets in a pattern, and allows a trailing slash in a pattern to be specified to request that only directories matching the pattern should be returned. * Fixed an issue with calling readline() and readuntil() with a timeout, where partial data received before the timeout was sometimes discarded. Any partial data which was received when a timeout occurs will now be left in the input buffer, so it is still available to future read() calls. * Fixed a race condition where trying to restart a read() after a timeout could sometimes raise an exception about multiple simultaneous reads. * Changed readuntil() in SSHReader to raise IncompleteReadError if the receive window fills up before a delimiter match is found. This also applies to readline(), which will return a partial line without a newline at the end when this occurs. To support longer lines, a caller can call readuntil() or readline() as many times as they'd like, appending the data returned to the previous partial data until a delimiter is found or some maximum size is exceeded. Since the default window size is 2 MBytes, though, it's very unlikely this will be needed in most applications. * Reworked the crypto support in AsyncSSH to separate packet encryption and decryption into its own module and simplified the directory structure of the asyncssh.crypto package, eliminating a pyca subdirectory that was created back when AsyncSSH used a mix of PyCA and PyCrypto. Release 1.12.2 (17 Apr 2018) ---------------------------- * Added support for using pathlib objects as paths in calls to SFTP methods, in addition to Unicode and byte strings. This is mainly intended for use in constructing local paths, but it can also be used for remote paths as long as POSIX-style pathlib objects are used and an appropriate path encoding is set to handle the conversion from Unicode to bytes. * Changed server EXT_INFO message to only be sent after the first SSH key exchange, to match the specification recently published in RFC 8308. * Fixed edge case in TCP connection forwarding where data received on a forward TCP connection was not delivered if the connection was closed or half-closed before the corresponding SSH tunnel was fully established. * Made note about OpenSSH not properly handling send_signal more visible. Release 1.12.1 (10 Mar 2018) ---------------------------- * Implemented a fix for CVE-2018-7749, where a modified SSH client could request that an AsyncSSH server perform operations before authentication had completed. Thanks go to Matthijs Kooijman for discovering and reporting this issue and helping to review the fix. * Added a non-blocking collect_output() method to SSHClientProcess to allow applications to retrieve data received on an output stream without blocking. This call can be called multiple times and freely intermixed with regular read calls with a guarantee that output will always be returned in order and without duplication. * Updated debug logging implementation to make it more maintainable, and to fix an issue where unprocessed packets were not logged in some cases. * Extended the support below for non-ASCII characters in comments to apply to X.509 certificates, allowing an optional encoding to be passed in to get_comment() and set_comment() and a get_comment_bytes() function to get the raw comment bytes without performing Unicode decoding. * Fixed an issue where a UnicodeDecodeError could be reported in some cases instead of a KeyEncryptionError when a private key was imported using the wrong passphrase. * Fixed the reporting of the MAC algorithm selected during key exchange to properly report the cipher name for GCM and Chacha ciphers that don't use a separate MAC algorithm. The correct value was being returned in queries after the key exchange was complete, but the logging was being done before this adjustment was made. * Fixed the documentation of connection_made() in SSHSession subclasses to properly reflect the type of SSHChannel objects passed to them. Release 1.12.0 (5 Feb 2018) --------------------------- * Enhanced AsyncSSH logging framework to provide detailed logging of events in the connection, channel, key exchange, authentication, sftp, and scp modules. Both high-level information logs and more detailed debug logs are available, and debug logging supports multiple debug levels with different amounts of verboseness. Logger objects are also available on various AsyncSSH classes to allow applications to report their own log events in a manner that can be tied back to a specific SSH connection or channel. * Added support for begin_auth() to be a coroutine, so asynchronous operations can be performed within it to load state needed to perform SSH authentication. * Adjusted key usage flags set on generated X.509 certificates to be more RFC compliant and work around an issue with OpenSSL validation of self-signed non-CA certificates. * Updated key and certificate comment handling to be less sensitive to the encoding of non-ASCII characters. The get_comment() and set_comment() functions now take an optional encoding parameter, defaulting to UTF-8 but allowing for others encodings. There's also a get_comment_bytes() function to get the comment data as bytes without performing Unicode decoding. * Updated AsyncSSH to be compatible with beta release of Python 3.7. * Updated code to address warnings reported by the latest version of pylint. * Cleaned up various formatting issues in Sphinx documentation. * Significantly reduced time it takes to run unit tests by decreasing the rounds of bcrypt encryption used when unit testing encrypted OpenSSH private keys. * Added support for testing against uvloop in Travis CI. Release 1.11.1 (15 Nov 2017) ---------------------------- * Switched to using PBKDF2 implementation provided by PyCA, replacing a much slower pure-Python implementation used in earlier releases. * Improved support for file-like objects in process I/O redirection, properly handling objects which don't support fileno() and allowing both text and binary file objects based on whether they have an 'encoding' member. * Changed PEM parser to be forgiving of trailing blank lines. * Updated documentation to note lack of support in OpenSSH for send_signal(), terminate(), and kill() channel requests. * Updated unit tests to work better with OpenSSH 7.6. * Updated Travis CI config to test with more recent Python versions. Release 1.11.0 (9 Sep 2017) --------------------------- * Added support for X.509 certificate based client and server authentication, as defined in RFC 6187. * DSA, RSA, and ECDSA keys are supported. * New methods are available on SSHKey private keys to generate X.509 user, host, and CA certificates. * Authorized key and known host support has been enhanced to support matching on X.509 certificates and X.509 subject names. * New arguments have been added to create_connection() and create_server() to specify X.509 trusted root CAs, X.509 trusted root CA hash directories, and allowed X.509 certificate purposes. * A new load_certificates() function has been added to more easily pre-load a list of certificates from byte strings or files. * Support for including and validating OCSP responses is not yet available, but may be added in a future release. * This support adds a new optional dependency on pyOpenSSL in setup.py. * Added command, subsystem, and environment properties to SSHProcess, SSHCompletedProcess, and ProcessError classes, as well as stdout and stderr properties in ProcessError which mirror what is already present in SSHCompletedProcess. Thanks go to iforapsy for suggesting this. * Worked around a datetime.max bug on Windows. * Increased the build timeout on TravisCI to avoid build failures. Release 1.10.1 (19 May 2017) ---------------------------- * Fixed SCP to properly call exit() on SFTPServer when the copy completes. Thanks go to Arthur Darcet for discovering this and providing a suggested fix. * Added support for passphrase to be specified when loading default client keys, and to ignore encrypted default keys if no passphrase is specified. * Added additional known hosts test cases. Thanks go to Rafael Viotti for providing these. * Increased the default number of rounds for OpenSSH-compatible bcrypt private key encryption to avoid a warning in the latest version of the bcrypt module, and added a note that the encryption strength scale linearly with the rounds value, not logarithmically. * Fixed SCP unit test errors on Windows. * Fixed some issues with Travis and Appveyor CI builds. Release 1.10.0 (5 May 2017) --------------------------- * Added SCP client and server support, The new asyncssh.scp() function can get and put files on a remote SCP server and copy files between two or more remote SCP servers, with options similar to what was previously supported for SFTP. On the server side, an SFTPServer used to serve files over SFTP can also serve files over SCP by simply setting allow_scp to True in the call to create_server(). * Added a new SSHServerProcess class which supports I/O redirection on inbound connections to an SSH server, mirroring the SSHClientProcess class added previously for outbound SSH client connections. * Enabled TCP keepalive on SSH client and server connections. * Enabled Python 3 highlighting in Sphinx documentation. * Fixed a bug where a previously loaded SSHKnownHosts object wasn't properly accepted as a known_hosts value in create_connection() and enhanced known_hosts to accept a callable to allow applications to provide their own function to return trusted host keys. * Fixed a bug where an exception was raised if the connection closed while waiting for an asynchronous authentication callback to complete. * Fixed a bug where empty passwords weren't being properly supported. Release 1.9.0 (18 Feb 2017) --------------------------- * Added support for GSSAPI key exchange and authentication when the "gssapi" module is installed on UNIX or the "sspi" module from pypiwin32 is installed on Windows. * Added support for additional Diffie Hellman groups, and added the ability for Diffie Hellman and GSS group exchange to select larger group sizes. * Added overridable methods format_user() and format_group() to format user and group names in the SFTP server, defaulting to the previous behavior of using pwd.getpwuid() and grp.getgrgid() on platforms that support those. * Added an optional progress reporting callback on SFTP file transfers, and made the block size for these transfers configurable. * Added append_private_key(), append_public_key(), and append_certificate() methods on the corresponding key and certificate classes to simplify the creating of files containing a list of keys/certificates. * Updated readdir to break responses into chunks to avoid hitting maximum message size limits on large directories. * Updated SFTP to work better on Windows, properly handling drive letters and conversion between forward and back slashes in paths and handling setting of attributes on open files and proper support for POSIX rename. Also, file closes now block until the close completes, to avoid issues with file locking. * Updated the unit tests to run on Windows, and enabled continuous integration builds for Windows to automatically run on Appveyor. Release 1.8.1 (29 Dec 2016) --------------------------- * Fix an issue in attempting to load the 'nettle' library on Windows. Release 1.8.0 (29 Dec 2016) --------------------------- * Added support for forwarding X11 connections. When requested, AsyncSSH clients will allow remote X11 applications to tunnel data back to a local X server and AsyncSSH servers can request an X11 DISPLAY value to export to X11 applications they launch which will tunnel data back to an X server associated with the client. * Improved ssh-agent forwarding support on UNIX to allow AsyncSSH servers to request an SSH_AUTH_SOCK value to export to applications they launch in order to access the client's ssh-agent. Previously, there was support for agent forwarding on server connections within AsyncSSH itself, but they did not provide this forwarding to other applications. * Added support for PuTTY's Pageant agent on Windows systems, providing functionality similar to the OpenSSH agent on UNIX. AsyncSSH client connections from Windows can now access keys stored in the Pageant agent when they perform public key authentication. * Added support for the umac-64 and umac-128 MAC algorithms, compatible with the implementation in OpenSSH. These algorithms are preferred over the HMAC algorithms when both are available and the cipher chosen doesn't already include a MAC. * Added curve25519-sha256 as a supported key exchange algorithm. This algorithm is identical to the previously supported algorithm named 'curve25519-sha256\@libssh.org', matching what was done in OpenSSH 7.3. Either name may now be used to request this type of key exchange. * Changed the default order of key exchange algorithms to prefer the curve25519-sha256 algorithm over the ecdh-sha2-nistp algorithms. * Added support for a readuntil() function in SSHReader, modeled after the readuntil() function in asyncio.StreamReader added in Python 3.5.2. Thanks go to wwjiang for suggesting this and providing an example implementation. * Fixed issues where the explicitly provided event loop value was not being passed through to all of the places which needed it. Thanks go to Vladimir Rutsky for pointing out this problem and providing some initial fixes. * Improved error handling when port forwarding is requested for a port number outside of the range 0-65535. * Disabled use of IPv6 in unit tests when opening local loopback sockets to avoid issues with incomplete IPv6 support in TravisCI. * Changed the unit tests to always start with a known set of environment variables rather than inheriting the environment from the shell running the tests. This was leading to test breakage in some cases. Release 1.7.3 (22 Nov 2016) --------------------------- * Updated unit tests to run properly in environments where OpenSSH and OpenSSL are not installed. * Updated a process unit test to not depend on the system's default file encoding being UTF-8. * Updated Mac TravisCI builds to use Xcode 8.1. * Cleaned up some wording in the documentation. Release 1.7.2 (28 Oct 2016) --------------------------- * Fixed an issue with preserving file access times in SFTP, and update the unit tests to more accurate detect this kind of failure. * Fixed some markup errors in the documentation. * Fixed a small error in the change log for release 1.7.0 regarding the newly added Diffie Hellman key exchange algorithms. Release 1.7.1 (7 Oct 2016) -------------------------- * Fix an error that prevented the docs from building. Release 1.7.0 (7 Oct 2016) -------------------------- * Added support for group 14, 16, and 18 Diffie Hellman key exchange algorithms which use SHA-256 and SHA-512. * Added support for using SHA-256 and SHA-512 based signature algorithms for RSA keys and support for OpenSSH extension negotiation to advertise these signature algorithms. * Added new load_keypairs and load_public_keys API functions which support explicitly loading keys using the same syntax that was previously available for specifying client_keys, authorized_client_keys, and server_host_keys arguments when creating SSH clients and servers. * Enhanced the SSH agent client to support adding and removing keys and certificates (including support for constraints) and locking and unlocking the agent. Support has also been added for adding and removing smart card keys in the agent. * Added support for getting and setting a comment value when generating keys and certificates, and decoding and encoding this comment when importing and exporting keys that support it. Currently, this is available for OpenSSH format private keys and OpenSSH and RFC 4716 format public keys. These comment values are also passed on to the SSH agent when keys are added to it. * Fixed a bug in the generation of ECDSA certificates that showed up when trying to use the nistp384 or nistp521 curves. * Updated unit tests to use the new key and certificate generation functions, eliminating the dependency on the ssh-keygen program. * Updated unit tests to use the new SSH agent support when adding keys to the SSH agent, eliminating the dependency on the ssh-add program. * Incorporated a fix from Vincent Bernat for an issue with launching ssh-agent on some systems during unit testing. * Fixed some typos in the documentation found by Jakub Wilk. Release 1.6.2 (4 Sep 2016) -------------------------- * Added generate_user_certificate() and generate_host_certificate() methods to SSHKey class to generate SSH certificates, and export_certificate() and write_certificate() methods on SSHCertificate class to export certificates for use in other tools. * Improved editor unit tests to eliminate timing dependency. * Cleaned up a few minor documentation issues. Release 1.6.1 (27 Aug 2016) --------------------------- * Added generate_private_key() function to create new DSA, RSA, ECDSA, or Ed25519 private keys which can be used as SSH user and host keys. * Removed an unintended dependency in the SSHLineEditor on session objects keep a private member which referenced the corresponding channel. * Fixed a race condition in SFTP unit tests. * Updated dependencies to require version 1.5 of the cryptography module and started to take advantage of the new one-shot sign and verify APIs it now supports. * Clarified the documentation of the default return value of eof_received(). * Added new multi-user client and server examples, showing a single process opening multiple SSH connections in parallel. * Updated development status and Python versions listed in setup.py. Release 1.6.0 (13 Aug 2016) --------------------------- * Added new create_process() and run() APIs modeled after the "subprocess" module to simplify redirection of stdin, stdout, and stderr and collection of output from remote SSH processes. * Added input line editing and echoing capabilities to better support interactive SSH server applications. AsyncSSH server sessions will now automatically perform input echoing and provide basic line editing capabilities to clients which request a pseudo-terminal, avoiding the need for applications to provide this functionality. * Added the ability to use SSHReader objects as async iterators in Python 3.5, returning input a line at a time. * Added support for the IUTF8 terminal mode now recognized by OpenSSH 7.3. * Fixed a bug where an SSHReader read() call could return an empty string when it followed a call to readline() instead of blocking until more input was available. * Updated AsyncSSH to use the bcrypt package from PyCA, now that it has support for the kdf function. * Updated the documentation and examples to show how to take advantage of the new features listed here. Release 1.5.6 (18 Jun 2016) --------------------------- * Added support for Python 3.5 asynchronous context managers in SSHConnection, SFTPClient, and SFTPFile, while still maintaining backward compatibility with older Python 3.4 syntax. * Updated bcrypt check in test code to only test features that depend on it when the right version is available. * Switched testing over to using tox to better support testing on multiple versions of Python. * Added tests of new Python 3.5 async syntax. * Expanded Travis CI coverage to test both Python 3.4 and 3.5 on MacOS. * Updated documentation and examples to use Python 3.5 syntax. Release 1.5.5 (11 Jun 2016) --------------------------- * Updated public_key module to make sure the right version of bcrypt is installed before attempting to use it. * Updated forward and sftp module unit tests to work better on Linux. * Changed README links to point at new readthedocs.io domain. Release 1.5.4 (6 Jun 2016) -------------------------- * Added support for setting custom SSH client and server version strings. * Added unit tests for the sftp module, bringing AsyncSSH up to 100% code coverage under test on all modules. * Added new wait_closed() method in SFTPClient class to wait for an SFTP client session to be fully closed. * Fixed an issue with error handling in new parallel SFTP file copy code. * Fixed some other minor issues in SFTP found during unit tests. * Fixed some minor documentation issues. Release 1.5.3 (2 Apr 2016) -------------------------- * Added support for opening tunneled SSH connections, where an SSH connection is opened over another SSH connection's direct TCP/IP channel. * Improve performance of SFTP over high latency connections by having the internal copy method issue multiple read requests in parallel. * Reworked SFTP to mark all coroutine functions explicitly, to provide better compatibility with the new Python 3.5 "await" syntax. * Reworked create_connection() and create_server() functions to do argument checking immediately rather than in the SSHConnection constructors, improving error reporting and avoiding a bug in asyncio which can leak socket objects. * Fixed a hang which could occur when attempting to close an SSH connection with a listener still active. * Fixed an error related to passing keys in via public_key_auth_requested(). * Fixed a potential leak of an SSHAgentClient object when an error occurs while opening a client connection. * Fixed some race conditions related to channel and connection closes. * Fixed some minor documentation issues. * Continued to expand unit test coverage, completing coverage of the connection module. Release 1.5.2 (25 Feb 2016) --------------------------- * Fixed a bug in UNIX domain socket forwarding introduced in 1.5.1 by the TCP_NODELAY change. * Fixed channel code to report when a channel is closed with incomplete Unicode data in the receive buffer. This was previously reported correctly when EOF was received on a channel, but not when it was closed without sending EOF. * Added unit tests for channel, forward, and stream modules, partial unit tests for the connection module, and a placeholder for unit tests for the sftp module. Release 1.5.1 (23 Feb 2016) --------------------------- * Added basic support for running AsyncSSH on Windows. Some functionality such as UNIX domain sockets will not work there, and the test suite will not run there yet, but basic functionality has been tested and seems to work. This includes features like bcrypt and support for newer ciphers provided by libnacl when these optional packages are installed. * Greatly improved the performance of known_hosts matching on exact hostnames and addresses. Full wildcard pattern matching is still supported, but entries involving exact hostnames or addresses are now matched thousands of times faster. * Split known_hosts parsing and matching into separate calls so that a known_hosts file can be parsed once and used to make connections to several different hosts. Thanks go to Josh Yudaken for suggesting this and providing a sample implementation. * Updated AsyncSSH to allow SSH agent forwarding when it is requested even when local client keys are used to perform SSH authentication. * Updaded channel state machine to better handle close being received while the channel is paused for reading. Previously, some data would not be delivered in this case. * Set TCP_NODELAY on sockets to avoid latency problems caused by TCP delayed ACK. * Fixed a bug where exceptions were not always returned properly when attempting to drain writes on a stream. * Fixed a bug which could leak a socket object after an error opening a local TCP listening socket. * Fixed a number of race conditions uncovered during unit testing. Release 1.5.0 (27 Jan 2016) --------------------------- * Added support for OpenSSH-compatible direct and forwarded UNIX domain socket channels and local and remote UNIX domain socket forwarding. * Added support for client and server side ssh-agent forwarding. * Fixed the open_connection() method on SSHServerConnection to not include a handler_factory argument. This should only have been present on the start_server() method. * Fixed wait_closed() on SSHForwardListener to work properly when a close is in progress at the time of the call. Release 1.4.1 (23 Jan 2016) --------------------------- * Fixed a bug in SFTP introduced in 1.4.0 related to handling of responses to non-blocking file closes. * Updated code to avoid calling asyncio.async(), deprecated in Python 3.4.4. * Updated unit tests to avoid errors on systems with an older version of OpenSSL installed. Release 1.4.0 (17 Jan 2016) --------------------------- * Added ssh-agent client support, automatically using it when SSH_AUTH_SOCK is set and client private keys aren't explicitly provided. * Added new wait_closed() API on SSHConnection to allow applications to wait for a connection to be fully closed and updated examples to use it. * Added a new login_timeout argument when create an SSH server. * Added a missing acknowledgement response when canceling port forwarding and fixed a few other issues related to cleaning up port forwarding listeners. * Added handlers to improve the catching and reporting of exceptions that are raised in asynchronous tasks. * Reworked channel state machine to perform clean up on a channel only after a close is both sent and received. * Fixed SSHChannel to run the connection_lost() handler on the SSHSession before unblocking callers of wait_closed(). * Fixed wait_closed() on SSHListener to wait for the acknowledgement from the SSH server before returning. * Fixed a race condition in port forwarding code. * Fixed a bug related to sending a close on a channel which got a failure when being opened. * Fixed a bug related to handling term_type being set without term_size. * Fixed some issues related to the automatic conversion of client keyboard-interactive auth to password auth. With this change, automatic conversion will only occur if the application doesn't override the kbdint_challenge_received() method and it will only attempt to authenticate once with the password provided. Release 1.3.2 (26 Nov 2015) --------------------------- * Added server-side support for handling password changes during password authentication, and fixed a few other auth-related bugs. * Added the ability to override the automatic support for keyboard-interactive authentication when password authentication is supported. * Fixed a race condition in unblocking streams. * Removed support for OpenSSH v00 certificates now that OpenSSH no longer supports them. * Added unit tests for auth module. Release 1.3.1 (6 Nov 2015) -------------------------- * Updated AsyncSSH to depend on version 1.1 or later of PyCA and added support for using its new Elliptic Curve Diffie Hellman (ECDH) implementation, replacing the previous AsyncSSH native Python version. * Added support for specifying a passphrase in the create_connection, create_server, connect, and listen functions to allow file names or byte strings containing encrypted client and server host keys to be specified in those calls. * Fixed handling of cancellation in a few AsyncSSH calls, so it is now possible to make calls to things like stream read or drain which time out. * Fixed a bug in keyboard-interactive fallback to password auth which was introduced when support was added for auth functions optionally being coroutines. * Move bcrypt check in encrypted key handling until it is needed so better errors can be returned if a passphrase is not specified or the key derivation function used in a key is unknown. * Added unit tests for the auth_keys module. * Updated unit tests to better handle bcrypt or libnacl not being installed. Release 1.3.0 (10 Oct 2015) --------------------------- * Updated AsyncSSH dependencies to make PyCA version 1.0.0 or later mandatory and remove the older PyCrypto support. This change also adds support for the PyCA implementation of ECDSA and removes support for RC2-based private key encryption that was only supported by PyCrypto. * Refactored ECDH and Curve25519 key exchange code so they can share an implementation, and prepared the code for adding a PyCA shim for this as soon as support for that is released. * Hardened the DSA and RSA implementations to do stricter checking of the key exchange response, and sped up the RSA implementation by taking advantage of optional RSA private key parameters when they are present. * Added support for asynchronous client and server authentication, allowing auth-related callbacks in SSHClient and SSHServer to optionally be defined as coroutines. * Added support for asynchronous SFTP server processing, allowing callbacks in SFTPServer to optionally be defined as coroutines. * Added support for a broader set of open mode flags in the SFTP server. Note that this change is not completely backward compatible with previous releases. If you have application code which expects a Python mode string as an argument to SFTPServer open method, it will need to be changed to expect a pflags value instead. * Fixed handling of eof_received() when it returns false to close the half-open connection but still allow sending or receiving of exit status and exit signals. * Added unit tests for the asn1, cipher, compression, ec, kex, known_hosts, mac, and saslprep modules and expended the set of pbe and public_key unit tests. * Fixed a set of issues uncovered by ASN.1 unit tests: * Removed extra 0xff byte when encoding integers of the form -128*256^n * Fixed decoding error for OIDs beginning with 2.n where n >= 40 * Fixed range check for second component of ObjectIdentifier * Added check for extraneous 0x80 bytes in ObjectIdentifier components * Added check for negative component values in ObjectIdentifier * Added error handling for ObjectIdentifier components being non-integer * Added handling for missing length byte after extended tag * Raised ASN1EncodeError instead of TypeError on unsupported types * Added validation on asn1_class argument, and equality and hash methods to BitString, RawDERObject, and TaggedDERObject. Also, reordered RawDERObject arguments to be consistent with TaggedDERObject and added str method to ObjectIdentifier. * Fixed a set of issues uncovered by additional pbe unit tests: * Encoding and decoding of PBES2-encrypted keys with a PRF other than SHA1 is now handled correctly. * Some exception messages were made more specific. * Additional checks were put in for empty salt or zero iteration count in encryption parameters. * Fixed a set of issues uncovered by additional public key unit tests: * Properly handle PKCS#8 keys with invalid ASN.1 data * Properly handle PKCS#8 DSA & RSA keys with non-sequence for arg_params * Properly handle attempts to import empty string as a public key * Properly handle encrypted PEM keys with missing DEK-Info header * Report check byte mismatches for encrypted OpenSSH keys as bad passphrase * Return KeyImportError instead of KeyEncryptionError when passphrase is needed but not provided * Added information about branches to CONTRIBUTING guide. * Performed a bunch of code cleanup suggested by pylint. Release 1.2.1 (26 Aug 2015) --------------------------- * Fixed a problem with passing in client_keys=None to disable public key authentication in the SSH client. * Updated Unicode handling to allow multi-byte Unicode characters to be split across successive SSH data messages. * Added a note to the documentation for AsyncSSH create_connection() explaining how to perform the equivalent of a connect with a timeout. Release 1.2.0 (6 Jun 2015) -------------------------- * Fixed a problem with the SSHConnection context manager on Python versions older than 3.4.2. * Updated the documentation for get_extra_info() in the SSHConnection, SSHChannel, SSHReader, and SSHWriter classes to contain pointers to get_extra_info() in their parent transports to make it easier to see all of the attributes which can be queried. * Clarified the legal return values for the session_requested(), connection_requested(), and server_requested() methods in SSHServer. * Eliminated calls to the deprecated importlib.find_loader() method. * Made improvements to README suggested by Nicholas Chammas. * Fixed a number of issues identified by pylint. Release 1.1.1 (25 May 2015) --------------------------- * Added new start_sftp_server method on SSHChannel to allow applications using the non-streams API to start an SFTP server. * Enhanced the default format_longname() method in SFTPServer to properly handle the case where not all of the file attributes are returned by stat(). * Fixed a bug related to the new allow_pty parameter in create_server. * Fixed a bug in the hashed known_hosts support introduced in some recent refactoring of the host pattern matching code. Release 1.1.0 (22 May 2015) --------------------------- * SFTP is now supported! * Both client and server support is available. * SFTP version 3 is supported, with OpenSSH extensions. * Recursive transfers and glob matching are supported in the client. * File I/O APIs allow files to be accessed without downloading them. * New simplified connect and listen APIs have been added. * SSHConnection can now be used as a context manager. * New arguments to create_server now allow the specification of a session_factory and encoding or sftp_factory as well as controls over whether a pty is allowed and the window and max packet size, avoiding the need to create custom SSHServer subclasses or custom SSHServerChannel instances. * New examples have been added for SFTP and to show the use of the new connect and listen APIs. * Copyrights in changed files have all been updated to 2015. Release 1.0.1 (13 Apr 2015) --------------------------- * Fixed a bug in OpenSSH private key encryption introduced in some recent cipher refactoring. * Added bcrypt and libnacl as optional dependencies in setup.py. * Changed test_keys test to work properly when bcrypt or libnacl aren't installed. Release 1.0.0 (11 Apr 2015) --------------------------- * This release finishes adding a number of major features, finally making it worthy of being called a "1.0" release. * Host and user certificates are now supported! * Enforcement is done on principals in certificates. * Enforcement is done on force-command and source-address critical options. * Enforcement is done on permit-pty and permit-port-forwarding extensions. * OpenSSH-style known hosts files are now supported! * Positive and negative wildcard and CIDR-style patterns are supported. * HMAC-SHA1 hashed host entries are supported. * The @cert-authority and @revoked markers are supported. * OpenSSH-style authorized keys files are now supported! * Both client keys and certificate authorities are supported. * Enforcement is done on from and principals options during key matching. * Enforcement is done on no-pty, no-port-forwarding, and permitopen. * The command and environment options are supported. * Applications can query for their own non-standard options. * Support has been added for OpenSSH format private keys. * DSA, RSA, and ECDSA keys in this format are now supported. * Ed25519 keys are supported when libnacl and libsodium are installed. * OpenSSH private key encryption is supported when bcrypt is installed. * Curve25519 Diffie-Hellman key exchange is now available via either the curve25519-donna or libnacl and libsodium packages. * ECDSA key support has been enhanced. * Support is now available for PKCS#8 ECDSA v2 keys. * Support is now available for both NamedCurve and explicit ECParameter versions of keys, as long as the parameters match one of the supported curves (nistp256, nistp384, or nistp521). * Support is now available for the OpenSSH chacha20-poly1305 cipher when libnacl and libsodium are installed. * Cipher names specified in private key encryption have been changed to be consistent with OpenSSH cipher naming, and all SSH ciphers can now be used for encryption of keys in OpenSSH private key format. * A couple of race conditions in SSHChannel have been fixed and channel cleanup is now delayed to allow outstanding message handling to finish. * Channel exceptions are now properly delivered in the streams API. * A bug in SSHStream read() where it could sometimes return more data than requested has been fixed. Also, read() has been changed to properly block and return all data until EOF or a signal is received when it is called with no length. * A bug in the default implementation of keyboard-interactive authentication has been fixed, and the matching of a password prompt has been loosened to allow it to be used for password authentication on more devices. * Missing code to resume reading after a stream is paused has been added. * Improvements have been made in the handling of canceled requests. * The test code has been updated to test Ed25519 and OpenSSH format private keys. * Examples have been updated to reflect some of the new capabilities. Release 0.9.2 (26 Jan 2015) --------------------------- * Fixed a bug in PyCrypto CipherFactory introduced during PyCA refactoring. Release 0.9.1 (3 Dec 2014) -------------------------- * Added some missing items in setup.py and MANIFEST.in. * Fixed the install to work even when cryptographic dependencies aren't yet installed. * Fixed an issue where get_extra_info calls could fail if called when a connection or session was shutting down. Release 0.9.0 (14 Nov 2014) --------------------------- * Added support to use PyCA (0.6.1 or later) for cryptography. AsyncSSH will automatically detect and use either PyCA, PyCrypto, or both depending on which is installed and which algorithms are requested. * Added support for AES-GCM ciphers when PyCA is installed. Release 0.8.4 (12 Sep 2014) --------------------------- * Fixed an error in the encode/decode functions for PKCS#1 DSA public keys. * Fixed a bug in the unit test code for import/export of RFC4716 public keys. Release 0.8.3 (16 Aug 2014) --------------------------- * Added a missing import in the curve25519 implementation. Release 0.8.2 (16 Aug 2014) --------------------------- * Provided a better long description for PyPI. * Added link to PyPI in documentation sidebar. Release 0.8.1 (15 Aug 2014) --------------------------- * Added a note in the :meth:`validate_public_key() ` documentation clarifying that AsyncSSH will verify that the client possesses the corresponding private key before authentication is allowed to succeed. * Switched from setuptools to distutils and added an initial set of unit tests. * Prepared the package to be uploaded to PyPI. Release 0.8.0 (15 Jul 2014) --------------------------- * Added support for Curve25519 Diffie Hellman key exchange on systems with the curve25519-donna Python package installed. * Updated the examples to more clearly show what values are returned even when not all of the return values are used. Release 0.7.0 (7 Jun 2014) -------------------------- * This release adds support for the "high-level" ``asyncio`` streams API, in the form of the :class:`SSHReader` and :class:`SSHWriter` classes and wrapper methods such as :meth:`open_session() `, :meth:`open_connection() `, and :meth:`start_server() `. It also allows the callback methods on :class:`SSHServer` to return either SSH session objects or handler functions that take :class:`SSHReader` and :class:`SSHWriter` objects as arguments. See :meth:`session_requested() `, :meth:`connection_requested() `, and :meth:`server_requested() ` for more information. * Added new exceptions :exc:`BreakReceived`, :exc:`SignalReceived`, and :exc:`TerminalSizeChanged` to report when these messages are received while trying to read from an :class:`SSHServerChannel` using the new streams API. * Changed :meth:`create_server() ` to accept either a callable or a coroutine for its ``session_factory`` argument, to allow asynchronous operations to be used when deciding whether to accept a forwarded TCP connection. * Renamed ``accept_connection()`` to :meth:`create_connection() ` in the :class:`SSHServerConnection` class for consistency with :class:`SSHClientConnection`, and added a corresponding :meth:`open_connection() ` method as part of the streams API. * Added :meth:`get_exit_status() ` and :meth:`get_exit_signal() ` methods to the :class:`SSHClientChannel` class. * Added :meth:`get_command() ` and :meth:`get_subsystem() ` methods to the :class:`SSHServerChannel` class. * Fixed the name of the :meth:`write_stderr() ` method and added the missing :meth:`writelines_stderr() ` method to the :class:`SSHServerChannel` class for outputting data to the stderr channel. * Added support for a return value in the :meth:`eof_received() ` of :class:`SSHClientSession`, :class:`SSHServerSession`, and :class:`SSHTCPSession` to support half-open channels. By default, the channel is automatically closed after :meth:`eof_received() ` returns, but returning ``True`` will now keep the channel open, allowing output to still be sent on the half-open channel. This is done automatically when the new streams API is used. * Added values ``'local_peername'`` and ``'remote_peername'`` to the set of information available from the :meth:`get_extra_info() ` method in the :class:`SSHTCPChannel` class. * Updated functions returning :exc:`IOError` or :exc:`socket.error` to return the new :exc:`OSError` exception introduced in Python 3.3. * Cleaned up some errors in the documentation. * The :ref:`API`, :ref:`ClientExamples`, and :ref:`ServerExamples` have all been updated to reflect these changes, and new examples showing the streams API have been added. Release 0.6.0 (11 May 2014) --------------------------- * This release is a major revamp of the code to migrate from the ``asyncore`` framework to the new ``asyncio`` framework in Python 3.4. All the APIs have been adapted to fit the new ``asyncio`` paradigm, using coroutines wherever possible to avoid the need for callbacks when performing asynchronous operations. So far, this release only supports the "low-level" ``asyncio`` API. * The :ref:`API`, :ref:`ClientExamples`, and :ref:`ServerExamples` have all been updated to reflect these changes. Release 0.5.0 (11 Oct 2013) --------------------------- * Added the following new classes to support fully asynchronous connection forwarding, replacing the methods previously added in release 0.2.0: * :class:`SSHClientListener` * :class:`SSHServerListener` * :class:`SSHClientLocalPortForwarder` * :class:`SSHClientRemotePortForwarder` * :class:`SSHServerPortForwarder` These new classes allow for DNS lookups and other operations to be performed fully asynchronously when new listeners are set up. As with the asynchronous connect changes below, methods are now available to report when the listener is opened or when an error occurs during the open rather than requiring the listener to be fully set up in a single call. * Updated examples in :ref:`ClientExamples` and :ref:`ServerExamples` to reflect the above changes. Release 0.4.0 (28 Sep 2013) --------------------------- * Added support in :class:`SSHTCPConnection` for the following methods to allow asynchronous operations to be used when accepting inbound connection requests: * :meth:`handle_open_request() ` * :meth:`report_open() ` * :meth:`report_open_error() ` These new methods are used to implement asynchronous connect support for local and remote port forwarding, and to support trying multiple destination addresses when connection failures occur. * Cleaned up a few minor documentation errors. Release 0.3.0 (26 Sep 2013) --------------------------- * Added support in :class:`SSHClient` and :class:`SSHServer` for setting the key exchange, encryption, MAC, and compression algorithms allowed in the SSH handshake. * Refactored the algorithm selection code to pull a common matching function back into ``_SSHConnection`` and simplify other modules. * Extended the listener class to open multiple listening sockets when necessary, fixing a bug where sockets opened to listen on ``localhost`` were not properly accepting both IPv4 and IPv6 connections. Now, any listen request which resolves to multiple addresses will open listening sockets for each address. * Fixed a bug related to tracking of listeners opened on dynamic ports. Release 0.2.0 (21 Sep 2013) --------------------------- * Added support in :class:`SSHClient` for the following methods related to performing standard SSH port forwarding: * :meth:`forward_local_port() ` * :meth:`cancel_local_port_forwarding() ` * :meth:`forward_remote_port() ` * :meth:`cancel_remote_port_forwarding() ` * :meth:`handle_remote_port_forwarding() ` * :meth:`handle_remote_port_forwarding_error() ` * Added support in :class:`SSHServer` for new return values in :meth:`handle_direct_connection() ` and :meth:`handle_listen() ` to activate standard SSH server-side port forwarding. * Added a client_addr argument and member variable to :class:`SSHServer` to hold the client's address information. * Added and updated examples related to port forwarding and using :class:`SSHTCPConnection` to open direct and forwarded TCP connections in :ref:`ClientExamples` and :ref:`ServerExamples`. * Cleaned up some of the other documentation. * Removed a debug print statement accidentally left in related to SSH rekeying. Release 0.1.0 (14 Sep 2013) --------------------------- * Initial release asyncssh-2.20.0/docs/conf.py000066400000000000000000000174111475467777400157150ustar00rootroot00000000000000#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # AsyncSSH documentation build configuration file, created by # sphinx-quickstart on Sun Sep 1 17:36:31 2013. # # This file is execfile()d with the current directory set to its containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys, os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) from asyncssh import __author__, __version__ # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = 'AsyncSSH' copyright = '2013-2023, ' + __author__ # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The full version, including alpha/beta/rc tags. release = __version__ # The short X.Y.Z version. version = '.'.join(release.split('.')[:3]) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build'] # The reST default role (used for this markup: `text`) to use for all documents. default_role = 'py:obj' # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' highlight_language = 'python3' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'rftheme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { "sidebarwidth": 450, "stickysidebar": "true" } # Add any paths that contain custom themes here, relative to this directory. html_theme_path = ['.'] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". #html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = {'**': ['sidebartop.html', 'localtoc.html', 'sidebarbottom.html']} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. html_domain_indices = False # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = False # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'AsyncSSHdoc' # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} autodoc_typehints = "none" asyncssh-2.20.0/docs/contributing.rst000066400000000000000000000000411475467777400176460ustar00rootroot00000000000000.. include:: ../CONTRIBUTING.rst asyncssh-2.20.0/docs/index.rst000066400000000000000000000577531475467777400162740ustar00rootroot00000000000000.. toctree:: :hidden: changes contributing api .. currentmodule:: asyncssh .. include:: ../README.rst .. _ClientExamples: Client Examples =============== Simple client ------------- The following code shows an example of a simple SSH client which logs into localhost and lists files in a directory named 'abc' under the user's home directory. The username provided is the logged in user, and the user's default SSH client keys or certificates are presented during authentication. The server's host key is checked against the user's SSH known_hosts file and the connection will fail if there's no entry for localhost there or if the key doesn't match. .. include:: ../examples/simple_client.py :literal: :start-line: 22 This example shows using the :class:`SSHClientConnection` returned by :func:`connect()` as a context manager, so that the connection is automatically closed when the end of the code block which opened it is reached. However, if you need the connection object to live longer, you can use "await" instead of "async with": .. code:: conn = await asyncssh.connect('localhost') In this case, the application will need to close the connection explicitly when done with it, and it is best to also wait for the close to complete. This can be done with the following code from inside an async function: .. code:: conn.close() await conn.wait_closed() Only stdout is referenced this example, but output on stderr is also collected as another attribute in the returned :class:`SSHCompletedProcess` object. Shell and exec sessions default to an encoding of 'utf-8', so read and write calls operate on strings by default. If you want to send and receive binary data, you can set the encoding to `None` when the session is opened to make read and write operate on bytes instead. Alternate encodings can also be selected to change how strings are converted to and from bytes. To check against a different set of server host keys, they can be provided in the known_hosts argument when the connection is opened: .. code:: async with asyncssh.connect('localhost', known_hosts='my_known_hosts') as conn: Server host key checking can be disabled by setting the known_hosts argument to ``None``, but that's not recommended as it makes the connection vulnerable to a man-in-the-middle attack. To log in as a different remote user, the username argument can be provided: .. code:: async with asyncssh.connect('localhost', username='user123') as conn: To use a different set of client keys for authentication, they can be provided in the client_keys argument: .. code:: async with asyncssh.connect('localhost', client_keys=['my_ssh_key']) as conn: Password authentication can be used by providing a password argument: .. code:: async with asyncssh.connect('localhost', password='secretpw') as conn: Any of the arguments above can be combined together as needed. If client keys and a password are both provided, either may be used depending on what forms of authentication the server supports and whether the authentication with them is successful. Callback example ---------------- AsyncSSH also provides APIs that use callbacks rather than "await" and "async with". Here's the example above written using custom :class:`SSHClient` and :class:`SSHClientSession` subclasses: .. include:: ../examples/callback_client.py :literal: :start-line: 22 In cases where you don't need to customize callbacks on the SSHClient class, this code can be simplified somewhat to: .. include:: ../examples/callback_client2.py :literal: :start-line: 22 If you need to distinguish output going to stdout vs. stderr, that's easy to do with the following change: .. include:: ../examples/callback_client3.py :literal: :start-line: 22 Interactive input ----------------- The following example demonstrates sending interactive input to a remote process. It executes the calculator program ``bc`` and performs some basic math calculations. Note that it uses the :meth:`create_process ` method rather than the :meth:`run ` method. This starts the process but doesn't wait for it to exit, allowing interaction with it. .. include:: ../examples/math_client.py :literal: :start-line: 22 When run, this program should produce the following output: .. code:: 2+2 = 4 1*2*3*4 = 24 2^32 = 4294967296 I/O redirection --------------- The following example shows how to pass a fixed input string to a remote process and redirect the resulting output to the local file '/tmp/stdout'. Input lines containing 1, 2, and 3 are passed into the 'tail -r' command and the output written to '/tmp/stdout' should contain the reversed lines 3, 2, and 1: .. include:: ../examples/redirect_input.py :literal: :start-line: 22 The ``stdin``, ``stdout``, and ``stderr`` arguments support redirecting to a variety of locations include local files, pipes, and sockets as well as :class:`SSHReader` or :class:`SSHWriter` objects associated with other remote SSH processes. Here's an example of piping stdout from a local process to a remote process: .. include:: ../examples/redirect_local_pipe.py :literal: :start-line: 22 Here's an example of piping one remote process to another: .. include:: ../examples/redirect_remote_pipe.py :literal: :start-line: 22 In this example both remote processes are running on the same SSH connection, but this redirection can just as easily be used between SSH sessions associated with connections going to different servers. Checking exit status -------------------- The following example shows how to test the exit status of a remote process: .. include:: ../examples/check_exit_status.py :literal: :start-line: 22 If an exit signal is received, the exit status will be set to -1 and exit signal information is provided in the ``exit_signal`` attribute of the returned :class:`SSHCompletedProcess`. If the ``check`` argument in :meth:`run ` is set to ``True``, any abnormal exit will raise a :exc:`ProcessError` exception instead of returning an :class:`SSHCompletedProcess`. Running multiple clients ------------------------ The following example shows how to run multiple clients in parallel and process the results when all of them have completed: .. include:: ../examples/gather_results.py :literal: :start-line: 22 Results could be processed as they became available by setting up a loop which repeatedly called :func:`asyncio.wait` instead of calling :func:`asyncio.gather`. Setting environment variables ----------------------------- The following example demonstrates setting environment variables for the remote session and displaying them by executing the 'env' command. .. include:: ../examples/set_environment.py :literal: :start-line: 22 Any number of environment variables can be passed in the dictionary given to :meth:`create_session() `. Note that SSH servers may restrict which environment variables (if any) are accepted, so this feature may require setting options on the SSH server before it will work. Setting terminal information ---------------------------- The following example demonstrates setting the terminal type and size passed to the remote session. .. include:: ../examples/set_terminal.py :literal: :start-line: 22 Note that this will cause AsyncSSH to request a pseudo-tty from the server. When a pseudo-tty is used, the server will no longer send output going to stderr with a different data type. Instead, it will be mixed with output going to stdout (unless it is redirected elsewhere by the remote command). Port forwarding --------------- The following example demonstrates the client setting up a local TCP listener on port 8080 and requesting that connections which arrive on that port be forwarded across SSH to the server and on to port 80 on ``www.google.com``: .. include:: ../examples/local_forwarding_client.py :literal: :start-line: 22 To listen on a dynamically assigned port, the client can pass in ``0`` as the listening port. If the listener is successfully opened, the selected port will be available via the :meth:`get_port() ` method on the returned listener object: .. include:: ../examples/local_forwarding_client2.py :literal: :start-line: 22 The client can also request remote port forwarding from the server. The following example shows the client requesting that the server listen on port 8080 and that connections arriving there be forwarded across SSH and on to port 80 on ``localhost``: .. include:: ../examples/remote_forwarding_client.py :literal: :start-line: 22 To limit which connections are accepted or dynamically select where to forward traffic to, the client can implement their own session factory and call :meth:`forward_connection() ` on the connections they wish to forward and raise an error on those they wish to reject: .. include:: ../examples/remote_forwarding_client2.py :literal: :start-line: 22 Just as with local listeners, the client can request remote port forwarding from a dynamic port by passing in ``0`` as the listening port and then call :meth:`get_port() ` on the returned listener to determine which port was selected. Direct TCP connections ---------------------- The client can also ask the server to open a TCP connection and directly send and receive data on it by using the :meth:`create_connection() ` method on the :class:`SSHClientConnection` object. In this example, a connection is attempted to port 80 on ``www.google.com`` and an HTTP HEAD request is sent for the document root. Note that unlike sessions created with :meth:`create_session() `, the I/O on these connections defaults to sending and receiving bytes rather than strings, allowing arbitrary binary data to be exchanged. However, this can be changed by setting the encoding to use when the connection is created. .. include:: ../examples/direct_client.py :literal: :start-line: 22 To use the streams API to open a direct connection, you can use :meth:`open_connection ` instead of :meth:`create_connection `: .. include:: ../examples/stream_direct_client.py :literal: :start-line: 22 Forwarded TCP connections ------------------------- The client can also directly process data from incoming TCP connections received on the server. The following example demonstrates the client requesting that the server listen on port 8888 and forward any received connections back to it over SSH. It then has a simple handler which echoes any data it receives back to the sender. As in the direct TCP connection example above, the default would be to send and receive bytes on this connection rather than strings, but here we set the encoding explicitly so all data is sent and received as strings: .. include:: ../examples/listening_client.py :literal: :start-line: 22 To use the streams API to open a listening connection, you can use :meth:`start_server ` instead of :meth:`create_server `: .. include:: ../examples/stream_listening_client.py :literal: :start-line: 22 SFTP client ----------- AsyncSSH also provides SFTP support. The following code shows an example of starting an SFTP client and requesting the download of a file: .. include:: ../examples/sftp_client.py :literal: :start-line: 22 To recursively download a directory, preserving access and modification times and permissions on the files, the preserve and recurse arguments can be included: .. code:: await sftp.get('example_dir', preserve=True, recurse=True) Wild card pattern matching is supported by the :meth:`mget `, :meth:`mput `, and :meth:`mcopy ` methods. The following downloads all files with extension "txt": .. code:: await sftp.mget('*.txt') See the :class:`SFTPClient` documentation for the full list of available actions. SCP client ---------- AsyncSSH also supports SCP. The following code shows an example of downloading a file via SCP: .. include:: ../examples/scp_client.py :literal: :start-line: 22 To upload a file to a remote system, host information can be specified for the destination instead of the source: .. code:: await asyncssh.scp('example.txt', 'localhost:') If the destination path includes a file name, that name will be used instead of the original file name when performing the copy. For instance: .. code:: await asyncssh.scp('example.txt', 'localhost:example2.txt') If the destination path refers to a directory, the origin file name will be preserved, but it will be copied into the requested directory. Wild card patterns are also supported on local source paths. For instance, the following copies all files with extension "txt": .. code:: await asyncssh.scp('*.txt', 'localhost:') When copying files from a remote system, any wild card expansion is the responsibility of the remote SCP program or the shell which starts it. Similar to SFTP, SCP also supports options for recursively copying a directory and preserving modification times and permissions on files using the preserve and recurse arguments: .. code:: await asyncssh.scp('example_dir', 'localhost:', preserve=True, recurse=True) In addition to the ``'host:path'`` syntax for source and destination paths, a tuple of the form ``(host, path)`` is also supported. A non-default port can be specified by replacing ``host`` with ``(host, port)``, resulting in something like: .. code:: await asyncssh.scp((('localhost', 8022), 'example.txt'), '.') An already open :class:`SSHClientConnection` can also be passed as the host: .. code:: async with asyncssh.connect('localhost') as conn: await asyncssh.scp((conn, 'example.txt'), '.') Multiple file patterns can be copied to the same destination by making the source path argument a list. Source paths in this list can be a mixture of local and remote file references and the destination path can be local or remote, but one or both of source and destination must be remote. Local to local copies are not supported. See the :func:`scp` function documentation for the complete list of available options. .. _ServerExamples: Server Examples =============== Simple server ------------- The following code shows an example of a simple SSH server which listens for connections on port 8022, does password authentication, and prints a message when users authenticate successfully and start a shell. Shell and exec sessions default to an encoding of 'utf-8', so read and write calls operate on strings by default. If you want to send and receive binary data, you can set the encoding to `None` when the session is opened to make read and write operate on bytes instead. Alternate encodings can also be selected to change how strings are converted to and from bytes. .. include:: ../examples/simple_server.py :literal: :start-line: 22 To authenticate with SSH client keys or certificates, the server would look something like the following. Client and certificate authority keys for each user need to be placed in a file matching the username in a directory called ``authorized_keys``. .. include:: ../examples/simple_keyed_server.py :literal: :start-line: 30 It is also possible to use a single authorized_keys file for all users. This is common when using certificates, as AsyncSSH can automatically enforce that the certificates presented have a principal in them which matches the username. In this case, a custom :class:`SSHServer` subclass is no longer required, and so the :func:`listen` function can be used in place of :func:`create_server`. .. include:: ../examples/simple_cert_server.py :literal: :start-line: 29 Simple server with input ------------------------ The following example demonstrates reading input in a server session. It adds a column of numbers, displaying the total when it receives EOF. .. include:: ../examples/math_server.py :literal: :start-line: 29 Callback example ---------------- Here's an example of the server above written using callbacks in custom :class:`SSHServer` and :class:`SSHServerSession` subclasses. .. include:: ../examples/callback_math_server.py :literal: :start-line: 29 I/O redirection --------------- The following shows an example of I/O redirection on the server side, executing a process on the server with input and output redirected back to the SSH client: .. include:: ../examples/redirect_server.py :literal: :start-line: 29 Serving multiple clients ------------------------ The following is a slightly more complicated example showing how a server can manage multiple simultaneous clients. It implements a basic chat service, where clients can send messages to one other. .. include:: ../examples/chat_server.py :literal: :start-line: 29 Line editing ------------ When SSH clients request a pseudo-terminal, they generally default to sending input a character at a time and expect the remote system to provide character echo and line editing. To better support interactive applications like the one above, AsyncSSH defaults to providing basic line editing for server sessions which request a pseudo-terminal. When this line editor is enabled, it defaults to delivering input to the application a line at a time. Applications can switch between line and character at a time input using the :meth:`set_line_mode() ` method. Also, when in line mode, applications can enable or disable echoing of input using the :meth:`set_echo() ` method. The following code provides an example of this. .. include:: ../examples/editor.py :literal: :start-line: 29 Getting environment variables ----------------------------- The following example demonstrates reading environment variables set by the client. It will show all of the variables set by the client, or return an error if none are set. Note that SSH clients may restrict which environment variables (if any) are sent by default, so you may need to set options in the client to get it to do so. .. include:: ../examples/show_environment.py :literal: :start-line: 29 Getting terminal information ---------------------------- The following example demonstrates reading the client's terminal type and window size, and handling window size changes during a session. .. include:: ../examples/show_terminal.py :literal: :start-line: 29 Port forwarding --------------- The following example demonstrates a server accepting port forwarding requests from clients, but only when they are destined to port 80. When such a connection is received, a connection is attempted to the requested host and port and data is bidirectionally forwarded over SSH from the client to this destination. Requests by the client to connect to any other port are rejected. .. include:: ../examples/local_forwarding_server.py :literal: :start-line: 29 The server can also support forwarding inbound TCP connections back to the client. The following example demonstrates a server which will accept requests like this from clients, but only to listen on port 8080. When such a connection is received, the client is notified and data is bidirectionally forwarded from the incoming connection over SSH to the client. .. include:: ../examples/remote_forwarding_server.py :literal: :start-line: 29 Direct TCP connections ---------------------- The server can also accept direct TCP connection requests from the client and process the data on them itself. The following example demonstrates a server which accepts requests to port 7 (the "echo" port) for any host and echoes the data itself rather than forwarding the connection: .. include:: ../examples/direct_server.py :literal: :start-line: 29 Here's an example of this server written using the streams API. In this case, :meth:`connection_requested() ` returns a handler coroutine instead of a session object. When a new direct TCP connection is opened, the handler coroutine is called with AsyncSSH stream objects which can be used to perform I/O on the tunneled connection. .. include:: ../examples/stream_direct_server.py :literal: :start-line: 29 SFTP server ----------- The following example shows how to start an SFTP server with default behavior: .. include:: ../examples/simple_sftp_server.py :literal: :start-line: 29 A subclass of :class:`SFTPServer` can be provided as the value of the SFTP factory to override specific behavior. For example, the following code remaps path names so that each user gets access to only their own individual directory under ``/tmp/sftp``: .. include:: ../examples/chroot_sftp_server.py :literal: :start-line: 29 More complex path remapping can be performed by implementing the :meth:`map_path ` and :meth:`reverse_map_path ` methods. Individual SFTP actions can also be overridden as needed. See the :class:`SFTPServer` documentation for the full list of methods to override. SCP server ---------- The above server examples can be modified to also support SCP by simply adding ``allow_scp=True`` alongside the specification of the ``sftp_factory`` in the :func:`listen` call. This will use the same :class:`SFTPServer` instance when performing file I/O for both SFTP and SCP requests. For instance: .. include:: ../examples/simple_scp_server.py :literal: :start-line: 29 Reverse Direction Example ========================= One of the unique capabilities of AsyncSSH is its ability to support "reverse direction" SSH connections, using the functions :func:`connect_reverse` and :func:`listen_reverse`. This can be helpful when implementing protocols such as "NETCONF Call Home", described in :rfc:`8071`. When using this capability, the SSH protocol doesn't change, but the roles at the TCP level about which side acts as a TCP client and server are reversed, with the TCP client taking on the role of the SSH server and the TCP server taking on the role of the SSH client once the connection is established. For these examples to run, the following files must be created: * The file ``client_host_key`` must exist on the client and contain an SSH private key for the client to use to authenticate itself as a host to the server. An SSH certificate can optionally be provided in ``client_host_key-cert.pub``. * The file ``trusted_server_keys`` must exist on the client and contain a list of trusted server keys or a ``cert-authority`` entry with a public key trusted to sign server keys if certificates are used. This file should be in "authorized_keys" format. * The file ``server_key`` must exist on the server and contain an SSH private key for the server to use to authenticate itself to the client. An SSH certificate can optionally be provided in ``server_key-cert.pub``. * The file ``trusted_client_host_keys`` must exist on the server and contain a list of trusted client host keys or a ``@cert-authority`` entry with a public key trusted to sign client host keys if certificates are used. This file should be in "known_hosts" format. Reverse Direction Client ------------------------ The following example shows a reverse-direction SSH client which will run arbitrary shell commands given to it by the server it connects to: .. include:: ../examples/reverse_client.py :literal: :start-line: 32 Reverse Direction Server ------------------------ Here is the corresponding server which makes requests to run the commands: .. include:: ../examples/reverse_server.py :literal: :start-line: 32 asyncssh-2.20.0/docs/requirements.txt000066400000000000000000000000561475467777400176770ustar00rootroot00000000000000cryptography >= 39.0 typing_extensions >= 3.6 asyncssh-2.20.0/docs/rftheme/000077500000000000000000000000001475467777400160445ustar00rootroot00000000000000asyncssh-2.20.0/docs/rftheme/layout.html000066400000000000000000000002631475467777400202500ustar00rootroot00000000000000{% extends "basic/layout.html" %} {# Omit the top navigation bar. #} {% block relbar1 %} {% endblock %} {# Omit the bottom navigation bar. #} {% block relbar2 %} {% endblock %} asyncssh-2.20.0/docs/rftheme/static/000077500000000000000000000000001475467777400173335ustar00rootroot00000000000000asyncssh-2.20.0/docs/rftheme/static/rftheme.css_t000066400000000000000000000006141475467777400220230ustar00rootroot00000000000000@import url("classic.css"); .tight-list * { line-height: 110% !important; margin: 0 0 3px !important; } div.sphinxsidebar { top: 0; } div.sphinxsidebarwrapper { padding-top: 8px; } div.body p, div.body dd, div.body li { text-align: left; } tt, .note tt { font-size: 1.15em; background: none; } div.body p.rubric { font-size: 1.3em; margin: 15px 0 5px; } asyncssh-2.20.0/docs/rftheme/theme.conf000066400000000000000000000001131475467777400200100ustar00rootroot00000000000000[theme] inherit = classic stylesheet = rftheme.css pygments_style = sphinx asyncssh-2.20.0/docs/rtd-req.txt000066400000000000000000000000401475467777400165230ustar00rootroot00000000000000cryptography==2.8 sphinx==4.2.0 asyncssh-2.20.0/examples/000077500000000000000000000000001475467777400153005ustar00rootroot00000000000000asyncssh-2.20.0/examples/callback_client.py000077500000000000000000000033501475467777400207500ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys from typing import Optional class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data: str, datatype: asyncssh.DataType) -> None: print(data, end='') def connection_lost(self, exc: Optional[Exception]) -> None: if exc: print('SSH session error: ' + str(exc), file=sys.stderr) class MySSHClient(asyncssh.SSHClient): def connection_made(self, conn: asyncssh.SSHClientConnection) -> None: print(f'Connection made to {conn.get_extra_info('peername')[0]}.') def auth_completed(self) -> None: print('Authentication successful.') async def run_client() -> None: conn, client = await asyncssh.create_connection(MySSHClient, 'localhost') async with conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/callback_client2.py000077500000000000000000000026511475467777400210350ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys from typing import Optional class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data: str, datatype: asyncssh.DataType) -> None: print(data, end='') def connection_lost(self, exc: Optional[Exception]) -> None: if exc: print('SSH session error: ' + str(exc), file=sys.stderr) async def run_client() -> None: async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/callback_client3.py000077500000000000000000000030421475467777400210310ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys from typing import Optional class MySSHClientSession(asyncssh.SSHClientSession): def data_received(self, data: str, datatype: asyncssh.DataType) -> None: if datatype == asyncssh.EXTENDED_DATA_STDERR: print(data, end='', file=sys.stderr) else: print(data, end='') def connection_lost(self, exc: Optional[Exception]) -> None: if exc: print('SSH session error: ' + str(exc), file=sys.stderr) async def run_client() -> None: async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_session(MySSHClientSession, 'ls abc') await chan.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/callback_math_server.py000077500000000000000000000053431475467777400220150ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServerSession(asyncssh.SSHServerSession): def __init__(self): self._input = '' self._total = 0 def connection_made(self, chan: asyncssh.SSHServerChannel): self._chan = chan def shell_requested(self) -> bool: return True def session_started(self) -> None: self._chan.write('Enter numbers one per line, or EOF when done:\n') def data_received(self, data: str, datatype: asyncssh.DataType) -> None: self._input += data lines = self._input.split('\n') for line in lines[:-1]: try: if line: self._total += int(line) except ValueError: self._chan.write_stderr(f'Invalid number: {line}\n') self._input = lines[-1] def eof_received(self) -> bool: self._chan.write(f'Total = {self._total}\n') self._chan.exit(0) return False def break_received(self, msec: int) -> bool: return self.eof_received() def soft_eof_received(self) -> None: self.eof_received() class MySSHServer(asyncssh.SSHServer): def session_requested(self) -> asyncssh.SSHServerSession: return MySSHServerSession() async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/chat_server.py000077500000000000000000000053631475467777400201710ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys from typing import List, cast class ChatClient: _clients: List['ChatClient'] = [] def __init__(self, process: asyncssh.SSHServerProcess): self._process = process @classmethod async def handle_client(cls, process: asyncssh.SSHServerProcess): await cls(process).run() async def readline(self) -> str: return cast(str, self._process.stdin.readline()) def write(self, msg: str) -> None: self._process.stdout.write(msg) def broadcast(self, msg: str) -> None: for client in self._clients: if client != self: client.write(msg) async def run(self) -> None: self.write('Welcome to chat!\n\n') self.write('Enter your name: ') name = (await self.readline()).rstrip('\n') self.write(f'\n{len(self._clients)} other users are connected.\n\n') self._clients.append(self) self.broadcast(f'*** {name} has entered chat ***\n') try: async for line in self._process.stdin: self.broadcast(f'{name}: {line}') except asyncssh.BreakReceived: pass self.broadcast(f'*** {name} has left chat ***\n') self._clients.remove(self) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=ChatClient.handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/check_exit_status.py000077500000000000000000000024171475467777400213720ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: result = await conn.run('ls abc') if result.exit_status == 0: print(result.stdout, end='') else: print(result.stderr, end='', file=sys.stderr) print(f'Program exited with status {result.exit_status}', file=sys.stderr) try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/chroot_sftp_server.py000077500000000000000000000034161475467777400216010ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, os, sys class MySFTPServer(asyncssh.SFTPServer): def __init__(self, chan: asyncssh.SSHServerChannel): root = '/tmp/sftp/' + chan.get_extra_info('username') os.makedirs(root, exist_ok=True) super().__init__(chan, chroot=root) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=MySFTPServer) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/direct_client.py000077500000000000000000000032771475467777400204760ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys from typing import Optional class MySSHTCPSession(asyncssh.SSHTCPSession): def data_received(self, data: bytes, datatype: asyncssh.DataType) -> None: # We use sys.stdout.buffer here because we're writing bytes sys.stdout.buffer.write(data) def connection_lost(self, exc: Optional[Exception]) -> None: if exc: print('Direct connection error:', str(exc), file=sys.stderr) async def run_client() -> None: async with asyncssh.connect('localhost') as conn: chan, session = await conn.create_connection(MySSHTCPSession, 'www.google.com', 80) # By default, TCP connections send and receive bytes chan.write(b'HEAD / HTTP/1.0\r\n\r\n') chan.write_eof() await chan.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/direct_server.py000077500000000000000000000043151475467777400205200ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHTCPSession(asyncssh.SSHTCPSession): def connection_made(self, chan: asyncssh.SSHTCPChannel) -> None: self._chan = chan def data_received(self, data: bytes, datatype: asyncssh.DataType) -> None: self._chan.write(data) class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> \ asyncssh.SSHTCPSession: if dest_port == 7: return MySSHTCPSession() else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only echo connections allowed') async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/editor.py000077500000000000000000000042111475467777400171410ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys from typing import cast async def handle_client(process: asyncssh.SSHServerProcess): channel = cast(asyncssh.SSHLineEditorChannel, process.channel) username = process.get_extra_info('username') process.stdout.write(f'Welcome to my SSH server, {username}!\n\n') channel.set_echo(False) process.stdout.write('Tell me a secret: ') secret = await process.stdin.readline() channel.set_line_mode(False) process.stdout.write('\nYour secret is safe with me! ' 'Press any key to exit...') await process.stdin.read(1) process.stdout.write('\n') process.exit(0) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/gather_results.py000077500000000000000000000031211475467777400207050ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh async def run_client(host, command: str) -> asyncssh.SSHCompletedProcess: async with asyncssh.connect(host) as conn: return await conn.run(command) async def run_multiple_clients() -> None: # Put your lists of hosts here hosts = 5 * ['localhost'] tasks = (run_client(host, 'ls abc') for host in hosts) results = await asyncio.gather(*tasks, return_exceptions=True) for i, result in enumerate(results, 1): if isinstance(result, Exception): print(f'Task {i} failed: {result}') elif result.exit_status != 0: print(f'Task {i} exited with status {result.exit_status}:') print(result.stderr, end='') else: print(f'Task {i} succeeded:') print(result.stdout, end='') print(75*'-') asyncio.run(run_multiple_clients()) asyncssh-2.20.0/examples/listening_client.py000077500000000000000000000032771475467777400212200ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys class MySSHTCPSession(asyncssh.SSHTCPSession): def connection_made(self, chan: asyncssh.SSHTCPChannel) -> None: self._chan = chan def data_received(self, data: bytes, datatype: asyncssh.DataType): self._chan.write(data) def connection_requested(orig_host: str, orig_port: int) -> asyncssh.SSHTCPSession: print(f'Connection received from {orig_host}, port {orig_port}') return MySSHTCPSession() async def run_client() -> None: async with asyncssh.connect('localhost') as conn: server = await conn.create_server(connection_requested, '', 8888, encoding='utf-8') if server: await server.wait_closed() else: print('Listener couldn\'t be opened.', file=sys.stderr) try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/local_forwarding_client.py000077500000000000000000000021341475467777400225270ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: listener = await conn.forward_local_port('', 8080, 'www.google.com', 80) await listener.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/local_forwarding_client2.py000077500000000000000000000022261475467777400226130ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: listener = await conn.forward_local_port('', 0, 'www.google.com', 80) print(f'Listening on port {listener.get_port()}...') await listener.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/local_forwarding_server.py000077500000000000000000000036711475467777400225660ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> bool: if dest_port == 80: return True else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only connections to port 80 are allowed') async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/math_client.py000077500000000000000000000023501475467777400201440ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: async with conn.create_process('bc') as process: for op in ['2+2', '1*2*3*4', '2^32']: process.stdin.write(op + '\n') result = await process.stdout.readline() print(op, '=', result, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/math_server.py000077500000000000000000000040771475467777400202040ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process: asyncssh.SSHServerProcess) -> None: process.stdout.write('Enter numbers one per line, or EOF when done:\n') total = 0 try: async for line in process.stdin: line = line.rstrip('\n') if line: try: total += int(line) except ValueError: process.stderr.write(f'Invalid number: {line}\n') except asyncssh.BreakReceived: pass process.stdout.write(f'Total = {total}\n') process.exit(0) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/redirect_input.py000077500000000000000000000020611475467777400206740ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: await conn.run('tail -r', input='1\n2\n3\n', stdout='/tmp/stdout') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/redirect_local_pipe.py000077500000000000000000000023541475467777400216510ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, subprocess, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: local_proc = subprocess.Popen(r'echo "1\n2\n3"', shell=True, stdout=subprocess.PIPE) remote_result = await conn.run('tail -r', stdin=local_proc.stdout) print(remote_result.stdout, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/redirect_remote_pipe.py000077500000000000000000000022231475467777400220450ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: proc1 = await conn.create_process(r'echo "1\n2\n3"') proc2_result = await conn.run('tail -r', stdin=proc1.stdout) print(proc2_result.stdout, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/redirect_server.py000077500000000000000000000036601475467777400210510ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, subprocess, sys async def handle_client(process: asyncssh.SSHServerProcess) -> None: bc_proc = subprocess.Popen('bc', shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) await process.redirect(stdin=bc_proc.stdin, stdout=bc_proc.stdout, stderr=bc_proc.stderr) await process.stdout.drain() process.exit(0) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/remote_forwarding_client.py000077500000000000000000000021301475467777400227240ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: listener = await conn.forward_remote_port('', 8080, 'localhost', 80) await listener.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/remote_forwarding_client2.py000077500000000000000000000031151475467777400230120ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys from functools import partial from typing import Awaitable def connection_requested(conn: asyncssh.SSHClientConnection, orig_host: str, orig_port: int) -> Awaitable[asyncssh.SSHForwarder]: if orig_host in ('127.0.0.1', '::1'): return conn.forward_connection('localhost', 80) else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Connections only allowed from localhost') async def run_client() -> None: async with asyncssh.connect('localhost') as conn: listener = await conn.create_server( partial(connection_requested, conn), '', 8080) await listener.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/remote_forwarding_server.py000077500000000000000000000032561475467777400227660ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys class MySSHServer(asyncssh.SSHServer): def server_requested(self, listen_host: str, listen_port: int) -> bool: return listen_port == 8080 async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/reverse_client.py000077500000000000000000000044641475467777400206760ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file client_host_key must exist on the client, # containing an SSH private key for the client to use to authenticate # itself as a host to the server. An SSH certificate can optionally be # provided in the file client_host_key-cert.pub. # # The file trusted_server_keys must also exist on the client, containing a # list of trusted server keys or a cert-authority entry with a public key # trusted to sign server keys if certificates are used. This file should # be in "authorized_keys" format. import asyncio, asyncssh, sys from asyncio.subprocess import PIPE async def handle_request(process: asyncssh.SSHServerProcess) -> None: """Run a command on the client, piping I/O over an SSH session""" assert process.command is not None local_proc = await asyncio.create_subprocess_shell( process.command, stdin=PIPE, stdout=PIPE, stderr=PIPE) await process.redirect(stdin=local_proc.stdin, stdout=local_proc.stdout, stderr=local_proc.stderr) process.exit(await local_proc.wait()) await process.wait_closed() async def run_reverse_client() -> None: """Make an outbound connection and then become an SSH server on it""" conn = await asyncssh.connect_reverse( 'localhost', 8022, server_host_keys=['client_host_key'], authorized_client_keys='trusted_server_keys', process_factory=handle_request, encoding=None) await conn.wait_closed() try: asyncio.run(run_reverse_client()) except (OSError, asyncssh.Error) as exc: sys.exit('Reverse SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/reverse_server.py000077500000000000000000000046341475467777400207250ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file server_key must exist on the server, # containing an SSH private key for the server to use to authenticate itself # to the client. An SSH certificate can optionally be provided in the file # server_key-cert.pub. # # The file trusted_client_host_keys must also exist on the server, containing # a list of trusted client host keys or a @cert-authority entry with a public # key trusted to sign client host keys if certificates are used. This file # should be in "known_hosts" format. import asyncio, asyncssh, sys async def run_commands(conn: asyncssh.SSHClientConnection) -> None: """Run a series of commands on the client which connected to us""" commands = ('ls', 'sleep 30 && date', 'sleep 5 && cat /proc/cpuinfo') async with conn: tasks = [conn.run(cmd) for cmd in commands] for task in asyncio.as_completed(tasks): result = await task print('Command:', result.command) print('Return code:', result.returncode) print('Stdout:') print(result.stdout, end='') print('Stderr:') print(result.stderr, end='') print(75*'-') async def start_reverse_server() -> None: """Accept inbound connections and then become an SSH client on them""" await asyncssh.listen_reverse(port=8022, client_keys=['server_key'], known_hosts='trusted_client_host_keys', acceptor=run_commands) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_reverse_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/scp_client.py000077500000000000000000000017451475467777400200070ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: await asyncssh.scp('localhost:example.txt', '.') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) asyncssh-2.20.0/examples/set_environment.py000077500000000000000000000022101475467777400210670ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: result = await conn.run('env', env={'LANG': 'en_GB', 'LC_COLLATE': 'C'}) print(result.stdout, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/set_terminal.py000077500000000000000000000022611475467777400203440ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: result = await conn.run('echo $TERM; stty size', term_type='xterm-color', term_size=(80, 24)) print(result.stdout, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/sftp_client.py000077500000000000000000000021051475467777400201650ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: async with conn.start_sftp_client() as sftp: await sftp.get('example.txt') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) asyncssh-2.20.0/examples/show_environment.py000077500000000000000000000036761475467777400212750ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process: asyncssh.SSHServerProcess) -> None: if process.env: keywidth = max(map(len, process.env.keys()))+1 process.stdout.write('Environment:\n') for key, value in process.env.items(): process.stdout.write(f' {key+":":{keywidth}} {value}\n') process.exit(0) else: process.stderr.write('No environment sent.\n') process.exit(1) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/show_terminal.py000077500000000000000000000045311475467777400205330ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_client(process: asyncssh.SSHServerProcess) -> None: width, height, pixwidth, pixheight = process.term_size process.stdout.write(f'Terminal type: {process.term_type}, ' f'size: {width}x{height}') if pixwidth and pixheight: process.stdout.write(f' ({pixwidth}x{pixheight} pixels)') process.stdout.write('\nTry resizing your window!\n') while not process.stdin.at_eof(): try: await process.stdin.read() except asyncssh.TerminalSizeChanged as exc: process.stdout.write(f'New window size: {exc.width}x{exc.height}') if exc.pixwidth and exc.pixheight: process.stdout.write(f' ({exc.pixwidth}' f'x{exc.pixheight} pixels)') process.stdout.write('\n') async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/simple_cert_server.py000077500000000000000000000033431475467777400215540ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys def handle_client(process: asyncssh.SSHServerProcess) -> None: username = process.get_extra_info('username') process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/simple_client.py000077500000000000000000000024351475467777400205100ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: try: result = await conn.run('ls abc', check=True) except asyncssh.ProcessError as exc: print(exc.stderr, end='') print(f'Process exited with status {exc.exit_status}', file=sys.stderr) else: print(result.stdout, end='') try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/simple_keyed_server.py000077500000000000000000000041661475467777400217240ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # Authentication requires the directory authorized_keys to exist with # files in it named based on the username containing the client keys # and certificate authority keys which are accepted for that user. import asyncio, asyncssh, sys def handle_client(process: asyncssh.SSHServerProcess) -> None: username = process.get_extra_info('username') process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) class MySSHServer(asyncssh.SSHServer): def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: self._conn = conn def begin_auth(self, username: str) -> bool: try: self._conn.set_authorized_keys(f'authorized_keys/{username}') except OSError: pass return True async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/simple_scp_server.py000077500000000000000000000030341475467777400214010ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=True, allow_scp=True) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/simple_server.py000077500000000000000000000053341475467777400205410ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. import asyncio, asyncssh, bcrypt, sys from typing import Optional passwords = {'guest': b'', # guest account with no password 'user123': bcrypt.hashpw(b'secretpw', bcrypt.gensalt()), } def handle_client(process: asyncssh.SSHServerProcess) -> None: username = process.get_extra_info('username') process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) class MySSHServer(asyncssh.SSHServer): def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: peername = conn.get_extra_info('peername')[0] print(f'SSH connection received from {peername}.') def connection_lost(self, exc: Optional[Exception]) -> None: if exc: print('SSH connection error: ' + str(exc), file=sys.stderr) else: print('SSH connection closed.') def begin_auth(self, username: str) -> bool: # If the user's password is the empty string, no auth is required return passwords.get(username) != b'' def password_auth_supported(self) -> bool: return True def validate_password(self, username: str, password: str) -> bool: if username not in passwords: return False pw = passwords[username] if not password and not pw: return True return bcrypt.checkpw(password.encode('utf-8'), pw) async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], process_factory=handle_client) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/simple_sftp_server.py000077500000000000000000000030141475467777400215660ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def start_server() -> None: await asyncssh.listen('', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca', sftp_factory=True) loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('Error starting server: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/stream_direct_client.py000077500000000000000000000025201475467777400220370ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def run_client() -> None: async with asyncssh.connect('localhost') as conn: reader, writer = await conn.open_connection('www.google.com', 80) # By default, TCP connections send and receive bytes writer.write(b'HEAD / HTTP/1.0\r\n\r\n') writer.write_eof() # We use sys.stdout.buffer here because we're writing bytes response = await reader.read() sys.stdout.buffer.write(response) try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/examples/stream_direct_server.py000077500000000000000000000044301475467777400220710ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # To run this program, the file ``ssh_host_key`` must exist with an SSH # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. # # The file ``ssh_user_ca`` must exist with a cert-authority entry of # the certificate authority which can sign valid client certificates. import asyncio, asyncssh, sys async def handle_connection(reader: asyncssh.SSHReader, writer: asyncssh.SSHWriter) -> None: while not reader.at_eof(): data = await reader.read(8192) try: writer.write(data) except BrokenPipeError: break writer.close() class MySSHServer(asyncssh.SSHServer): def connection_requested(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> \ asyncssh.SSHSocketSessionFactory: if dest_port == 7: return handle_connection else: raise asyncssh.ChannelOpenError( asyncssh.OPEN_ADMINISTRATIVELY_PROHIBITED, 'Only echo connections allowed') async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH server failed: ' + str(exc)) loop.run_forever() asyncssh-2.20.0/examples/stream_listening_client.py000077500000000000000000000026731475467777400225720ustar00rootroot00000000000000#!/usr/bin/env python3.7 # # Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation import asyncio, asyncssh, sys async def handle_connection(reader, writer): while not reader.at_eof(): data = await reader.read(8192) writer.write(data) writer.close() def connection_requested(orig_host, orig_port): print(f'Connection received from {orig_host}, port {orig_port}') return handle_connection async def run_client(): async with asyncssh.connect('localhost') as conn: server = await conn.start_server(connection_requested, '', 8888, encoding='utf-8') await server.wait_closed() try: asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) asyncssh-2.20.0/mypy.ini000066400000000000000000000000411475467777400151540ustar00rootroot00000000000000[mypy] allow_redefinition = True asyncssh-2.20.0/pylintrc000066400000000000000000000265161475467777400152630ustar00rootroot00000000000000[MASTER] # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=1 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist= # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time. See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=fixme,locally-enabled,locally-disabled,redefined-variable-type,no-else-break,no-else-continue,no-else-raise,no-else-return,inconsistent-return-statements,invalid-overridden-method,unbalanced-tuple-unpacking,useless-return,consider-using-with [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". files-output=no # Tells whether to display a full report or only the messages reports=yes # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # List of builtins function names that should not be used, separated by a comma bad-functions=map,filter # Good variable names which should always be accepted, separated by a comma good-names=A,B,D,I,a,av,b,c,ca,ch,cn,d,e,f,fd,fp,fs,g,h,i,id,ip,iv,j,k,l,n,ns,p,q,r,rp,s,sa,t,v,x,y,_,_e,_f,_g,_k,_p,_q,_x # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for argument names argument-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for variable names variable-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression matching correct constant names const-rgx=(([A-Za-z_][A-Za-z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z][A-Z-9_]*)|(__.*__))$ # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for method names method-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct function names function-rgx=[A-Za-z_][A-Za-z0-9_]{2,75}$ # Naming hint for function names function-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,75}$ # Naming hint for attribute names attr-name-hint=[a-z_][a-z0-9_]{2,75}$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,75}|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,75}|(__.*__))$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=__.*__ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 [FORMAT] # Maximum number of characters on a single line. max-line-length=100 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=10000 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=10 [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=15 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). ignored-classes=Namespace # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E0201 when accessed. Python regular # expressions are accepted. generated-members=REQUEST,acl_users,aq_parent [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=_$|dummy # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [DESIGN] # Maximum number of arguments for function / method max-args=100 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=100 # Maximum number of return / yield for function / method body max-returns=10 # Maximum number of boolean expressions in a if statement max-bool-expr=20 # Maximum number of branch for function / method body max-branches=100 # Maximum number of statements in function / method body max-statements=150 # Maximum number of parents for a class (see R0901). max-parents=10 # Maximum number of attributes for a class (see R0902). max-attributes=150 # Minimum number of public methods for a class (see R0903). min-public-methods=0 # Maximum number of public methods for a class (see R0904). max-public-methods=250 [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=stringprep,optparse # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception asyncssh-2.20.0/pyproject.toml000066400000000000000000000034641475467777400164050ustar00rootroot00000000000000[build-system] requires = ['setuptools'] build-backend = 'setuptools.build_meta' [project] name = 'asyncssh' license = {text = 'EPL-2.0 OR GPL-2.0-or-later'} description = 'AsyncSSH: Asynchronous SSHv2 client and server library' readme = 'README.rst' authors = [{name = 'Ron Frederick', email = 'ronf@timeheart.net'}] classifiers = [ 'Development Status :: 5 - Production/Stable', 'Environment :: Console', 'Intended Audience :: Developers', 'License :: OSI Approved', 'Operating System :: MacOS :: MacOS X', 'Operating System :: POSIX', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', 'Topic :: Internet', 'Topic :: Security :: Cryptography', 'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: System :: Networking', ] requires-python = '>= 3.6' dependencies = [ 'cryptography >= 39.0', 'typing_extensions >= 4.0.0', ] dynamic = ['version'] [project.optional-dependencies] bcrypt = ['bcrypt >= 3.1.3'] fido2 = ['fido2 >= 0.9.2'] gssapi = ['gssapi >= 1.2.0'] libnacl = ['libnacl >= 1.4.2'] pkcs11 = ['python-pkcs11 >= 0.7.0'] pyOpenSSL = ['pyOpenSSL >= 23.0.0'] pywin32 = ['pywin32 >= 227'] [project.urls] Homepage = 'http://asyncssh.timeheart.net' Documentation = 'https://asyncssh.readthedocs.io' Source = 'https://github.com/ronf/asyncssh' Tracker = 'https://github.com/ronf/asyncssh/issues' [tool.setuptools.dynamic] version = {attr = 'asyncssh.version.__version__'} [tool.setuptools.packages.find] include = ['asyncssh*'] [tool.setuptools.package-data] asyncssh = ['py.typed'] asyncssh-2.20.0/tests/000077500000000000000000000000001475467777400146245ustar00rootroot00000000000000asyncssh-2.20.0/tests/__init__.py000066400000000000000000000013641475467777400167410ustar00rootroot00000000000000# Copyright (c) 2014-2018 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH""" asyncssh-2.20.0/tests/gss_stub.py000066400000000000000000000030111475467777400170220ustar00rootroot00000000000000# Copyright (c) 2017-2018 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub GSS module for unit tests""" def step(host, token): """Perform next step in GSS authentication""" complete = False if token == b'errtok': return token, complete elif ((token is None and 'empty_init' in host) or (token == b'1' and 'empty_continue' in host)): return b'', complete elif token == b'0': if 'continue_token' in host: token = b'continue' else: complete = True token = b'extra' if 'extra_token' in host else None elif token: token = bytes((token[0]-1,)) else: token = host[0].encode('ascii') if token == b'0': if 'step_error' in host: return (b'errtok' if 'errtok' in host else b'error'), complete complete = True return token, complete asyncssh-2.20.0/tests/gssapi_stub.py000066400000000000000000000071421475467777400175250ustar00rootroot00000000000000# Copyright (c) 2017-2019 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub GSSAPI module for unit tests""" from enum import IntEnum from asyncssh.gss import GSSError from .gss_stub import step class Name: """Stub class for GSS principal name""" def __init__(self, base, _name_type=None): if 'init_error' in base: raise GSSError(99, 99) self.host = base[5:] class Credentials: """Stub class for GSS credentials""" def __init__(self, name=None, usage=None, store=None): # pylint: disable=unused-argument self.host = name.host if name else '' self.server = usage == 'accept' @property def mechs(self): """Return GSS mechanisms available for this host""" if self.server: return [0] if 'unknown_mech' in self.host else [1, 2] else: return [2] class RequirementFlag(IntEnum): """Stub class for GSS requirement flags""" # pylint: disable=invalid-name delegate_to_peer = 1 mutual_authentication = 2 integrity = 4 class SecurityContext: """Stub class for GSS security context""" def __init__(self, name=None, creds=None, flags=None): host = creds.host if creds.server else name.host if flags is None: flags = RequirementFlag.mutual_authentication | \ RequirementFlag.integrity if ((creds.server and 'no_server_integrity' in host) or (not creds.server and 'no_client_integrity' in host)): flags &= ~RequirementFlag.integrity self._host = host self._server = creds.server self._actual_flags = flags self._complete = False @property def complete(self): """Return whether or not GSS negotiation is complete""" return self._complete @property def actual_flags(self): """Return flags set on this context""" return self._actual_flags @property def initiator_name(self): """Return user principal associated with this context""" return 'user@TEST' @property def target_name(self): """Return host principal associated with this context""" return 'host@TEST' def step(self, token=None): """Perform next step in GSS security exchange""" token, complete = step(self._host, token) if complete: self._complete = True if token == b'error': raise GSSError(99, 99) elif token == b'errtok': raise GSSError(99, 99, token) else: return token def get_signature(self, _data): """Sign a block of data""" if 'sign_error' in self._host: raise GSSError(99, 99) return b'fail' if 'verify_error' in self._host else b'' def verify_signature(self, _data, sig): """Verify a signature for a block of data""" # pylint: disable=no-self-use if sig == b'fail': raise GSSError(99, 99) asyncssh-2.20.0/tests/keysign_stub.py000066400000000000000000000034641475467777400177130ustar00rootroot00000000000000# Copyright (c) 2018-2019 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub ssh-keysign module for unit tests""" import asyncssh from asyncssh.keysign import KEYSIGN_VERSION from asyncssh.packet import Byte, String, SSHPacket class SSHKeysignStub: """Stub class to replace process running ssh-keysign""" async def communicate(self, request): """Process SSH key signing request""" # pylint: disable=no-self-use packet = SSHPacket(request) request = packet.get_string() packet.check_end() packet = SSHPacket(request) version = packet.get_byte() _ = packet.get_uint32() # sock_fd data = packet.get_string() packet.check_end() if version == 0: return b'', b'' elif version == 1: return b'', b'invalid request' else: skey = asyncssh.load_keypairs('skey_ecdsa')[0] sig = skey.sign(data) return String(Byte(KEYSIGN_VERSION) + String(sig)), b'' async def create_subprocess_exec_stub(*_args, **_kwargs): """Return a stub for a subprocess running the ssh-keysign executable""" return SSHKeysignStub() asyncssh-2.20.0/tests/pkcs11_stub.py000066400000000000000000000155111475467777400173400ustar00rootroot00000000000000# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub PKCS#11 security key module for unit tests""" import asyncssh from asyncssh.asn1 import der_decode from asyncssh.pkcs11 import pkcs11_available from .util import get_test_key if pkcs11_available: # pragma: no branch import pkcs11 from pkcs11 import Attribute, KeyType, Mechanism, ObjectClass def _encode_public(key): """Stub to encode a PKCS#11 public key""" return key.encode_public() _encoders = {KeyType.RSA: _encode_public, KeyType.EC: _encode_public} _key_types = {'ssh-rsa': KeyType.RSA, 'ecdsa-sha2-nistp256': KeyType.EC, 'ecdsa-sha2-nistp384': KeyType.EC, 'ssh-ed25519': KeyType.EC_EDWARDS} _hash_algs = {Mechanism.SHA1_RSA_PKCS: 'sha1', Mechanism.SHA224_RSA_PKCS: 'sha224', Mechanism.SHA256_RSA_PKCS: 'sha256', Mechanism.SHA384_RSA_PKCS: 'sha384', Mechanism.SHA512_RSA_PKCS: 'sha512', Mechanism.ECDSA_SHA256: 'sha256', Mechanism.ECDSA_SHA384: 'sha384', Mechanism.ECDSA_SHA512: 'sha512'} class _PKCS11Key: """Stub for unit testing PKCS#11 keys""" def __init__(self, alg_name, key_type, key_label, key_id): self._priv = get_test_key(alg_name, key_id, comment=key_label) self.key_type = key_type self.label = key_label self.id = key_id def get_cert(self): """Return self-signed X.509 cert for this key""" return self._priv.generate_x509_user_certificate( self._priv, f'OU={self.label},CN=ckey') def get_public(self): """Return public key corresponding to this key""" return self._priv.convert_to_public() def encode_public(self): """Stub to encode a PKCS#11 public key""" return self._priv.export_public_key('pkcs8-der') def sign(self, data, mechanism): """Sign a block of data with this key""" sig = self._priv.sign_raw(data, _hash_algs[mechanism]) if self.key_type == KeyType.EC: r, s = der_decode(sig) length = (max(r.bit_length(), s.bit_length()) + 7) // 8 sig = r.to_bytes(length, 'big') + s.to_bytes(length, 'big') return sig class _PKCS11Cert: """Stub for unit testing PKCS#11 certificates""" def __init__(self, key): self._cert = key.get_cert() def __getitem__(self, key): if key == Attribute.VALUE: # pragma: no branch return self._cert.export_certificate('der') def get_cert(self): """Return cert object""" return self._cert class _PKCS11Session: """Stub for unit testing PKCS#11 security token sessions""" def __init__(self, keys, certs): self._keys = keys self._certs = certs def get_objects(self, attrs): """Return a list of PKCS#11 key or certificate objects""" label = attrs.get(Attribute.LABEL) obj_id = attrs.get(Attribute.OBJECT_ID) objs = self._keys if attrs[Attribute.CLASS] == \ ObjectClass.PRIVATE_KEY else self._certs for obj in objs: if label is not None and obj.label != label: continue if obj_id is not None and obj.id != obj_id: continue yield obj def close(self): """Close this session""" class _PKCS11Token: """Stub for unit testing PKCS#11 security tokens""" def __init__(self, label, serial, key_info): self.manufacturer_id = 'Test' self.label = label self.serial = serial self._keys = [] self._pubkeys = [] self._certs = [] for i, (alg, key_label) in enumerate(key_info, 1): self._add_key(alg, _key_types[alg], key_label, i) def _add_key(self, alg, key_type, key_label, key_id): """Add key to this token""" key = _PKCS11Key(alg, key_type, key_label, bytes((key_id,))) self._keys.append(key) self._pubkeys.append(key.get_public()) self._certs.append(_PKCS11Cert(key)) def get_pubkeys(self): """Return public keys associated with this token""" return self._pubkeys def get_certs(self): """Return X.509 certificates associated with this token""" return [cert.get_cert() for cert in self._certs] def open(self, user_pin=None): """Open a session to access a security token""" # pylint: disable=unused-argument return _PKCS11Session(self._keys, self._certs) class PKCS11Lib: """"Stub for unit testing PKCS#11 providers""" tokens = [] public_keys = [] certs = [] @classmethod def init_tokens(cls, token_info): """Initialize PKCS#11 token stubs for unit testing""" cls.tokens = [_PKCS11Token(*info) for info in token_info] cls.public_keys = sum((token.get_pubkeys() for token in cls.tokens), []) cls.certs = sum((token.get_certs() for token in cls.tokens), []) def __init__(self, provider): # pylint: disable=unused-argument pass def get_tokens(self, token_label=None, token_serial=None): """Return PKCS#11 security tokens""" for token in self.tokens: if token_label is not None and token.label != token_label: continue if token_serial is not None and token.serial != token_serial: continue yield token def get_pkcs11_public_keys(): """Return PKCS#11 public keys to trust in unit tests""" return PKCS11Lib.public_keys def get_pkcs11_certs(): """Return PKCS#11 X.509 certificates to trust in unit tests""" return PKCS11Lib.certs def stub_pkcs11(token_info): """Stub out PKCS#11 security token functions for unit testing""" old_lib = pkcs11.lib old_encoders = asyncssh.pkcs11.encoders pkcs11.lib = PKCS11Lib asyncssh.pkcs11.encoders = _encoders PKCS11Lib.init_tokens(token_info) return old_lib, old_encoders def unstub_pkcs11(old_lib, old_encoders): """Restore PKCS#11 security token functions""" pkcs11.lib = old_lib asyncssh.pkcs11.encoders = old_encoders asyncssh-2.20.0/tests/server.py000066400000000000000000000322051475467777400165060ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """SSH server used for unit tests""" import asyncio import os import shutil import signal import socket import subprocess import sys import asyncssh from asyncssh.misc import async_context_manager from .util import AsyncTestCase, all_tasks, current_task, get_test_key from .util import run, x509_available class Server(asyncssh.SSHServer): """Unit test SSH server""" def __init__(self): self._conn = None def connection_made(self, conn): """Record connection object for later use""" self._conn = conn def begin_auth(self, username): """Handle client authentication request""" return username != 'guest' class ServerTestCase(AsyncTestCase): """Unit test class which starts an SSH server and agent""" _server = None _server_addr = '' _server_port = 0 _agent_pid = None @classmethod @async_context_manager async def listen(cls, *, server_factory=(), options=None, **kwargs): """Create an SSH server for the tests to use""" if server_factory == (): server_factory = Server options = asyncssh.SSHServerConnectionOptions( options=options, server_factory=server_factory, gss_host=None, server_host_keys=['skey']) return await asyncssh.listen(port=0, family=socket.AF_INET, options=options, **kwargs) @classmethod @async_context_manager async def listen_reverse(cls, *, options=None, **kwargs): """Create a reverse SSH server for the tests to use""" options = asyncssh.SSHClientConnectionOptions( options=options, gss_host=None, known_hosts=(['skey.pub'], [], [])) return await asyncssh.listen_reverse(port=0, family=socket.AF_INET, options=options, **kwargs) @classmethod async def create_server(cls, server_factory=(), **kwargs): """Create an SSH server for the tests to use""" return await cls.listen(server_factory=server_factory, **kwargs) @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return NotImplemented # pragma: no cover # Pylint doesn't like mixed case method names, but this was chosen to # match the convention used in the unittest module. # pylint: disable=invalid-name @classmethod async def asyncSetUpClass(cls): """Set up keys, an SSH server, and an SSH agent for the tests to use""" # pylint: disable=too-many-statements ckey = get_test_key('ssh-rsa') ckey.write_private_key('ckey') ckey.write_private_key('ckey_encrypted', passphrase='passphrase') ckey.write_public_key('ckey.pub') ckey_ecdsa = get_test_key('ecdsa-sha2-nistp256') ckey_ecdsa.write_private_key('ckey_ecdsa') ckey_ecdsa.write_public_key('ckey_ecdsa.pub') ckey_cert = ckey.generate_user_certificate(ckey, 'name', principals=['ckey']) ckey_cert.write_certificate('ckey-cert.pub') skey = get_test_key('ssh-rsa', 1) skey.write_private_key('skey') skey.write_public_key('skey.pub') skey_ecdsa = get_test_key('ecdsa-sha2-nistp256', 1) skey_ecdsa.write_private_key('skey_ecdsa') skey_ecdsa.write_public_key('skey_ecdsa.pub') skey_cert = skey.generate_host_certificate( skey, 'name', principals=['127.0.0.1', 'localhost']) skey_cert.write_certificate('skey-cert.pub') skey_ecdsa_cert = skey_ecdsa.generate_host_certificate( skey_ecdsa, 'name', principals=['127.0.0.1', 'localhost']) skey_ecdsa_cert.write_certificate('skey_ecdsa-cert.pub') exp_cert = skey.generate_host_certificate(skey, 'name', valid_after='-2d', valid_before='-1d') skey.write_private_key('exp_skey') exp_cert.write_certificate('exp_skey-cert.pub') if x509_available: # pragma: no branch ckey_x509_self = ckey_ecdsa.generate_x509_user_certificate( ckey_ecdsa, 'OU=name', principals=['ckey']) ckey_ecdsa.write_private_key('ckey_x509_self') ckey_x509_self.append_certificate('ckey_x509_self', 'pem') ckey_x509_self.write_certificate('ckey_x509_self.pem', 'pem') ckey_x509_self.write_certificate('ckey_x509_self.pub') skey_x509_self = skey_ecdsa.generate_x509_host_certificate( skey_ecdsa, 'OU=name', principals=['127.0.0.1']) skey_ecdsa.write_private_key('skey_x509_self') skey_x509_self.append_certificate('skey_x509_self', 'pem') skey_x509_self.write_certificate('skey_x509_self.pem', 'pem') root_ca_key = get_test_key('ssh-rsa', 2) root_ca_key.write_private_key('root_ca_key') root_ca_cert = root_ca_key.generate_x509_ca_certificate( root_ca_key, 'OU=RootCA', ca_path_len=1) root_ca_cert.write_certificate('root_ca_cert.pem', 'pem') root_ca_cert.write_certificate('root_ca_cert.pub') int_ca_key = get_test_key('ssh-rsa', 3) int_ca_key.write_private_key('int_ca_key') int_ca_cert = root_ca_key.generate_x509_ca_certificate( int_ca_key, 'OU=IntCA', 'OU=RootCA', ca_path_len=0) int_ca_cert.write_certificate('int_ca_cert.pem', 'pem') ckey_x509_chain = int_ca_key.generate_x509_user_certificate( ckey, 'OU=name', 'OU=IntCA', principals=['ckey']) ckey.write_private_key('ckey_x509_chain') ckey_x509_chain.append_certificate('ckey_x509_chain', 'pem') int_ca_cert.append_certificate('ckey_x509_chain', 'pem') ckey_x509_chain.write_certificate('ckey_x509_partial.pem', 'pem') skey_x509_chain = int_ca_key.generate_x509_host_certificate( skey, 'OU=name', 'OU=IntCA', principals=['127.0.0.1']) skey.write_private_key('skey_x509_chain') skey_x509_chain.append_certificate('skey_x509_chain', 'pem') int_ca_cert.append_certificate('skey_x509_chain', 'pem') root_hash = root_ca_cert.x509_cert.subject_hash os.mkdir('cert_path') shutil.copy('root_ca_cert.pem', os.path.join('cert_path', root_hash + '.0')) # Intentional hash mismatch shutil.copy('int_ca_cert.pem', os.path.join('cert_path', root_hash + '.1')) for f in ('ckey', 'ckey_ecdsa', 'skey', 'exp_skey', 'skey_ecdsa'): os.chmod(f, 0o600) os.mkdir('.ssh', 0o700) os.mkdir(os.path.join('.ssh', 'crt'), 0o700) shutil.copy('ckey_ecdsa', os.path.join('.ssh', 'id_ecdsa')) shutil.copy('ckey_ecdsa.pub', os.path.join('.ssh', 'id_ecdsa.pub')) shutil.copy('ckey_encrypted', os.path.join('.ssh', 'id_rsa')) shutil.copy('ckey.pub', os.path.join('.ssh', 'id_rsa.pub')) shutil.copy('ckey-cert.pub', os.path.join('.ssh', 'id_rsa-cert.pub')) with open('authorized_keys', 'w') as auth_keys: with open('ckey.pub') as ckey_pub: shutil.copyfileobj(ckey_pub, auth_keys) with open('ckey_ecdsa.pub') as ckey_ecdsa_pub: shutil.copyfileobj(ckey_ecdsa_pub, auth_keys) auth_keys.write('cert-authority,principals="ckey",' 'permitopen=:* ') with open('ckey.pub') as ckey_pub: shutil.copyfileobj(ckey_pub, auth_keys) if x509_available: # pragma: no branch with open('authorized_keys_x509', 'w') as auth_keys_x509: with open('ckey_x509_self.pub') as ckey_self_pub: shutil.copyfileobj(ckey_self_pub, auth_keys_x509) auth_keys_x509.write('cert-authority,principals="ckey" ') with open('root_ca_cert.pub') as root_pub: shutil.copyfileobj(root_pub, auth_keys_x509) shutil.copy('skey_x509_self.pem', os.path.join('.ssh', 'ca-bundle.crt')) os.environ['LOGNAME'] = 'guest' os.environ['HOME'] = '.' os.environ['USERPROFILE'] = '.' cls._server = await cls.start_server() sock = cls._server.sockets[0] cls._server_addr = '127.0.0.1' cls._server_port = sock.getsockname()[1] host = f'[{cls._server_addr}]:{cls._server_port},localhost ' with open('known_hosts', 'w') as known_hosts: known_hosts.write(host) with open('skey.pub') as skey_pub: shutil.copyfileobj(skey_pub, known_hosts) known_hosts.write('@cert-authority * ') with open('skey.pub') as skey_pub: shutil.copyfileobj(skey_pub, known_hosts) shutil.copy('known_hosts', os.path.join('.ssh', 'known_hosts')) if 'DISPLAY' in os.environ: # pragma: no cover del os.environ['DISPLAY'] if 'SSH_ASKPASS' in os.environ: # pragma: no cover del os.environ['SSH_ASKPASS'] if 'SSH_AUTH_SOCK' in os.environ: # pragma: no cover del os.environ['SSH_AUTH_SOCK'] if 'XAUTHORITY' in os.environ: # pragma: no cover del os.environ['XAUTHORITY'] if sys.platform != 'win32': try: output = run('ssh-agent -a agent 2>/dev/null') except subprocess.CalledProcessError: # pragma: no cover cls._agent_pid = None else: cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) os.environ['SSH_AUTH_SOCK'] = 'agent' async with asyncssh.connect_agent() as agent: await agent.add_keys([ckey_ecdsa, (ckey, ckey_cert)]) else: # pragma: no cover cls._agent_pid = None with open('ssh-keysign', 'wb'): pass @classmethod async def asyncTearDownClass(cls): """Shut down test server and agent""" cls._server.close() await cls._server.wait_closed() tasks = all_tasks() tasks.remove(current_task()) await asyncio.gather(*tasks, return_exceptions=True) if cls._agent_pid: # pragma: no branch os.kill(cls._agent_pid, signal.SIGTERM) # pylint: enable=invalid-name def agent_available(self): """Return whether SSH agent is available""" return bool(self._agent_pid) @async_context_manager async def connect(self, host=(), port=(), gss_host=None, options=None, **kwargs): """Open a connection to the test server""" return await asyncssh.connect(host or self._server_addr, port or self._server_port, gss_host=gss_host, options=options, **kwargs) @async_context_manager async def connect_reverse(self, options=None, gss_host=None, **kwargs): """Create a connection to the test server""" options = asyncssh.SSHServerConnectionOptions(options, server_factory=Server, server_host_keys=['skey'], gss_host=gss_host) return await asyncssh.connect_reverse(self._server_addr, self._server_port, options=options, **kwargs) @async_context_manager async def run_client(self, sock, config=(), options=None, **kwargs): """Run an SSH client on an already-connected socket""" return await asyncssh.run_client(sock, config, options, **kwargs) @async_context_manager async def run_server(self, sock, config=(), options=None, **kwargs): """Run an SSH server on an already-connected socket""" options = asyncssh.SSHServerConnectionOptions(options, server_factory=Server, server_host_keys=['skey']) return await asyncssh.run_server(sock, config, options, **kwargs) async def create_connection(self, client_factory, **kwargs): """Create a connection to the test server""" conn = await self.connect(client_factory=client_factory, **kwargs) return conn, conn.get_owner() asyncssh-2.20.0/tests/sk_stub.py000066400000000000000000000314321475467777400166530ustar00rootroot00000000000000# Copyright (c) 2019-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Stub U2F security key module for unit tests""" from contextlib import contextmanager from hashlib import sha256 import asyncssh from asyncssh.asn1 import der_encode, der_decode from asyncssh.crypto import ECDSAPrivateKey, EdDSAPrivateKey from asyncssh.packet import Byte, UInt32 from asyncssh.sk import sk_available, sk_webauthn_prefix if sk_available: # pragma: no branch from asyncssh.sk import SSH_SK_ECDSA, SSH_SK_ED25519 from asyncssh.sk import SSH_SK_USER_PRESENCE_REQD from asyncssh.sk import APDU, ApduError, CtapError class _Registration: """Security key registration data""" def __init__(self, public_key, key_handle): self.public_key = public_key self.key_handle = key_handle class _AuthData: """Security key authentication data""" def __init__(self, flags, counter): self.flags = flags self.counter = counter class _Assertion: """Security key assertion""" def __init__(self, auth_data, signature): self.auth_data = auth_data self.signature = signature class _CredentialData: """Security key credential data""" def __init__(self, alg, public_value, key_handle): if alg == SSH_SK_ED25519: self.public_key = {-2: public_value} else: self.public_key = {-2: public_value[1:33], -3: public_value[33:]} self.public_key[3] = alg self.credential_id = key_handle class _CredentialAuthData: """Security key credential authentication data""" def __init__(self, credential_data): self.credential_data = credential_data class _Credential: """Security key credential""" def __init__(self, auth_data): self.auth_data = auth_data class _AttestationResponse: """Security key attestation response""" def __init__(self, attestation_object): self.attestation_object = attestation_object class _AuthenticatorData: """Security key authenticator data in aseertion""" def __init__(self, flags, counter): self.flags = flags self.counter = counter class _AssertionResponse: """Security key aseertion response""" def __init__(self, client_data, auth_data, signature): self.client_data = client_data self.authenticator_data = auth_data self.signature = signature class _AssertionSelection: """Security key assertion response list""" def __init__(self, assertions): self._assertions = assertions def get_response(self, index): """Return the assertion at specified index""" return self._assertions[index] class _CtapStub: """Stub for unit testing U2F security key support""" @staticmethod def _enroll(alg): """Enroll a new security key""" if alg == SSH_SK_ECDSA: key = ECDSAPrivateKey.generate(b'nistp256') else: key = EdDSAPrivateKey.generate(b'ed25519') key_handle = der_encode((alg, key.public_value, key.private_value)) return key.public_value, key_handle @staticmethod def _sign(message_hash, app_hash, key_handle, flags): """Sign a message with a security key""" alg, public_value, private_value = der_decode(key_handle) if alg == SSH_SK_ECDSA: key = ECDSAPrivateKey.construct( b'nistp256', public_value, int.from_bytes(private_value, 'big')) hash_alg = 'sha256' else: key = EdDSAPrivateKey.construct(b'ed25519', private_value) hash_alg = None counter = 0x12345678 sig = key.sign(app_hash + Byte(flags) + UInt32(counter) + message_hash, hash_alg) return flags, counter, sig class Ctap1(_CtapStub): """Stub for unit testing U2F security keys using CTAP version 1""" def __init__(self, dev): self.dev = dev self._polled = False def _poll(self): """Simulate needing to poll the security device""" if not self._polled: self._polled = True raise ApduError(APDU.USE_NOT_SATISFIED, b'') def register(self, client_data_hash, app_hash): """Enroll a new security key using CTAP version 1""" # pylint: disable=unused-argument self._poll() if self.dev.error == 'err': raise ApduError(0, b'') public_key, key_handle = self._enroll(SSH_SK_ECDSA) return _Registration(public_key, key_handle) def authenticate(self, message_hash, app_hash, key_handle): """Sign a message with a security key using CTAP version 1""" self._poll() if self.dev.error == 'nocred': raise ApduError(APDU.WRONG_DATA, b'') elif self.dev.error == 'err': raise ApduError(0, b'') flags, counter, sig = self._sign(message_hash, app_hash, key_handle, SSH_SK_USER_PRESENCE_REQD) return Byte(flags) + UInt32(counter) + sig class Ctap2(_CtapStub): """Stub for unit testing U2F security keys using CTAP version 2""" def __init__(self, dev): if dev.version != 2: raise ValueError('Wrong protocol version') self.dev = dev def make_credential(self, client_data_hash, rp, user, key_params, options, pin_uv_param, pin_uv_protocol): """Enroll a new security key using CTAP version 2""" # pylint: disable=unused-argument alg = key_params[0]['alg'] if self.dev.error == 'err': raise CtapError(CtapError.ERR.INVALID_CREDENTIAL) elif self.dev.error == 'pinreq': raise CtapError(CtapError.ERR.PUAT_REQUIRED) elif self.dev.error == 'badpin': raise CtapError(CtapError.ERR.PIN_INVALID) public_key, key_handle = self._enroll(alg) cdata = _CredentialData(alg, public_key, key_handle) if options.get('rk'): cred_mgmt = CredentialManagement(self) cred_mgmt.add_resident_key(user['name'], cdata) return _Credential(_CredentialAuthData(cdata)) def get_assertions(self, application, message_hash, allow_creds, options): """Sign a message with a security key using CTAP version 2""" app_hash = sha256(application.encode()).digest() key_handle = allow_creds[0]['id'] flags = SSH_SK_USER_PRESENCE_REQD if options['up'] else 0 if self.dev.error == 'nocred': raise CtapError(CtapError.ERR.NO_CREDENTIALS) elif self.dev.error == 'err': raise CtapError(CtapError.ERR.INVALID_CREDENTIAL) flags, counter, sig = self._sign(message_hash, app_hash, key_handle, flags) return [_Assertion(_AuthData(flags, counter), sig)] class WindowsClient(_CtapStub): """Stub for unit testing U2F security keys via Windows WebAuthn""" def __init__(self, origin, verify): self._origin = origin self._verify = verify def make_credential(self, options): """Make a credential using Windows WebAuthN API""" self._verify(options['rp']['id'], self._origin) alg = options['pubKeyCredParams'][0]['alg'] public_key, key_handle = self._enroll(alg) cdata = _CredentialData(alg, public_key, key_handle) return _AttestationResponse(_Credential(_CredentialAuthData(cdata))) def get_assertion(self, options): """Get assertion using Windows WebAuthN API""" self._verify(options['rpId'], self._origin) challenge = options['challenge'] application = options['rpId'] key_handle = options['allowCredentials'][0]['id'] flags = SSH_SK_USER_PRESENCE_REQD app_hash = sha256(application.encode()).digest() data = sk_webauthn_prefix(challenge, application) + b'}' message_hash = sha256(data).digest() flags, counter, sig = self._sign(message_hash, app_hash, key_handle, flags) auth_data = _AuthenticatorData(flags, counter) assertion = _AssertionResponse(data, auth_data, sig) return _AssertionSelection([assertion]) class CredentialManagement: """Stub for unit testing U2F security device resident keys""" class RESULT: """Credential management result keys""" USER = 6 CREDENTIAL_ID = 7 PUBLIC_KEY = 8 def __init__(self, ctap, pin_uv_protocol=None, pin_uv_token=None): # pylint: disable=unused-argument self.dev = ctap.dev if self.dev.error == 'err': raise CtapError(CtapError.ERR.INVALID_CREDENTIAL) elif self.dev.error == 'nocred': raise CtapError(CtapError.ERR.NO_CREDENTIALS) elif self.dev.error == 'nopin': raise CtapError(CtapError.ERR.PIN_NOT_SET) elif self.dev.error == 'badpin': raise CtapError(CtapError.ERR.PIN_INVALID) def enumerate_creds(self, app_hash): """Enumerate resident credentials""" # pylint: disable=unused-argument return self.dev.resident_keys def add_resident_key(self, user, cdata): """Add a resident key to a device""" self.dev.resident_keys.append( {self.RESULT.USER: {'id': b'', 'name': user, 'displayName': user}, self.RESULT.CREDENTIAL_ID: {'id': cdata.credential_id}, self.RESULT.PUBLIC_KEY: cdata.public_key}) class Device: """Stub for unit testing U2F security devices""" def __init__(self, version): self.version = version self.error = None self.resident_keys = [] def close(self): """Close this security device""" class ClientPin: """Stub for unit testing U2F security device PINs""" def __init__(self, ctap, protocol): # pylint: disable=unused-argument pass def get_pin_token(self, pin): """Return a PIN token""" # pylint: disable=no-self-use return pin class PinProtocolV1: """Stub for unit testing U2F pin protocol version 1""" VERSION = 1 def stub_sk(devices, use_webauthn=False): """Stub out security key module functions for unit testing""" devices = list(map(Device, devices)) old_ctap1 = asyncssh.sk.Ctap1 old_ctap2 = asyncssh.sk.Ctap2 old_windows_client = asyncssh.sk.WindowsClient old_use_webauthn = asyncssh.sk.sk_use_webauthn old_client_pin = asyncssh.sk.ClientPin old_cred_mgmt = asyncssh.sk.CredentialManagement old_pin_proto = asyncssh.sk.PinProtocolV1 old_list_devices = asyncssh.sk.CtapHidDevice.list_devices asyncssh.sk.Ctap1 = Ctap1 asyncssh.sk.Ctap2 = Ctap2 asyncssh.sk.WindowsClient = WindowsClient asyncssh.sk.sk_use_webauthn = use_webauthn asyncssh.sk_ecdsa.sk_use_webauthn = use_webauthn asyncssh.sk.ClientPin = ClientPin asyncssh.sk.CredentialManagement = CredentialManagement asyncssh.sk.PinProtocolV1 = PinProtocolV1 asyncssh.sk.CtapHidDevice.list_devices = lambda: iter(devices) return old_ctap1, old_ctap2, old_windows_client, old_use_webauthn, \ old_client_pin, old_cred_mgmt, old_pin_proto, old_list_devices def unstub_sk(old_ctap1, old_ctap2, old_windows_client, old_use_webauthn, old_client_pin, old_cred_mgmt, old_pin_proto, old_list_devices): """Restore security key module functions""" asyncssh.sk.Ctap1 = old_ctap1 asyncssh.sk.Ctap2 = old_ctap2 asyncssh.sk.WindowsClient = old_windows_client asyncssh.sk.sk_use_webauthn = old_use_webauthn asyncssh.sk_ecdsa.sk_use_webauthn = old_use_webauthn asyncssh.sk.ClientPin = old_client_pin asyncssh.sk.CredentialManagement = old_cred_mgmt asyncssh.sk.PinProtocolV1 = old_pin_proto asyncssh.sk.CtapHidDevice.list_devices = old_list_devices @contextmanager def patch_sk(devices): """Context manager to stub out security key functions""" old_sk_hooks = stub_sk(devices) try: yield finally: unstub_sk(*old_sk_hooks) @contextmanager def sk_error(err): """Set security key error condition""" try: for dev in asyncssh.sk.CtapHidDevice.list_devices(): dev.error = err yield finally: for dev in asyncssh.sk.CtapHidDevice.list_devices(): dev.error = None asyncssh-2.20.0/tests/sspi_stub.py000066400000000000000000000075031475467777400172160ustar00rootroot00000000000000# Copyright (c) 2017-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation # Georg Sauthoff - fix for "setup.py test" command on non-Windows """Stub SSPI module for unit tests""" import sys from .gss_stub import step if sys.platform == 'win32': # pragma: no cover from asyncssh.gss_win32 import ASC_RET_INTEGRITY, ISC_RET_INTEGRITY from asyncssh.gss_win32 import SECPKG_ATTR_NATIVE_NAMES, SSPIError class SSPIBuffer: """Stub class for SSPI buffer""" def __init__(self, data): self._data = data @property def Buffer(self): # pylint: disable=invalid-name """Return the data in the buffer""" return self._data class SSPIContext: """Stub class for SSPI security context""" def QueryContextAttributes(self, attr): # pylint: disable=invalid-name """Return principal information associated with this context""" # pylint: disable=no-self-use if attr == SECPKG_ATTR_NATIVE_NAMES: return ['user@TEST', 'host@TEST'] else: # pragma: no cover return None class SSPIAuth: """Stub class for SSPI authentication""" def __init__(self, _package=None, spn=None, targetspn=None, scflags=None): host = spn or targetspn if 'init_error' in host: raise SSPIError('Authentication initialization error') if targetspn and 'no_client_integrity' in host: scflags &= ~ISC_RET_INTEGRITY elif spn and 'no_server_integrity' in host: scflags &= ~ASC_RET_INTEGRITY self._host = host[5:] self._flags = scflags self._ctxt = SSPIContext() self._complete = False self._error = False @property def authenticated(self): """Return whether authentication is complete""" return self._complete @property def ctxt(self): """Return authentication context""" return self._ctxt @property def ctxt_attr(self): """Return authentication flags""" return self._flags def reset(self): """Reset SSPI security context""" self._complete = False def authorize(self, token): """Perform next step in SSPI authentication""" if self._error: self._error = False raise SSPIError('Token authentication error') new_token, complete = step(self._host, token) if complete: self._complete = True if new_token in (b'error', b'errtok'): if token: raise SSPIError('Token authentication error') else: self._error = True return True, [SSPIBuffer(b'')] else: return bool(new_token), [SSPIBuffer(new_token)] def sign(self, data): """Sign a block of data""" # pylint: disable=no-self-use,unused-argument if 'sign_error' in self._host: raise SSPIError('Signing error') return b'fail' if 'verify_error' in self._host else b'' def verify(self, data, sig): """Verify a signature for a block of data""" # pylint: disable=no-self-use,unused-argument if sig == b'fail': raise SSPIError('Signature verification error') asyncssh-2.20.0/tests/test_agent.py000066400000000000000000000365361475467777400173500ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH ssh-agent client""" import asyncio import functools import os from pathlib import Path import signal import subprocess import unittest import asyncssh from asyncssh.agent import SSH_AGENT_SUCCESS, SSH_AGENT_FAILURE from asyncssh.agent import SSH_AGENT_IDENTITIES_ANSWER from asyncssh.crypto import ed25519_available from asyncssh.packet import Byte, String, UInt32 from .sk_stub import sk_available, patch_sk from .util import AsyncTestCase, asynctest, get_test_key, run, try_remove def agent_test(func): """Decorator for running SSH agent tests""" @asynctest @functools.wraps(func) async def agent_wrapper(self): """Run a test after connecting to an SSH agent""" async with asyncssh.connect_agent() as agent: await agent.remove_all() await func(self, agent) return agent_wrapper class _Agent: """Mock SSH agent for testing error cases""" def __init__(self, response): self._response = b'' if response is None else String(response) self._path = None self._server = None async def start(self, path): """Start a new mock SSH agent""" self._path = path # pylint doesn't think start_unix_server exists # pylint: disable=no-member self._server = \ await asyncio.start_unix_server(self.process_request, path) async def process_request(self, reader, writer): """Process a request sent to the mock SSH agent""" await reader.readexactly(4) writer.write(self._response) writer.close() async def stop(self): """Shut down the mock SSH agent""" self._server.close() await self._server.wait_closed() try_remove(self._path) class _TestAgent(AsyncTestCase): """Unit tests for AsyncSSH API""" _agent_pid = None _public_keys = {} @staticmethod def set_askpass(status): """Set return status for ssh-askpass""" with open('ssh-askpass', 'w') as f: f.write(f'#!/bin/sh\nexit {status}\n') os.chmod('ssh-askpass', 0o755) # Pylint doesn't like mixed case method names, but this was chosen to # match the convention used in the unittest module. # pylint: disable=invalid-name @classmethod async def asyncSetUpClass(cls): """Set up keys and an SSH server for the tests to use""" os.environ['DISPLAY'] = ' ' os.environ['HOME'] = '.' os.environ['SSH_ASKPASS'] = os.path.join(os.getcwd(), 'ssh-askpass') try: output = run('ssh-agent -a agent 2>/dev/null') except subprocess.CalledProcessError: # pragma: no cover return cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) os.environ['SSH_AUTH_SOCK'] = 'agent' @classmethod async def asyncTearDownClass(cls): """Shut down agents""" if cls._agent_pid: # pragma: no branch os.kill(cls._agent_pid, signal.SIGTERM) def setUp(self): """Skip unit tests if we couldn't start an agent""" if not self._agent_pid: # pragma: no cover self.skipTest('ssh-agent not available') # pylint: enable=invalid-name @agent_test async def test_connection(self, agent): """Test opening a connection to the agent""" self.assertIsNotNone(agent) @asynctest async def test_connection_failed(self): """Test failure in opening a connection to the agent""" with self.assertRaises(OSError): await asyncssh.connect_agent('xxx') @asynctest async def test_no_auth_sock(self): """Test failure when no auth sock is set""" del os.environ['SSH_AUTH_SOCK'] with self.assertRaises(OSError): await asyncssh.connect_agent() os.environ['SSH_AUTH_SOCK'] = 'agent' @agent_test async def test_get_keys(self, agent): """Test getting keys from the agent""" keys = await agent.get_keys() self.assertEqual(len(keys), len(self._public_keys)) @agent_test async def test_sign(self, agent): """Test signing a block of data using the agent""" algs = ['ssh-rsa', 'ecdsa-sha2-nistp256'] if ed25519_available: # pragma: no branch algs.append('ssh-ed25519') for alg_name in algs: key = get_test_key(alg_name) pubkey = key.convert_to_public() cert = key.generate_user_certificate(key, 'name') await agent.add_keys([(key, cert)]) agent_keys = await agent.get_keys() for agent_key in agent_keys: agent_key.set_sig_algorithm(agent_key.sig_algorithms[0]) sig = await agent_key.sign_async(b'test') self.assertTrue(pubkey.verify(b'test', sig)) await agent.remove_keys(agent_keys) @agent_test async def test_set_certificate(self, agent): """Test setting certificate on an existing keypair""" key = get_test_key('ssh-rsa') cert = key.generate_user_certificate(key, 'name') key2 = get_test_key('ssh-rsa', 1) cert2 = key.generate_user_certificate(key2, 'name') await agent.add_keys([key]) agent_key = (await agent.get_keys())[0] agent_key.set_certificate(cert) self.assertEqual(agent_key.public_data, cert.public_data) with self.assertRaises(ValueError): asyncssh.load_keypairs([(agent_key, cert2)]) agent_key = (await agent.get_keys())[0] agent_key = asyncssh.load_keypairs([(agent_key, cert)])[0] self.assertEqual(agent_key.public_data, cert.public_data) with self.assertRaises(ValueError): asyncssh.load_keypairs([(agent_key, cert2)]) @agent_test async def test_reconnect(self, agent): """Test reconnecting to the agent after closing it""" key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() async with agent: await agent.add_keys([key]) agent_keys = await agent.get_keys() for agent_key in agent_keys: sig = await agent_key.sign_async(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @agent_test async def test_add_remove_keys(self, agent): """Test adding and removing keys""" await agent.add_keys() agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 0) key = get_test_key('ssh-rsa') await agent.add_keys([key]) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 1) await agent.remove_keys(agent_keys) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 0) await agent.add_keys([key]) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 1) await agent_keys[0].remove() agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 0) await agent.add_keys([key], lifetime=1) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 1) await asyncio.sleep(2) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 0) @agent_test async def test_add_nonlocal(self, agent): """Test failure when adding a non-local key to an agent""" key = get_test_key('ssh-rsa') async with agent: await agent.add_keys([key]) agent_keys = await agent.get_keys() with self.assertRaises(asyncssh.KeyImportError): await agent.add_keys(agent_keys) @agent_test async def test_add_keys_failure(self, agent): """Test failure adding keys to the agent""" os.mkdir('.ssh', 0o700) key = get_test_key('ssh-rsa') key.write_private_key(Path('.ssh', 'id_rsa')) try: mock_agent = _Agent(Byte(SSH_AGENT_FAILURE)) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: async with agent: await agent.add_keys() async with agent: with self.assertRaises(ValueError): await agent.add_keys([key]) finally: await mock_agent.stop() os.remove(os.path.join('.ssh', 'id_rsa')) os.rmdir('.ssh') @unittest.skipUnless(sk_available, 'security key support not available') @patch_sk([2]) @asynctest async def test_add_sk_keys(self): """Test adding U2F security keys""" key = get_test_key('sk-ecdsa-sha2-nistp256@openssh.com') cert = key.generate_user_certificate(key, 'test') mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS)) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: for keypair in asyncssh.load_keypairs([key, (key, cert)]): async with agent: self.assertIsNone(await agent.add_keys([keypair])) async with agent: with self.assertRaises(asyncssh.KeyExportError): await agent.add_keys([key.convert_to_public()]) await mock_agent.stop() @unittest.skipUnless(sk_available, 'security key support not available') @patch_sk([2]) @asynctest async def test_get_sk_keys(self): """Test getting U2F security keys""" key = get_test_key('sk-ecdsa-sha2-nistp256@openssh.com') cert = key.generate_user_certificate(key, 'test') mock_agent = _Agent(Byte(SSH_AGENT_IDENTITIES_ANSWER) + UInt32(2) + String(key.public_data) + String('') + String(cert.public_data) + String('')) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: await agent.get_keys() await mock_agent.stop() @asynctest async def test_add_remove_smartcard_keys(self): """Test adding and removing smart card keys""" mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS)) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: result = await agent.add_smartcard_keys('provider') self.assertIsNone(result) await mock_agent.stop() mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS)) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: result = await agent.remove_smartcard_keys('provider') self.assertIsNone(result) await mock_agent.stop() @agent_test async def test_confirm(self, agent): """Test confirmation of key""" key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() await agent.add_keys([key], confirm=True) agent_keys = await agent.get_keys() self.set_askpass(1) for agent_key in agent_keys: with self.assertRaises(ValueError): sig = await agent_key.sign_async(b'test') self.set_askpass(0) for agent_key in agent_keys: sig = await agent_key.sign_async(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @agent_test async def test_lock(self, agent): """Test lock and unlock""" key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() await agent.add_keys([key]) agent_keys = await agent.get_keys() await agent.lock('passphrase') for agent_key in agent_keys: with self.assertRaises(ValueError): await agent_key.sign_async(b'test') await agent.unlock('passphrase') for agent_key in agent_keys: sig = await agent_key.sign_async(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @asynctest async def test_query_extensions(self): """Test query of supported extensions""" mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS) + String('xxx')) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: extensions = await agent.query_extensions() self.assertEqual(extensions, ['xxx']) await mock_agent.stop() mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS) + String(b'\xff')) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: with self.assertRaises(ValueError): await agent.query_extensions() await mock_agent.stop() mock_agent = _Agent(Byte(SSH_AGENT_FAILURE)) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: extensions = await agent.query_extensions() self.assertEqual(extensions, []) await mock_agent.stop() mock_agent = _Agent(b'\xff') await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: with self.assertRaises(ValueError): await agent.query_extensions() await mock_agent.stop() @agent_test async def test_unknown_key(self, agent): """Test failure when signing with an unknown key""" key = get_test_key('ssh-rsa') with self.assertRaises(ValueError): await agent.sign(key.public_data, b'test') @agent_test async def test_double_close(self, agent): """Test calling close more than once on the agent""" self.assertIsNotNone(agent) agent.close() @asynctest async def test_errors(self): """Test getting error responses from SSH agent""" key = get_test_key('ssh-rsa') keypair = asyncssh.load_keypairs(key)[0] for response in (None, b'', Byte(SSH_AGENT_FAILURE), b'\xff'): mock_agent = _Agent(response) await mock_agent.start('mock_agent') async with asyncssh.connect_agent('mock_agent') as agent: for request in (agent.get_keys(), agent.sign(b'xxx', b'test'), agent.add_keys([key]), agent.add_smartcard_keys('xxx'), agent.remove_keys([keypair]), agent.remove_smartcard_keys('xxx'), agent.remove_all(), agent.lock('passphrase'), agent.unlock('passphrase')): async with agent: with self.assertRaises(ValueError): await request await mock_agent.stop() asyncssh-2.20.0/tests/test_asn1.py000066400000000000000000000205641475467777400171060ustar00rootroot00000000000000# Copyright (c) 2015-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for ASN.1 encoding and decoding""" import codecs import unittest from asyncssh.asn1 import der_encode, der_decode from asyncssh.asn1 import ASN1EncodeError, ASN1DecodeError from asyncssh.asn1 import BitString, IA5String, ObjectIdentifier from asyncssh.asn1 import RawDERObject, TaggedDERObject, PRIVATE class _TestASN1(unittest.TestCase): """Unit tests for ASN.1 module""" tests = [ (None, '0500'), (False, '010100'), (True, '0101ff'), (0, '020100'), (127, '02017f'), (128, '02020080'), (256, '02020100'), (-128, '020180'), (-129, '0202ff7f'), (-256, '0202ff00'), (b'', '0400'), (b'\0', '040100'), (b'abc', '0403616263'), (127*b'\0', '047f' + 127*'00'), (128*b'\0', '048180' + 128*'00'), ('', '0c00'), ('\0', '0c0100'), ('abc', '0c03616263'), ((), '3000'), ((1,), '3003020101'), ((1, 2), '3006020101020102'), (frozenset(), '3100'), (frozenset({1}), '3103020101'), (frozenset({1, 2}), '3106020101020102'), (frozenset({-128, 127}), '310602017f020180'), (BitString(b''), '030100'), (BitString(b'\0', 7), '03020700'), (BitString(b'\x80', 7), '03020780'), (BitString(b'\x80', named=True), '03020780'), (BitString(b'\x81', named=True), '03020081'), (BitString(b'\x81\x00', named=True), '03020081'), (BitString(b'\x80', 6), '03020680'), (BitString(b'\x80'), '03020080'), (BitString(b'\x80\x00', 7), '0303078000'), (BitString(''), '030100'), (BitString('0'), '03020700'), (BitString('1'), '03020780'), (BitString('10'), '03020680'), (BitString('10000000'), '03020080'), (BitString('10000001'), '03020081'), (BitString('100000000'), '0303078000'), (IA5String(b''), '1600'), (IA5String(b'\0'), '160100'), (IA5String(b'abc'), '1603616263'), (ObjectIdentifier('0.0'), '060100'), (ObjectIdentifier('1.2'), '06012a'), (ObjectIdentifier('1.2.840'), '06032a8648'), (ObjectIdentifier('2.5'), '060155'), (ObjectIdentifier('2.40'), '060178'), (TaggedDERObject(0, None), 'a0020500'), (TaggedDERObject(1, None), 'a1020500'), (TaggedDERObject(32, None), 'bf20020500'), (TaggedDERObject(128, None), 'bf8100020500'), (TaggedDERObject(0, None, PRIVATE), 'e0020500'), (RawDERObject(0, b'', PRIVATE), 'c000') ] encode_errors = [ (range, [1]), # Unsupported type (BitString, [b'', 1]), # Bit count with empty value (BitString, [b'', -1]), # Invalid unused bit count (BitString, [b'', 8]), # Invalid unused bit count (BitString, [b'0c0', 7]), # Unused bits not zero (BitString, ['', 1]), # Unused bits with string (BitString, [0]), # Invalid type (ObjectIdentifier, ['']), # Too few components (ObjectIdentifier, ['1']), # Too few components (ObjectIdentifier, ['-1.1']), # First component out of range (ObjectIdentifier, ['3.1']), # First component out of range (ObjectIdentifier, ['0.-1']), # Second component out of range (ObjectIdentifier, ['0.40']), # Second component out of range (ObjectIdentifier, ['1.-1']), # Second component out of range (ObjectIdentifier, ['1.40']), # Second component out of range (ObjectIdentifier, ['1.1.-1']), # Later component out of range (TaggedDERObject, [0, None, 99]), # Invalid ASN.1 class (RawDERObject, [0, None, 99]), # Invalid ASN.1 class ] decode_errors = [ '', # Incomplete data '01', # Incomplete data '0101', # Incomplete data '1f00', # Incomplete data '1f8000', # Incomplete data '1f0001', # Incomplete data '1f80', # Incomplete tag '0180', # Indefinite length '050001', # Unexpected bytes at end '2500', # Constructed null '050100', # Null with content '2100', # Constructed boolean '010102', # Boolean value not 0x00/0xff '2200', # Constructed integer '2400', # Constructed octet string '2c00', # Constructed UTF-8 string '1000', # Non-constructed sequence '1100', # Non-constructed set '2300', # Constructed bit string '03020800', # Invalid unused bit count '3600', # Constructed IA5 string '2600', # Constructed object identifier '0600', # Empty object identifier '06020080', # Invalid component '06020081' # Incomplete component ] def test_asn1(self): """Unit test ASN.1 module""" for value, data in self.tests: data = codecs.decode(data, 'hex') with self.subTest(msg='encode', value=value): self.assertEqual(der_encode(value), data) with self.subTest(msg='decode', data=data): decoded_value = der_decode(data) self.assertEqual(decoded_value, value) self.assertEqual(hash(decoded_value), hash(value)) self.assertEqual(repr(decoded_value), repr(value)) self.assertEqual(str(decoded_value), str(value)) for cls, args in self.encode_errors: with self.subTest(msg='encode error', cls=cls.__name__, args=args): with self.assertRaises(ASN1EncodeError): der_encode(cls(*args)) for data in self.decode_errors: with self.subTest(msg='decode error', data=data): with self.assertRaises(ASN1DecodeError): der_decode(codecs.decode(data, 'hex')) asyncssh-2.20.0/tests/test_auth.py000066400000000000000000000621471475467777400172100ustar00rootroot00000000000000# Copyright (c) 2015-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for authentication""" import asyncio import inspect import unittest import asyncssh from asyncssh.auth import MSG_USERAUTH_PK_OK, lookup_client_auth from asyncssh.auth import get_supported_server_auth_methods, lookup_server_auth from asyncssh.auth import MSG_USERAUTH_GSSAPI_RESPONSE from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE from asyncssh.constants import MSG_USERAUTH_SUCCESS from asyncssh.gss import GSSClient, GSSServer from asyncssh.packet import SSHPacket, Boolean, Byte, NameList, String from .util import asynctest, gss_available, patch_gss from .util import AsyncTestCase, ConnectionStub, get_test_key class _AuthConnectionStub(ConnectionStub): """Connection stub class to test authentication""" def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError async def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError def _get_userauth_request_packet(self, method, args): """Get packet data for a user authentication request""" # pylint: disable=no-self-use return b''.join((Byte(MSG_USERAUTH_REQUEST), String('user'), String('service'), String(method)) + args) def get_userauth_request_data(self, method, *args): """Get signature data for a user authentication request""" return String('') + self._get_userauth_request_packet(method, args) def send_userauth_packet(self, pkttype, *args, handler=None, trivial=True): """Send a user authentication packet""" # pylint: disable=unused-argument self.send_packet(pkttype, *args, handler=handler) class _AuthClientStub(_AuthConnectionStub): """Stub class for client connection""" @classmethod def make_pair(cls, method, **kwargs): """Make a client and server connection pair to test authentication""" client_conn = cls(method, **kwargs) return client_conn, client_conn.get_peer() def __init__(self, method, gss_host=None, override_gss_mech=False, host_based_auth=False, client_host_key=None, client_host_cert=None, public_key_auth=False, client_key=None, client_cert=None, override_pk_ok=False, password_auth=False, password=None, password_change=NotImplemented, password_change_prompt=None, kbdint_auth=False, kbdint_submethods=None, kbdint_challenge=None, kbdint_response=None, success=False): super().__init__(_AuthServerStub(self, gss_host, override_gss_mech, host_based_auth, public_key_auth, override_pk_ok, password_auth, password_change_prompt, kbdint_auth, kbdint_challenge, success), False) self._gss = GSSClient(gss_host, None, False) if gss_host else None self._client_host_key = client_host_key self._client_host_cert = client_host_cert self._client_key = client_key self._client_cert = client_cert self._password = password self._password_change = password_change self._password_changed = None self._kbdint_submethods = kbdint_submethods self._kbdint_response = kbdint_response self._auth_waiter = asyncio.Future() self._auth = lookup_client_auth(self, method) if self._auth is None: self.close() raise ValueError('Invalid auth method') def connection_lost(self, exc=None): """Handle the closing of a connection""" if exc: self._auth_waiter.set_exception(exc) self.close() async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() if pkttype == MSG_USERAUTH_FAILURE: _ = packet.get_namelist() partial_success = packet.get_boolean() packet.check_end() if partial_success: # pragma: no cover # Partial success not implemented yet self._auth.auth_succeeded() else: self._auth.auth_failed() self._auth_waiter.set_result((False, self._password_changed)) self._auth = None self._auth_waiter = None elif pkttype == MSG_USERAUTH_SUCCESS: packet.check_end() self._auth.auth_succeeded() self._auth_waiter.set_result((True, self._password_changed)) self._auth = None self._auth_waiter = None else: result = self._auth.process_packet(pkttype, None, packet) if inspect.isawaitable(result): await result async def get_auth_result(self): """Return the result of the authentication""" return await self._auth_waiter def try_next_auth(self, *, next_method=False): """Handle a request to move to another form of auth""" # pylint: disable=unused-argument # Report that the current auth attempt failed self._auth_waiter.set_result((False, self._password_changed)) self._auth = None self._auth_waiter = None async def send_userauth_request(self, method, *args, key=None, trivial=True): """Send a user authentication request""" packet = self._get_userauth_request_packet(method, args) if key: packet += String(key.sign(String('') + packet)) self.send_userauth_packet(MSG_USERAUTH_REQUEST, packet[1:], trivial=trivial) def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def gss_mic_auth_requested(self): """Return whether to allow GSS MIC authentication or not""" return bool(self._gss) async def host_based_auth_requested(self): """Return a host key pair, host, and user to authenticate with""" if self._client_host_key: keypair = asyncssh.load_keypairs((self._client_host_key, self._client_host_cert))[0] else: keypair = None return keypair, 'host', 'user' async def public_key_auth_requested(self): """Return key to use for public key authentication""" if self._client_key: return asyncssh.load_keypairs((self._client_key, self._client_cert))[0] else: return None async def password_auth_requested(self): """Return password to send for password authentication""" return self._password async def password_change_requested(self, _prompt, _lang): """Return old & new passwords for password change""" if self._password_change is True: return 'password', 'new_password' else: return self._password_change def password_changed(self): """Handle a successful password change""" self._password_changed = True def password_change_failed(self): """Handle an unsuccessful password change""" self._password_changed = False async def kbdint_auth_requested(self): """Return submethods to send for keyboard-interactive authentication""" return self._kbdint_submethods async def kbdint_challenge_received(self, _name, _instruction, _lang, _prompts): """Return responses to keyboard-interactive challenge""" if self._kbdint_response is True: return ('password',) else: return self._kbdint_response class _AuthServerStub(_AuthConnectionStub): """Stub class for server connection""" def __init__(self, peer=None, gss_host=None, override_gss_mech=False, host_based_auth=False, public_key_auth=False, override_pk_ok=False, password_auth=False, password_change_prompt=None, kbdint_auth=False, kbdint_challenge=False, success=False): super().__init__(peer, True) self._gss = GSSServer(gss_host, None) if gss_host else None self._override_gss_mech = override_gss_mech self._host_based_auth = host_based_auth self._public_key_auth = public_key_auth self._override_pk_ok = override_pk_ok self._password_auth = password_auth self._password_change_prompt = password_change_prompt self._kbdint_auth = kbdint_auth self._kbdint_challenge = kbdint_challenge self._success = success self._auth = None def connection_lost(self, exc=None): """Handle the closing of a connection""" if self._peer: self._peer.connection_lost(exc) self.close() async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() if pkttype == MSG_USERAUTH_REQUEST: _ = packet.get_string() # username _ = packet.get_string() # service method = packet.get_string() if self._auth: self._auth.cancel() if self._override_gss_mech: self.send_userauth_packet(MSG_USERAUTH_GSSAPI_RESPONSE, String('mismatch')) elif self._override_pk_ok: self.send_userauth_packet(MSG_USERAUTH_PK_OK, String(''), String('')) else: self._auth = lookup_server_auth(self, 'user', method, packet) else: result = self._auth.process_packet(pkttype, None, packet) if inspect.isawaitable(result): await result def send_userauth_failure(self, partial_success): """Send a user authentication failure response""" self._auth = None self.send_userauth_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(partial_success)) def send_userauth_success(self): """Send a user authentication success response""" self._auth = None self.send_userauth_packet(MSG_USERAUTH_SUCCESS) def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss def gss_kex_auth_supported(self): """Return whether or not GSS key exchange authentication is supported""" return bool(self._gss) def gss_mic_auth_supported(self): """Return whether or not GSS MIC authentication is supported""" return bool(self._gss) async def validate_gss_principal(self, _username, _user_principal, _host_principal): """Validate the GSS principal name for the specified user""" return self._success def host_based_auth_supported(self): """Return whether or not host-based authentication is supported""" return self._host_based_auth async def validate_host_based_auth(self, _username, _key_data, _client_host, _client_username, _msg, _signature): """Validate host based authentication for the specified host and user""" return self._success def public_key_auth_supported(self): """Return whether or not public key authentication is supported""" return self._public_key_auth async def validate_public_key(self, _username, _key_data, _msg, _signature): """Validate public key""" return self._success def password_auth_supported(self): """Return whether or not password authentication is supported""" return self._password_auth async def validate_password(self, _username, _password): """Validate password""" if self._password_change_prompt: raise asyncssh.PasswordChangeRequired(self._password_change_prompt) else: return self._success async def change_password(self, _username, _old_password, _new_password): """Validate password""" return self._success def kbdint_auth_supported(self): """Return whether or not keyboard-interactive authentication is supported""" return self._kbdint_auth async def get_kbdint_challenge(self, _username, _lang, _submethods): """Return a keyboard-interactive challenge""" if self._kbdint_challenge is True: return '', '', '', (('Password:', False),) else: return self._kbdint_challenge async def validate_kbdint_response(self, _username, _responses): """Validate keyboard-interactive responses""" return self._success @patch_gss class _TestAuth(AsyncTestCase): """Unit tests for auth module""" async def check_auth(self, method, expected_result, **kwargs): """Unit test authentication""" client_conn, server_conn = _AuthClientStub.make_pair(method, **kwargs) try: self.assertEqual((await client_conn.get_auth_result()), expected_result) finally: client_conn.close() server_conn.close() @asynctest async def test_client_auth_methods(self): """Test client auth methods""" with self.subTest('Unknown client auth method'): with self.assertRaises(ValueError): _AuthClientStub.make_pair(b'xxx') @asynctest async def test_server_auth_methods(self): """Test server auth methods""" with self.subTest('No auth methods'): server_conn = _AuthServerStub() self.assertEqual( get_supported_server_auth_methods(server_conn), []) server_conn.close() with self.subTest('All auth methods'): gss_host = '1' if gss_available else None server_conn = _AuthServerStub( gss_host=gss_host, host_based_auth=True, public_key_auth=True, password_auth=True, kbdint_auth=True) if gss_available: # pragma: no branch self.assertEqual( get_supported_server_auth_methods(server_conn), [b'gssapi-keyex', b'gssapi-with-mic', b'hostbased', b'publickey', b'keyboard-interactive', b'password']) else: # pragma: no cover self.assertEqual( get_supported_server_auth_methods(server_conn), [b'hostbased', b'publickey', b'keyboard-interactive', b'password']) server_conn.close() with self.subTest('Unknown auth method'): server_conn = _AuthServerStub() self.assertEqual(lookup_server_auth(server_conn, 'user', b'xxx', SSHPacket(b'')), None) server_conn.close() @asynctest async def test_null_auth(self): """Unit test null authentication""" await self.check_auth(b'none', (False, None)) @unittest.skipUnless(gss_available, 'GSS not available') @asynctest async def test_gss_auth(self): """Unit test GSS authentication""" with self.subTest('GSS with MIC auth not available'): await self.check_auth(b'gssapi-with-mic', (False, None)) for steps in range(4): with self.subTest('GSS with MIC auth available'): await self.check_auth(b'gssapi-with-mic', (True, None), gss_host=str(steps), success=True) gss_host = str(steps) + ',step_error' with self.subTest('GSS with MIC error', steps=steps): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host=gss_host) with self.subTest('GSS with MIC error with token', steps=steps): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host=gss_host + ',errtok') with self.subTest('GSS with MIC without integrity'): await self.check_auth(b'gssapi-with-mic', (True, None), gss_host='1,no_client_integrity,' + 'no_server_integrity', success=True) with self.subTest('GSS client integrity mismatch'): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,no_client_integrity') with self.subTest('GSS server integrity mismatch'): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,no_server_integrity') with self.subTest('GSS mechanism unknown'): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1,unknown_mech') with self.subTest('GSS mechanism mismatch'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'gssapi-with-mic', (False, None), gss_host='1', override_gss_mech=True) @asynctest async def test_hostbased_auth(self): """Unit test host-based authentication""" hkey = get_test_key('ecdsa-sha2-nistp256') cert = hkey.generate_host_certificate(hkey, 'host') with self.subTest('Host-based auth not available'): await self.check_auth(b'hostbased', (False, None)) with self.subTest('Untrusted key'): await self.check_auth(b'hostbased', (False, None), client_host_key=hkey, host_based_auth=True) with self.subTest('Trusted key'): await self.check_auth(b'hostbased', (True, None), client_host_key=hkey, host_based_auth=True, success=True) with self.subTest('Trusted certificate'): await self.check_auth(b'hostbased', (True, None), client_host_key=hkey, client_host_cert=cert, host_based_auth=True, success=True) @asynctest async def test_publickey_auth(self): """Unit test public key authentication""" ckey = get_test_key('ecdsa-sha2-nistp256') cert = ckey.generate_user_certificate(ckey, 'name') with self.subTest('Public key auth not available'): await self.check_auth(b'publickey', (False, None)) with self.subTest('Untrusted key'): await self.check_auth(b'publickey', (False, None), client_key=ckey, public_key_auth=True) with self.subTest('Trusted key'): await self.check_auth(b'publickey', (True, None), client_key=ckey, public_key_auth=True, success=True) with self.subTest('Trusted certificate'): await self.check_auth(b'publickey', (True, None), client_key=ckey, client_cert=cert, public_key_auth=True, success=True) with self.subTest('Invalid PK_OK message'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'publickey', (False, None), client_key=ckey, public_key_auth=True, override_pk_ok=True) @asynctest async def test_password_auth(self): """Unit test password authentication""" with self.subTest('Password auth not available'): await self.check_auth(b'password', (False, None)) with self.subTest('Invalid password'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'password', (False, None), password_auth=True, password=b'\xff') with self.subTest('Incorrect password'): await self.check_auth(b'password', (False, None), password_auth=True, password='password') with self.subTest('Correct password'): await self.check_auth(b'password', (True, None), password_auth=True, password='password', success=True) with self.subTest('Password change not available'): await self.check_auth(b'password', (False, None), password_auth=True, password='password', password_change_prompt='change') with self.subTest('Invalid password change prompt'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'password', (False, False), password_auth=True, password='password', password_change=True, password_change_prompt=b'\xff') with self.subTest('Password change failed'): await self.check_auth(b'password', (False, False), password_auth=True, password='password', password_change=True, password_change_prompt='change') with self.subTest('Password change succeeded'): await self.check_auth(b'password', (True, True), password_auth=True, password='password', password_change=True, password_change_prompt='change', success=True) @asynctest async def test_kbdint_auth(self): """Unit test keyboard-interactive authentication""" with self.subTest('Kbdint auth not available'): await self.check_auth(b'keyboard-interactive', (False, None)) with self.subTest('No submethods'): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True) with self.subTest('Invalid submethods'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods=b'\xff') with self.subTest('No challenge'): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='') with self.subTest('Invalid challenge name'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=(b'\xff', '', '', ())) with self.subTest('Invalid challenge prompt'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=('', '', '', ((b'\xff', False),))) with self.subTest('No response'): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True) with self.subTest('Invalid response'): with self.assertRaises(asyncssh.ProtocolError): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=(b'\xff',)) with self.subTest('Incorrect response'): await self.check_auth(b'keyboard-interactive', (False, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=True) with self.subTest('Correct response'): await self.check_auth(b'keyboard-interactive', (True, None), kbdint_auth=True, kbdint_submethods='', kbdint_challenge=True, kbdint_response=True, success=True) asyncssh-2.20.0/tests/test_auth_keys.py000066400000000000000000000210061475467777400202300ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for matching against authorized_keys file""" import unittest import asyncssh from .util import TempDirTestCase, get_test_key, x509_available class _TestAuthorizedKeys(TempDirTestCase): """Unit tests for auth_keys module""" keylist = [] imported_keylist = [] certlist = [] imported_certlist = [] @classmethod def setUpClass(cls): """Create public keys needed for test""" super().setUpClass() for i in range(3): key = get_test_key('ssh-rsa', i) cls.keylist.append(key.export_public_key().decode('ascii')) cls.imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch subject = f'CN=cert{i}' cert = key.generate_x509_user_certificate(key, subject) cls.certlist.append(cert.export_certificate().decode('ascii')) cls.imported_certlist.append(cert) def build_keys(self, keys, x509=False, from_file=False): """Build and import a list of authorized keys""" auth_keys = '# Comment line\n # Comment line with whitespace\n\n' for options in keys: options = options + ' ' if options else '' keynum = 1 if 'cert-authority' in options else 0 key_or_cert = (self.certlist if x509 else self.keylist)[keynum] auth_keys += options + key_or_cert if from_file: with open('authorized_keys', 'w') as f: f.write(auth_keys) return asyncssh.read_authorized_keys('authorized_keys') else: return asyncssh.import_authorized_keys(auth_keys) def match_keys(self, tests, x509=False): """Match against authorized keys""" for keys, matches in tests: auth_keys = self.build_keys(keys, x509) for (msg, keynum, client_host, \ client_addr, cert_principals, match) in matches: with self.subTest(msg, x509=x509): if x509: result, trusted_cert = auth_keys.validate_x509( self.imported_certlist[keynum], client_host, client_addr) if (trusted_cert and trusted_cert.subject != self.imported_certlist[keynum].subject): result = None else: result = auth_keys.validate( self.imported_keylist[keynum], client_host, client_addr, cert_principals, keynum == 1) self.assertEqual(result is not None, match) def test_matches(self): """Test authorized keys matching""" tests = ( ((None, 'cert-authority'), (('Match key or cert', 0, '1.2.3.4', '1.2.3.4', None, True), ('Match CA key or cert', 1, '1.2.3.4', '1.2.3.4', None, True), ('No match', 2, '1.2.3.4', '1.2.3.4', None, False))), (('from="1.2.3.4"',), (('Match IP', 0, '1.2.3.4', '1.2.3.4', None, True),)), (('from="1.2.3.0/24,!1.2.3.5"',), (('Match subnet', 0, '1.2.3.4', '1.2.3.4', None, True), ('Exclude IP', 0, '1.2.3.5', '1.2.3.5', None, False))), (('from="localhost*"',), (('Match host name', 0, 'localhost', '127.0.0.1', None, True),)), (('from="1.2.3.*,!1.2.3.5*"',), (('Match host pattern', 0, '1.2.3.4', '1.2.3.4', None, True), ('Exclude host pattern', 0, '1.2.3.5', '1.2.3.5', None, False))), (('principals="cert*,!cert1"',), (('Match principal', 0, '1.2.3.4', '1.2.3.4', ['cert0'], True),)), (('cert-authority,principals="cert*,!cert1"',), (('Exclude principal', 1, '1.2.3.4', '1.2.3.4', ['cert1'], False),)) ) self.match_keys(tests) if x509_available: # pragma: no branch self.match_keys(tests, x509=True) def test_options(self): """Test authorized keys returned option values""" tests = ( ('Command', 'command="ls abc"', {'command': 'ls abc'}), ('PermitOpen', 'permitopen="xxx:123"', {'permitopen': {('xxx', 123)}}), ('PermitOpen IPv6 address', 'permitopen="[fe80::1]:123"', {'permitopen': {('fe80::1', 123)}}), ('PermitOpen wildcard port', 'permitopen="xxx:*"', {'permitopen': {('xxx', None)}}), ('Unknown option', 'foo=abc,foo=def', {'foo': ['abc', 'def']}), ('Escaped value', 'environment="FOO=\\"xxx\\""', {'environment': {'FOO': '"xxx"'}}) ) for msg, options, expected in tests: with self.subTest(msg): auth_keys = self.build_keys([options]) result = auth_keys.validate(self.imported_keylist[0], '1.2.3.4', None, False) self.assertEqual(result, expected) def test_file(self): """Test reading authorized keys from file""" self.build_keys([None], from_file=True) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_match(self): """Test match on X.509 subject name""" auth_keys = asyncssh.import_authorized_keys( 'x509v3-ssh-rsa subject=CN=cert0\n') result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4', '1.2.3.4') self.assertIsNotNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_option_match(self): """Test match on X.509 subject in options""" auth_keys = asyncssh.import_authorized_keys( 'subject=CN=cert0 ' + self.certlist[0]) result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4', '1.2.3.4') self.assertIsNotNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_subject_option_mismatch(self): """Test failed match on X.509 subject in options""" auth_keys = asyncssh.import_authorized_keys( 'subject=CN=cert1 ' + self.certlist[0]) result, _ = auth_keys.validate_x509( self.imported_certlist[0], '1.2.3.4', '1.2.3.4') self.assertIsNone(result) @unittest.skipUnless(x509_available, 'X.509 not available') def test_cert_authority_with_subject(self): """Test error when cert-authority is used with subject""" with self.assertRaises(ValueError): asyncssh.import_authorized_keys( 'cert-authority x509v3-sign-rsa subject=CN=cert0\n') @unittest.skipUnless(x509_available, 'X.509 not available') def test_non_root_ca(self): """Test error on non-root X.509 CA""" key = get_test_key('ssh-rsa') cert = key.generate_x509_user_certificate(key, 'CN=a', 'CN=b') data = 'cert-authority ' + cert.export_certificate().decode('ascii') with self.assertRaises(ValueError): asyncssh.import_authorized_keys(data) def test_errors(self): """Test various authorized key parsing errors""" tests = ( ('Bad key', 'xxx\n'), ('Unbalanced quote', 'xxx"\n'), ('Unbalanced backslash', 'xxx\\\n'), ('Missing option name', '=xxx\n'), ('Environment missing equals', 'environment="FOO"\n'), ('Environment missing variable name', 'environment="=xxx"\n'), ('PermitOpen missing colon', 'permitopen="xxx"\n'), ('PermitOpen non-integer port', 'permitopen="xxx:yyy"\n') ) for msg, data in tests: with self.subTest(msg): with self.assertRaises(ValueError): asyncssh.import_authorized_keys(data) asyncssh-2.20.0/tests/test_channel.py000066400000000000000000001774571475467777400176720ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH channel API""" import asyncio import os import tempfile import unittest from signal import SIGINT from unittest.mock import patch import asyncssh from asyncssh.constants import DEFAULT_LANG, MSG_USERAUTH_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE from asyncssh.constants import MSG_CHANNEL_WINDOW_ADJUST from asyncssh.constants import MSG_CHANNEL_DATA from asyncssh.constants import MSG_CHANNEL_EXTENDED_DATA from asyncssh.constants import MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE from asyncssh.constants import MSG_CHANNEL_SUCCESS from asyncssh.packet import Byte, String, UInt32 from asyncssh.public_key import CERT_TYPE_USER from asyncssh.stream import SSHTCPStreamSession, SSHUNIXStreamSession from asyncssh.stream import SSHTunTapStreamSession from asyncssh.tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_MODE_ETHERNET from .server import Server, ServerTestCase from .util import asynctest, echo, make_certificate PTY_OP_PARTIAL = 158 PTY_OP_NO_END = 159 class _ClientChannel(asyncssh.SSHClientChannel): """Patched SSH client channel for unit testing""" def _send_request(self, request, *args, want_reply=False): """Send a channel request""" if request == b'pty-req': if args[5][-6:-5] == Byte(PTY_OP_PARTIAL): args = args[:5] + (String(args[5][4:-5]),) elif args[5][-6:-5] == Byte(PTY_OP_NO_END): args = args[:5] + (String(args[5][4:-6]),) super()._send_request(request, *args, want_reply=want_reply) def get_send_pktsize(self): """Return the sender's max packet size """ return self._send_pktsize def send_request(self, request, *args): """Send a custom request (for unit testing)""" self._send_request(request, *args) async def make_request(self, request, *args): """Make a custom request (for unit testing)""" return await self._make_request(request, *args) class _ClientSession(asyncssh.SSHClientSession): """Unit test SSH client session""" def __init__(self): self._chan = None self.recv_buf = {None: [], asyncssh.EXTENDED_DATA_STDERR: []} self.xon_xoff = None self.exit_status = None self.exit_signal_msg = None self.exc = None def connection_made(self, chan): """Handle connection open""" self._chan = chan def connection_lost(self, exc): """Handle connection close""" self.exc = exc self._chan = None def data_received(self, data, datatype): """Handle data from the channel""" self.recv_buf[datatype].append(data) def xon_xoff_requested(self, client_can_do): """Handle request to enable/disable XON/XOFF flow control""" self.xon_xoff = client_can_do def exit_status_received(self, status): """Handle remote exit status""" # pylint: disable=unused-argument self.exit_status = status def exit_signal_received(self, signal, core_dumped, msg, lang): """Handle remote exit signal""" # pylint: disable=unused-argument self.exit_signal_msg = msg async def _create_session(conn, command=None, *, subsystem=None, **kwargs): """Create a client session""" return await conn.create_session(_ClientSession, command, subsystem=subsystem, **kwargs) class _ServerChannel(asyncssh.SSHServerChannel): """Patched SSH server channel class for unit testing""" def _send_request(self, request, *args, want_reply=False): """Send a channel request""" if request == b'exit-signal': if args[0] == String('invalid'): args = (String(b'\xff'),) + args[1:] if args[3] == String('invalid'): args = args[:3] + (String(b'\xff'),) super()._send_request(request, *args, want_reply=want_reply) def _process_delayed_request(self, packet): """Process a request that delays before responding""" packet.check_end() asyncio.get_event_loop().call_later(0.1, self._report_response, True) def get_send_pktsize(self): """Return the sender's max packet size """ return self._send_pktsize async def open_session(self): """Attempt to open a session on the client""" return await self._open(b'session') class _EchoServerSession(asyncssh.SSHServerSession): """A shell session which echoes data from stdin to stdout/stderr""" def __init__(self): self._chan = None def connection_made(self, chan): """Handle session open""" self._chan = chan username = self._chan.get_extra_info('username') if username == 'close': self._chan.close() elif username == 'task_error': raise RuntimeError('Exception handler test') def shell_requested(self): """Handle shell request""" return True def data_received(self, data, datatype): """Handle data from the channel""" self._chan.write(data[:1]) self._chan.writelines([data[1:]]) self._chan.write_stderr(data[:1]) self._chan.writelines_stderr([data[1:]]) def eof_received(self): """Handle EOF on the channel""" self._chan.write_eof() self._chan.close() class _PTYServerSession(asyncssh.SSHServerSession): """Server for testing PTY requests""" def __init__(self): self._chan = None self._pty_ok = True def connection_made(self, chan): """Handle session open""" self._chan = chan username = self._chan.get_extra_info('username') if username == 'no_pty': self._pty_ok = False def pty_requested(self, term_type, term_size, term_modes): """Handle pseudo-terminal request""" self._chan.set_extra_info( pty_args=(term_type, term_size, term_modes.get(asyncssh.PTY_OP_OSPEED))) return self._pty_ok def shell_requested(self): """Handle shell request""" return True def session_started(self): """Handle session start""" chan = self._chan chan.write(f'Req: {chan.get_extra_info("pty_args")}\n') chan.close() class _ChannelServer(Server): """Server for testing the AsyncSSH channel API""" async def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=too-many-statements action = stdin.channel.get_command() or stdin.channel.get_subsystem() if not action: action = 'echo' if action == 'echo': await echo(stdin, stdout, stderr) elif action == 'conn_close': await stdin.read(1) stdout.write('\n') self._conn.close() elif action == 'close': await stdin.read(1) stdout.write('\n') elif action == 'agent': try: async with asyncssh.connect_agent(self._conn) as agent: stdout.write(str(len(await agent.get_keys())) + '\n') except (OSError, asyncssh.ChannelOpenError): stdout.channel.exit(1) elif action == 'agent_sock': agent_path = stdin.channel.get_agent_path() if agent_path: async with asyncssh.connect_agent(agent_path) as agent: await asyncio.sleep(0.1) stdout.write(str(len(await agent.get_keys())) + '\n') else: stdout.channel.exit(1) elif action == 'rejected_agent': agent_path = stdin.channel.get_agent_path() stdout.write(str(bool(agent_path)) + '\n') chan = self._conn.create_agent_channel() try: await chan.open(SSHUNIXStreamSession) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_session': chan = _ServerChannel(self._conn, asyncio.get_event_loop(), False, False, False, 0, 1024, None, 'strict', 1, 32768) try: await chan.open_session() except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_tcpip_direct': chan = self._conn.create_tcp_channel() try: await chan.connect(SSHTCPStreamSession, '', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'unknown_tcpip_listener': chan = self._conn.create_tcp_channel() try: await chan.accept(SSHTCPStreamSession, 'xxx', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'invalid_tcpip_listener': chan = self._conn.create_tcp_channel() try: await chan.accept(SSHTCPStreamSession, b'\xff', 0, '', 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_unix_direct': chan = self._conn.create_unix_channel() try: await chan.connect(SSHUNIXStreamSession, '') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'unknown_unix_listener': chan = self._conn.create_unix_channel() try: await chan.accept(SSHUNIXStreamSession, 'xxx') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'invalid_unix_listener': chan = self._conn.create_unix_channel() try: await chan.accept(SSHUNIXStreamSession, b'\xff') except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_tun_request': chan = self._conn.create_tuntap_channel() try: await chan.open(SSHTunTapStreamSession, SSH_TUN_MODE_POINTTOPOINT, 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'rejected_tap_request': chan = self._conn.create_tuntap_channel() try: await chan.open(SSHTunTapStreamSession, SSH_TUN_MODE_ETHERNET, 0) except asyncssh.ChannelOpenError: stdout.channel.exit(1) elif action == 'late_auth_banner': try: self._conn.send_auth_banner('auth banner') except OSError: stdin.channel.exit(1) elif action == 'invalid_open_confirm': stdin.channel.send_packet(MSG_CHANNEL_OPEN_CONFIRMATION, UInt32(0), UInt32(0), UInt32(0)) elif action == 'invalid_open_failure': stdin.channel.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0), String(''), String('')) elif action == 'env': value = stdin.channel.get_environment_bytes().get(b'TEST', b'') stdout.write(value.decode('utf-8', 'backslashreplace') + '\n') elif action == 'env_binary_key': value = stdin.channel.get_environment_bytes().get(b'TEST\xff', b'') stdout.write(value.decode('utf-8', 'backslashreplace') + '\n') elif action == 'env_str': value = stdin.channel.get_environment().get('TEST', '') stdout.write(value + '\n') elif action == 'env_str_cached': value1 = stdin.channel.get_environment().get('TEST', '') value2 = stdin.channel.get_environment().get('TEST', '') stdout.write(value1 + value2 + '\n') elif action == 'env_non_string_key': value = stdin.channel.get_environment().get('1', '') stdout.write(value + '\n') elif action == 'term': chan = stdin.channel info = str((chan.get_terminal_type(), chan.get_terminal_size(), chan.get_terminal_mode(asyncssh.PTY_OP_OSPEED))) stdout.write(info + '\n') elif action == 'xon_xoff': stdin.channel.set_xon_xoff(True) elif action == 'no_xon_xoff': stdin.channel.set_xon_xoff(False) elif action == 'signals': try: await stdin.readline() except asyncssh.BreakReceived as exc: stdin.channel.exit_with_signal('ABRT', False, str(exc.msec)) except asyncssh.SignalReceived as exc: stdin.channel.exit_with_signal('ABRT', False, exc.signal) except asyncssh.TerminalSizeChanged as exc: size = (exc.width, exc.height, exc.pixwidth, exc.pixheight) stdin.channel.exit_with_signal('ABRT', False, str(size)) elif action == 'exit_status': stdin.channel.exit(1) elif action == 'closed_status': stdin.channel.close() stdin.channel.exit(1) elif action == 'exit_signal': stdin.channel.exit_with_signal('INT', False, 'exit_signal') elif action == 'unknown_signal': stdin.channel.exit_with_signal('unknown', False, 'unknown_signal') elif action == 'closed_signal': stdin.channel.close() stdin.channel.exit_with_signal('INT', False, 'closed_signal') elif action == 'invalid_exit_signal': stdin.channel.exit_with_signal('invalid') elif action == 'invalid_exit_lang': stdin.channel.exit_with_signal('INT', False, '', 'invalid') elif action == 'window_after_close': stdin.channel.send_packet(MSG_CHANNEL_CLOSE) stdin.channel.send_packet(MSG_CHANNEL_WINDOW_ADJUST, UInt32(0)) elif action == 'empty_data': stdin.channel.send_packet(MSG_CHANNEL_DATA, String('')) elif action == 'partial_unicode': data = '\xff\xff'.encode() stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[3:])) elif action == 'partial_unicode_at_eof': data = '\xff\xff'.encode() stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) elif action == 'unicode_error': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(b'\xff')) elif action == 'data_past_window': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(2*1025*1024*'\0')) elif action == 'ext_data_past_window': stdin.channel.send_packet(MSG_CHANNEL_EXTENDED_DATA, UInt32(asyncssh.EXTENDED_DATA_STDERR), String(2*1025*1024*'\0')) elif action == 'data_after_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdout.write('xxx') elif action == 'data_after_close': await asyncio.sleep(0.1) stdout.write('xxx') elif action == 'ext_data_after_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdin.channel.write_stderr('xxx') elif action == 'invalid_datatype': stdin.channel.send_packet(MSG_CHANNEL_EXTENDED_DATA, UInt32(255), String('')) elif action == 'double_eof': await asyncio.sleep(0.1) stdin.channel.send_packet(MSG_CHANNEL_EOF) stdin.channel.write_eof() elif action == 'double_close': await asyncio.sleep(0.1) stdout.write('xxx') stdin.channel.send_packet(MSG_CHANNEL_CLOSE) elif action == 'request_after_close': stdin.channel.send_packet(MSG_CHANNEL_CLOSE) stdin.channel.exit(1) elif action == 'unexpected_auth': self._conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), String('ssh-connection'), String('none')) elif action == 'invalid_response': stdin.channel.send_packet(MSG_CHANNEL_SUCCESS) elif action == 'send_pktsize': stdout.write(str(stdout.channel.get_send_pktsize())) stdout.close() else: stdin.channel.exit(255) stdin.channel.close() await stdin.channel.wait_closed() async def _conn_close(self): """Close the connection during a channel open""" self._conn.close() await asyncio.sleep(0.1) return _EchoServerSession() def begin_auth(self, username): """Handle client authentication request""" return username not in {'guest', 'conn_close_startup', 'conn_close_open', 'close', 'echo', 'no_channels', 'no_pty', 'request_pty', 'task_error'} def session_requested(self): """Handle a request to create a new session""" username = self._conn.get_extra_info('username') with patch('asyncssh.connection.SSHServerChannel', _ServerChannel): channel = self._conn.create_server_channel() if username == 'conn_close_startup': self._conn.close() return False elif username == 'conn_close_open': return (channel, self._conn_close()) elif username in {'close', 'echo', 'task_error'}: return (channel, _EchoServerSession()) elif username in {'request_pty', 'no_pty'}: return (channel, _PTYServerSession()) elif username != 'no_channels': return (channel, self._begin_session) else: return False class _TestChannel(ServerTestCase): """Unit tests for AsyncSSH channel API""" # pylint: disable=too-many-public-methods @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server( _ChannelServer, authorized_client_keys='authorized_keys')) async def _check_action(self, command, expected_result): """Run a command on a remote session and check for a specific result""" async with self.connect() as conn: chan, session = await _create_session(conn, command) await chan.wait_closed() self.assertEqual(session.exit_status, expected_result) async def _check_session(self, conn, command=(), *, large_block=False, **kwargs): """Open a session and test if an input line is echoed back""" chan, session = await _create_session(conn, command, **kwargs) if large_block: data = 4 * [1025*1024*'\0'] else: data = [str(id(self))] chan.writelines(data) self.assertTrue(chan.can_write_eof()) self.assertFalse(chan.is_closing()) chan.write_eof() self.assertTrue(chan.is_closing()) await chan.wait_closed() data = ''.join(data) for buf in session.recv_buf.values(): self.assertEqual(data, ''.join(buf)) chan.close() @asynctest async def test_shell(self): """Test starting a shell""" async with self.connect(username='echo') as conn: await self._check_session(conn) @asynctest async def test_shell_failure(self): """Test failure to start a shell""" async with self.connect(username='no_channels') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn) @asynctest async def test_shell_internal_error(self): """Test internal error in callback to start a shell""" async with self.connect(username='task_error') as conn: with self.assertRaises((OSError, asyncssh.ConnectionLost)): await _create_session(conn) @asynctest async def test_shell_large_block(self): """Test starting a shell and sending a large block of data""" async with self.connect(username='echo') as conn: await self._check_session(conn, large_block=True) @asynctest async def test_exec(self): """Test execution of a remote command""" async with self.connect() as conn: await self._check_session(conn, 'echo', window=1024*1024, max_pktsize=16384) @asynctest async def test_exec_from_connect(self): """Test execution of a remote command set on connection""" async with self.connect(command='echo') as conn: await self._check_session(conn) @asynctest async def test_forced_exec(self): """Test execution of a forced remote command""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], options={'force-command': String('echo')}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: await self._check_session(conn) @asynctest async def test_invalid_exec(self): """Test execution of an invalid remote command""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, b'\xff') @asynctest async def test_exec_failure(self): """Test failure to execute a remote command""" async with self.connect(username='no_channels') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, 'echo') @asynctest async def test_subsystem(self): """Test starting a subsystem""" async with self.connect() as conn: await self._check_session(conn, subsystem='echo') @asynctest async def test_invalid_subsystem(self): """Test starting an invalid subsystem""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, subsystem=b'\xff') @asynctest async def test_subsystem_failure(self): """Test failure to start a subsystem""" async with self.connect(username='no_channels') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, subsystem='echo') @asynctest async def test_conn_close_during_startup(self): """Test connection close during channel startup""" async with self.connect(username='conn_close_startup') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn) @asynctest async def test_conn_close_during_open(self): """Test connection close during channel open""" async with self.connect(username='conn_close_open') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn) @asynctest async def test_close_during_startup(self): """Test channel close during startup""" async with self.connect(username='close') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn) @asynctest async def test_inbound_conn_close_while_read_paused(self): """Test inbound connection close while reading is paused""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'conn_close') chan.pause_reading() chan.write('\n') await asyncio.sleep(0.1) conn.close() await chan.wait_closed() @asynctest async def test_outbound_conn_close_while_read_paused(self): """Test outbound connection close while reading is paused""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'close') chan.pause_reading() chan.write('\n') await asyncio.sleep(0.1) conn.close() await chan.wait_closed() @asynctest async def test_close_while_read_paused(self): """Test closing a remotely closed channel while reading is paused""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'close') chan.pause_reading() chan.write('\n') await asyncio.sleep(0.1) chan.close() await chan.wait_closed() @asynctest async def test_keepalive(self): """Test keepalive channel requests""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, _ = await _create_session(conn) result = await chan.make_request(b'keepalive@openssh.com') self.assertFalse(result) @asynctest async def test_invalid_open_confirmation(self): """Test receiving an open confirmation on already open channel""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_open_confirm') await chan.wait_closed() @asynctest async def test_invalid_open_failure(self): """Test receiving an open failure on already open channel""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_open_failure') await chan.wait_closed() @asynctest async def test_unknown_channel_request(self): """Test sending unknown channel request""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, _ = await _create_session(conn) self.assertFalse(await chan.make_request('unknown')) @asynctest async def test_invalid_channel_request(self): """Test sending non-ASCII channel request""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, _ = await _create_session(conn) with self.assertRaises(asyncssh.ProtocolError): await chan.make_request('\xff') @asynctest async def test_delayed_channel_request(self): """Test queuing channel requests with delayed response""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, _ = await _create_session(conn) chan.send_request(b'delayed') chan.send_request(b'delayed') @asynctest async def test_invalid_channel_response(self): """Test receiving response for non-existent channel request""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_response') chan.close() @asynctest async def test_already_open(self): """Test connect on an already open channel""" async with self.connect() as conn: chan, _ = await _create_session(conn) with self.assertRaises(OSError): await chan.create(None, None, None, {}, False, None, None, None, False, None, None, False, False) chan.close() @asynctest async def test_write_buffer(self): """Test setting write buffer limits""" async with self.connect() as conn: chan, _ = await _create_session(conn) chan.set_write_buffer_limits() chan.set_write_buffer_limits(low=8192) chan.set_write_buffer_limits(high=32768) chan.set_write_buffer_limits(32768, 8192) with self.assertRaises(ValueError): chan.set_write_buffer_limits(8192, 32768) self.assertEqual(chan.get_write_buffer_size(), 0) chan.close() @asynctest async def test_empty_write(self): """Test writing an empty block of data""" async with self.connect() as conn: chan, _ = await _create_session(conn) chan.write('') chan.close() @asynctest async def test_invalid_write_extended(self): """Test writing using an invalid extended data type""" async with self.connect() as conn: chan, _ = await _create_session(conn) with self.assertRaises(OSError): chan.write('test', -1) @asynctest async def test_unneeded_resume_reading(self): """Test resume reading when not paused""" async with self.connect() as conn: chan, _ = await _create_session(conn) await asyncio.sleep(0.1) chan.resume_reading() chan.close() @asynctest async def test_agent_forwarding_explicit(self): """Test SSH agent forwarding with explicit path""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with self.connect(username='ckey', agent_forwarding='agent') as conn: chan, session = await _create_session(conn, 'agent') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') chan, session = await _create_session(conn, 'agent') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') @asynctest async def test_agent_forwarding_sock(self): """Test SSH agent forwarding via UNIX domain socket""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with self.connect(username='ckey', agent_forwarding=True) as conn: chan, session = await _create_session(conn, 'agent_sock') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '3\n') @asynctest async def test_rejected_session(self): """Test receiving inbound session request""" await self._check_action('rejected_session', 1) @asynctest async def test_rejected_tcpip_direct(self): """Test receiving inbound direct TCP/IP connection""" await self._check_action('rejected_tcpip_direct', 1) @asynctest async def test_unknown_tcpip_listener(self): """Test receiving connection on unknown TCP/IP listener""" await self._check_action('unknown_tcpip_listener', 1) @asynctest async def test_invalid_tcpip_listener(self): """Test receiving connection on invalid TCP/IP listener path""" await self._check_action('invalid_tcpip_listener', None) @asynctest async def test_rejected_unix_direct(self): """Test receiving inbound direct UNIX connection""" await self._check_action('rejected_unix_direct', 1) @asynctest async def test_unknown_unix_listener(self): """Test receiving connection on unknown UNIX listener""" await self._check_action('unknown_unix_listener', 1) @asynctest async def test_invalid_unix_listener(self): """Test receiving connection on invalid UNIX listener path""" await self._check_action('invalid_unix_listener', None) @asynctest async def test_rejected_tun_request(self): """Test receiving inbound TUN request""" await self._check_action('rejected_tun_request', 1) @asynctest async def test_rejected_tap_request(self): """Test receiving inbound TAP request""" await self._check_action('rejected_tap_request', 1) @asynctest async def test_agent_forwarding_failure(self): """Test failure of SSH agent forwarding""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-agent-forwarding': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None, agent_forwarding=True) as conn: chan, session = await _create_session(conn, 'agent') await chan.wait_closed() self.assertEqual(session.exit_status, 1) @asynctest async def test_agent_forwarding_sock_failure(self): """Test failure to create SSH agent forwarding socket""" old_tempdir = tempfile.tempdir try: tempfile.tempdir = 'xxx' async with self.connect(username='ckey', agent_forwarding=True) as conn: chan, session = await _create_session(conn, 'agent_sock') await chan.wait_closed() self.assertEqual(session.exit_status, 1) finally: tempfile.tempdir = old_tempdir @asynctest async def test_agent_forwarding_not_offered(self): """Test SSH agent forwarding not offered by client""" async with self.connect() as conn: chan, session = await _create_session(conn, 'agent') await chan.wait_closed() self.assertEqual(session.exit_status, 1) @asynctest async def test_agent_forwarding_rejected(self): """Test rejection of SSH agent forwarding by client""" async with self.connect() as conn: chan, session = await _create_session(conn, 'rejected_agent') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'False\n') self.assertEqual(session.exit_status, 1) @asynctest async def test_request_pty(self): """Test reuquesting a PTY with terminal information""" modes = {asyncssh.PTY_OP_OSPEED: 9600} async with self.connect(username='request_pty') as conn: chan, session = await _create_session(conn, term_type='ansi', term_size=(80, 24), term_modes=modes) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "Req: ('ansi', (80, 24, 0, 0), 9600)\r\n") @asynctest async def test_terminal_full_size(self): """Test sending terminal information with full size""" modes = {asyncssh.PTY_OP_OSPEED: 9600} async with self.connect() as conn: chan, session = await _create_session(conn, 'term', term_type='ansi', term_size=(80, 24, 480, 240), term_modes=modes) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('ansi', (80, 24, 480, 240), 9600)\r\n") @asynctest async def test_pty_without_term_type(self): """Test requesting a PTY without setting the terminal type""" async with self.connect() as conn: chan, session = await _create_session(conn, 'term', request_pty='force') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('', (0, 0, 0, 0), None)\n") @asynctest async def test_invalid_terminal_size(self): """Test sending invalid terminal size""" async with self.connect() as conn: with self.assertRaises(ValueError): await _create_session(conn, 'term', term_type='ansi', term_size=(0, 0, 0)) @asynctest async def test_invalid_terminal_modes(self): """Test sending invalid terminal modes""" modes = {asyncssh.PTY_OP_RESERVED: 0} async with self.connect() as conn: with self.assertRaises(ValueError): await _create_session(conn, 'term', term_type='ansi', term_modes=modes) @asynctest async def test_pty_disallowed_by_cert(self): """Test rejection of pty request by certificate""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-pty': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, 'term', term_type='ansi') @asynctest async def test_pty_disallowed_by_session(self): """Test rejection of pty request by session""" async with self.connect(username='no_pty') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_session(conn, term_type='ansi') @asynctest async def test_invalid_term_type(self): """Test requesting an invalid terminal type""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: with self.assertRaises(asyncssh.ProtocolError): await _create_session(conn, term_type=b'\xff') @asynctest async def test_term_modes_missing_end(self): """Test sending terminal modes without PTY_OP_END""" modes = {asyncssh.PTY_OP_OSPEED: 9600, PTY_OP_NO_END: 0} with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, session = await _create_session(conn, 'term', term_type='ansi', term_modes=modes) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, "('ansi', (0, 0, 0, 0), 9600)\r\n") @asynctest async def test_term_modes_incomplete(self): """Test sending terminal modes with incomplete value""" modes = {asyncssh.PTY_OP_OSPEED: 9600, PTY_OP_PARTIAL: 0} with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: with self.assertRaises(asyncssh.ProtocolError): await _create_session(conn, 'term', term_type='ansi', term_modes=modes) @asynctest async def test_env(self): """Test setting environment with byte strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env={b'TEST': b'test'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_str(self): """Test setting environment using Unicode strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env_str', env={'TEST': 'test'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_str_cached(self): """Test caching of Unicode string environment dict on server""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env_str_cached', env={'TEST': 'test'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'testtest\n') @asynctest async def test_env_invalid_str(self): """Test trying to access binary envionment value as a Unicode string""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env_str', env={'TEST': b'test\xff'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '\n') @asynctest async def test_env_binary_key(self): """Test setting environment with binary data in key""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env_binary_key', env={b'TEST\xff': 'test'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_binary_value(self): """Test setting environment with binary data in value""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env={'TEST': b'test\xff'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\\xff\n') @asynctest async def test_env_non_string_key(self): """Test setting environment with non-string as a key""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env_non_string_key', env={1: 'test'}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_non_string_value(self): """Test setting environment with non-string as a value""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env={'TEST': 1}) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '1\n') @asynctest async def test_invalid_env(self): """Test sending invalid environment""" async with self.connect() as conn: with self.assertRaises(ValueError): await _create_session(conn, 'env', env=1) @asynctest async def test_env_from_connect(self): """Test setting environment on connection""" async with self.connect(env={'TEST': 'test'}) as conn: chan, session = await _create_session(conn, 'env') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_list(self): """Test setting environment using a list of name=value strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env=['TEST=test']) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_env_list_binary(self): """Test setting environment using a list of name=value byte strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env=[b'TEST=test\xff']) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\\xff\n') @asynctest async def test_env_tuple(self): """Test setting environment using a tuple of name=value strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', env=('TEST=test',)) await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @asynctest async def test_invalid_env_list(self): """Test setting environment using an invalid string""" with self.assertRaises(ValueError): async with self.connect() as conn: await _create_session(conn, 'env', env=['XXX']) @asynctest async def test_send_env(self): """Test sending local environment""" async with self.connect() as conn: try: os.environ['TEST'] = 'test' chan, session = await _create_session(conn, 'env', send_env=['TEST']) finally: del os.environ['TEST'] await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') @unittest.skipUnless(os.supports_bytes_environ, 'skip binary send env if not supported by OS') @asynctest async def test_send_env_binary(self): """Test sending local environment using a byte string""" async with self.connect() as conn: try: os.environb[b'TEST'] = b'test\xff' chan, session = await _create_session(conn, 'env', send_env=[b'TEST']) finally: del os.environb[b'TEST'] await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\\xff\n') @asynctest async def test_send_env_from_connect(self): """Test sending local environment on connection""" try: os.environ['TEST'] = 'test' async with self.connect(send_env=['TEST']) as conn: chan, session = await _create_session(conn, 'env') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') finally: del os.environ['TEST'] @asynctest async def test_mixed_env(self): """Test sending a mix of local environment and new values""" async with self.connect() as conn: try: os.environ['TEST'] = '1' chan, session = await _create_session(conn, 'env', env={'TEST': 2}, send_env='TEST') finally: del os.environ['TEST'] await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '2\n') @asynctest async def test_xon_xoff_enable(self): """Test enabling XON/XOFF flow control""" async with self.connect() as conn: chan, session = await _create_session(conn, 'xon_xoff') await chan.wait_closed() self.assertEqual(session.xon_xoff, True) @asynctest async def test_xon_xoff_disable(self): """Test disabling XON/XOFF flow control""" async with self.connect() as conn: chan, session = await _create_session(conn, 'no_xon_xoff') await chan.wait_closed() self.assertEqual(session.xon_xoff, False) @asynctest async def test_break(self): """Test sending a break""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.send_break(1000) await chan.wait_closed() self.assertEqual(session.exit_signal_msg, '1000') @asynctest async def test_signal(self): """Test sending a signal""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.send_signal('INT') await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'INT') @asynctest async def test_numeric_signal(self): """Test sending a signal using a numeric value""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.send_signal(SIGINT) await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'INT') @asynctest async def test_unknown_signal(self): """Test sending a signal with an unknown numeric value""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'signals') with self.assertRaises(ValueError): chan.send_signal(123) chan.close() @asynctest async def test_terminate(self): """Test sending a terminate signal""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.terminate() await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'TERM') @asynctest async def test_kill(self): """Test sending a kill signal""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.kill() await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'KILL') @asynctest async def test_invalid_signal(self): """Test sending an invalid signal""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect() as conn: chan, session = await _create_session(conn, 'signals') chan.send_signal(b'\xff') chan.write('\n') await chan.wait_closed() self.assertEqual(session.exit_status, None) @asynctest async def test_terminal_size_change(self): """Test sending terminal size change""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals', term_type='ansi') chan.change_terminal_size(80, 24) await chan.wait_closed() self.assertEqual(session.exit_signal_msg, '(80, 24, 0, 0)') @asynctest async def test_full_terminal_size_change(self): """Test sending full terminal size change""" async with self.connect() as conn: chan, session = await _create_session(conn, 'signals', term_type='ansi') chan.change_terminal_size(80, 24, 480, 240) await chan.wait_closed() self.assertEqual(session.exit_signal_msg, '(80, 24, 480, 240)') @asynctest async def test_exit_status(self): """Test receiving exit status""" async with self.connect() as conn: chan, session = await _create_session(conn, 'exit_status') await chan.wait_closed() self.assertEqual(session.exit_status, 1) self.assertEqual(chan.get_exit_status(), 1) self.assertIsNone(chan.get_exit_signal()) self.assertEqual(chan.get_returncode(), 1) @asynctest async def test_exit_status_after_close(self): """Test delivery of exit status after remote close""" async with self.connect() as conn: chan, session = await _create_session(conn, 'closed_status') await chan.wait_closed() self.assertIsNone(session.exit_status) self.assertIsNone(chan.get_exit_status()) self.assertIsNone(chan.get_exit_signal()) self.assertIsNone(chan.get_returncode()) @asynctest async def test_exit_signal(self): """Test receiving exit signal""" async with self.connect() as conn: chan, session = await _create_session(conn, 'exit_signal') await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'exit_signal') self.assertEqual(chan.get_exit_status(), -1) self.assertEqual(chan.get_exit_signal(), ('INT', False, 'exit_signal', DEFAULT_LANG)) self.assertEqual(chan.get_returncode(), -SIGINT) @asynctest async def test_exit_signal_after_close(self): """Test delivery of exit signal after remote close""" async with self.connect() as conn: chan, session = await _create_session(conn, 'closed_signal') await chan.wait_closed() self.assertIsNone(session.exit_signal_msg) self.assertIsNone(chan.get_exit_status()) self.assertIsNone(chan.get_exit_signal()) self.assertIsNone(chan.get_returncode()) @asynctest async def test_unknown_exit_signal(self): """Test receiving unknown exit signal""" async with self.connect() as conn: chan, session = await _create_session(conn, 'unknown_signal') await chan.wait_closed() self.assertEqual(session.exit_signal_msg, 'unknown_signal') self.assertEqual(chan.get_exit_status(), -1) self.assertEqual(chan.get_exit_signal(), ('unknown', False, 'unknown_signal', DEFAULT_LANG)) self.assertEqual(chan.get_returncode(), -99) @asynctest async def test_invalid_exit_signal(self): """Test delivery of invalid exit signal""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_exit_signal') await chan.wait_closed() @asynctest async def test_invalid_exit_lang(self): """Test delivery of invalid exit signal language""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_exit_lang') await chan.wait_closed() @asynctest async def test_window_adjust_after_eof(self): """Test receiving window adjust after EOF""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'window_after_close') await chan.wait_closed() @asynctest async def test_empty_data(self): """Test receiving empty data packet""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'empty_data') chan.close() @asynctest async def test_partial_unicode(self): """Test receiving Unicode data spread across two packets""" async with self.connect() as conn: chan, session = await _create_session(conn, 'partial_unicode') await chan.wait_closed() result = ''.join(session.recv_buf[None]) self.assertEqual(result, '\xff\xff') @asynctest async def test_partial_unicode_at_eof(self): """Test receiving partial Unicode data and then EOF""" async with self.connect() as conn: chan, session = await _create_session( conn, 'partial_unicode_at_eof') await chan.wait_closed() self.assertIsInstance(session.exc, asyncssh.ProtocolError) @asynctest async def test_unicode_error(self): """Test receiving bad Unicode data""" async with self.connect() as conn: chan, session = await _create_session(conn, 'unicode_error') await chan.wait_closed() self.assertIsInstance(session.exc, asyncssh.ProtocolError) @asynctest async def test_data_past_window(self): """Test receiving a data packet past the advertised window""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'data_past_window') await chan.wait_closed() @asynctest async def test_ext_data_past_window(self): """Test receiving an extended data packet past the advertised window""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'ext_data_past_window') await chan.wait_closed() @asynctest async def test_data_after_eof(self): """Test receiving data after EOF""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'data_after_eof') await chan.wait_closed() @asynctest async def test_data_after_close(self): """Test receiving data after close""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'data_after_close') chan.write(4*1025*1024*'\0') chan.close() await asyncio.sleep(0.2) await chan.wait_closed() @asynctest async def test_extended_data_after_eof(self): """Test receiving extended data after EOF""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'ext_data_after_eof') await chan.wait_closed() @asynctest async def test_invalid_datatype(self): """Test receiving data with invalid data type""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'invalid_datatype') await chan.wait_closed() @asynctest async def test_double_eof(self): """Test receiving two EOF messages""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'double_eof') await chan.wait_closed() @asynctest async def test_double_close(self): """Test receiving two close messages""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'double_close') chan.pause_reading() await asyncio.sleep(0.2) chan.resume_reading() await chan.wait_closed() @asynctest async def test_request_after_close(self): """Test receiving a channel request after a close""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'request_after_close') await chan.wait_closed() @asynctest async def test_late_auth_banner(self): """Test server sending authentication banner after auth completes""" async with self.connect() as conn: chan, session = await _create_session(conn, 'late_auth_banner') await chan.wait_closed() self.assertEqual(session.exit_status, 1) @asynctest async def test_unexpected_userauth_request(self): """Test userauth request sent to client""" async with self.connect() as conn: chan, _ = await _create_session(conn, 'unexpected_auth') await chan.wait_closed() @asynctest async def test_unknown_action(self): """Test unknown action""" async with self.connect() as conn: chan, session = await _create_session(conn, 'unknown') await chan.wait_closed() self.assertEqual(session.exit_status, 255) class _TestChannelNoPTY(ServerTestCase): """Unit tests for AsyncSSH channel module with PTYs disallowed""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server( _ChannelServer, authorized_client_keys='authorized_keys', allow_pty=False)) @asynctest async def test_shell_pty(self): """Test starting a shell that request a PTY""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.run(term_type='ansi') @asynctest async def test_shell_no_pty(self): """Test starting a shell that doesn't request a PTY""" async with self.connect() as conn: await conn.run(request_pty=False, stdin=asyncssh.DEVNULL) @asynctest async def test_exec_pty(self): """Test execution of a remote command that requests a PTY""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.run('echo', request_pty='force') @asynctest async def test_exec_pty_from_connect(self): """Test execution of a command that requests a PTY on the connection""" async with self.connect(request_pty='force') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.run('echo') @asynctest async def test_exec_no_pty(self): """Test execution of a remote command that doesn't request a PTY""" async with self.connect() as conn: await conn.run('echo', term_type='ansi', request_pty='auto', stdin=asyncssh.DEVNULL) class _TestChannelNoAgentForwarding(ServerTestCase): """Unit tests for channel module with agent forwarding disallowed""" @classmethod async def start_server(cls): """Start an SSH server with agent forwarding disabled""" return (await cls.create_server( _ChannelServer, authorized_client_keys='authorized_keys', agent_forwarding=False)) @asynctest async def test_agent_forwarding_disallowed(self): """Test starting a shell that request a PTY""" async with self.connect(agent_forwarding=True) as conn: result = await conn.run('agent') self.assertEqual(result.exit_status, 1) class _TestConnectionDropbearClient(ServerTestCase): """Unit tests for testing Dropbear client compatibility fix""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return await cls.create_server(_ChannelServer) @asynctest async def test_dropbear_client(self): """Test reduced dropbear send packet size""" with patch('asyncssh.connection.SSHServerChannel', _ServerChannel): async with self.connect( client_version='dropbear', max_pktsize=32759, compression_algs=['zlib@openssh.com']) as conn: _, stdout, _ = await conn.open_session('send_pktsize') self.assertEqual((await stdout.read()), '32758') async with self.connect(client_version='dropbear', max_pktsize=32759, compression_algs=None) as conn: _, stdout, _ = await conn.open_session('send_pktsize') self.assertEqual((await stdout.read()), '32759') class _TestConnectionDropbearServer(ServerTestCase): """Unit tests for testing Dropbear server compatibility fix""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return await cls.create_server( _ChannelServer, server_version='dropbear', max_pktsize=32759) @asynctest async def test_dropbear_server(self): """Test reduced dropbear send packet size""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): async with self.connect( compression_algs='zlib@openssh.com') as conn: stdin, _, _ = await conn.open_session() self.assertEqual(stdin.channel.get_send_pktsize(), 32758) async with self.connect(compression_algs=None) as conn: stdin, _, _ = await conn.open_session() self.assertEqual(stdin.channel.get_send_pktsize(), 32759) asyncssh-2.20.0/tests/test_compression.py000066400000000000000000000030041475467777400205730ustar00rootroot00000000000000# Copyright (c) 2015-2018 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for compression""" import os import unittest from asyncssh.compression import get_compression_algs, get_compression_params from asyncssh.compression import get_compressor, get_decompressor class TestCompression(unittest.TestCase): """Unit tests for compression module""" def test_compression_algs(self): """Unit test compression algorithms""" for alg in get_compression_algs(): with self.subTest(alg=alg): get_compression_params(alg) data = os.urandom(256) compressor = get_compressor(alg) decompressor = get_decompressor(alg) if compressor: cmpdata = compressor.compress(data) self.assertEqual(decompressor.decompress(cmpdata), data) asyncssh-2.20.0/tests/test_config.py000066400000000000000000000550031475467777400175050ustar00rootroot00000000000000# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for parsing OpenSSH-compatible config file""" import os import socket import unittest from pathlib import Path from unittest.mock import patch import asyncssh from asyncssh.config import SSHClientConfig, SSHServerConfig from .util import TempDirTestCase class _TestConfig(TempDirTestCase): """Unit tests for config module""" @classmethod def setUpClass(cls): """Set up $HOME and .ssh directory""" super().setUpClass() os.mkdir('.ssh', 0o700) os.environ['HOME'] = '.' os.environ['USERPROFILE'] = '.' def _load_config(self, config, last_config=None, reload=False): """Abstract method to load a config object""" raise NotImplementedError def _parse_config(self, config_data, **kwargs): """Return a config object based on the specified data""" with open('config', 'w') as f: f.write(config_data) return self._load_config('config', **kwargs) def test_blank_and_comment(self): """Test blank and comment lines""" config = self._parse_config('\n#Port 22') self.assertIsNone(config.get('Port')) def test_set_bool(self): """Test boolean config option""" for value, result in (('yes', True), ('true', True), ('no', False), ('false', False)): config = self._parse_config(f'Compression {value}') self.assertEqual(config.get('Compression'), result) config = self._parse_config('Compression yes\nCompression no') self.assertEqual(config.get('Compression'), True) def test_set_int(self): """Test integer config option""" config = self._parse_config('Port 1') self.assertEqual(config.get('Port'), 1) config = self._parse_config('Port 1\nPort 2') self.assertEqual(config.get('Port'), 1) def test_set_string(self): """Test string config option""" config = self._parse_config('BindAddress addr') self.assertEqual(config.get('BindAddress'), 'addr') config = self._parse_config('BindAddress addr1\nBindAddress addr2') self.assertEqual(config.get('BindAddress'), 'addr1') def test_set_address_family(self): """Test address family config option""" for family, result in (('any', socket.AF_UNSPEC), ('inet', socket.AF_INET), ('inet6', socket.AF_INET6)): config = self._parse_config(f'AddressFamily {family}') self.assertEqual(config.get('AddressFamily'), result) config = self._parse_config('AddressFamily inet\n' 'AddressFamily inet6') self.assertEqual(config.get('AddressFamily'), socket.AF_INET) def test_set_canonicaize_host(self): """Test canonicalize host config option""" for value, result in (('yes', True), ('true', True), ('no', False), ('false', False), ('always', 'always')): config = self._parse_config(f'CanonicalizeHostname {value}') self.assertEqual(config.get('CanonicalizeHostname'), result) config = self._parse_config('CanonicalizeHostname yes\n' 'CanonicalizeHostname no') self.assertEqual(config.get('CanonicalizeHostname'), True) def test_set_rekey_limit(self): """Test rekey limit config option""" for value, result in (('1', ('1', ())), ('1 2', ('1', '2')), ('1 none', ('1', None)), ('default', ((), ())), ('default 2', ((), '2')), ('default none', ((), None))): config = self._parse_config(f'RekeyLimit {value}') self.assertEqual(config.get('RekeyLimit'), result) config = self._parse_config('RekeyLimit 1 2\nRekeyLimit 3 4') self.assertEqual(config.get('RekeyLimit'), ('1', '2')) def test_get_compression_algs(self): """Test getting compression algorithms""" config = self._parse_config('Compression yes') self.assertEqual(config.get_compression_algs(), 'zlib@openssh.com,zlib,none') config = self._parse_config('Compression no') self.assertEqual(config.get_compression_algs(), 'none,zlib@openssh.com,zlib') config = self._parse_config('') self.assertEqual(config.get_compression_algs(), ()) def test_include(self): """Test include config option""" with open('.ssh/include', 'w') as f: f.write('Port 2222') for path in ('include', Path('.ssh/include').resolve().as_posix()): config = self._parse_config(f'Include {path}') self.assertEqual(config.get('Port'), 2222) def test_missing_include(self): """Test missing include target""" # Missing include files should be ignored self._parse_config('Include xxx') def test_multiple_include(self): """Test multiple levels of include""" os.mkdir('.ssh/dir1') os.mkdir('.ssh/dir2') with open('.ssh/include', 'w') as f: f.write('Include dir1/include2\n' 'Include dir2/include4\n') with open('.ssh/dir1/include2', 'w') as f: f.write('Include dir1/include3\n') with open('.ssh/dir1/include3', 'w') as f: f.write('AddressFamily inet\n') with open('.ssh/dir2/include4', 'w') as f: f.write('Port 2222\n') config = self._parse_config('Include include') self.assertEqual(config.get('AddressFamily'), socket.AF_INET) self.assertEqual(config.get('Port'), 2222) def test_match_all(self): """Test a match block which always matches""" config = self._parse_config('Match user xxx\nMatch all\nPort 2222') self.assertEqual(config.get('Port'), 2222) def test_match_negated(self): """Test a match block which never matches due to negation""" config = self._parse_config('Match !all user xxx\nPort 2222') self.assertEqual(config.get('Port'), None) def test_match_canonical(self): """Test a match block which matches when the host is canonicalized""" config = self._parse_config('Match canonical\nPort 2222') self.assertEqual(config.get('Port'), None) def test_match_final(self): """Test a match block which matches on the final parsing pass""" config = self._parse_config('Match final\nPort 2222') self.assertEqual(config.get('Port'), None) def test_match_exec(self): """Test a match block which runs a subprocess""" config = self._parse_config('Match exec "exit 0"\nPort 2222') self.assertEqual(config.get('Port'), 2222) config = self._parse_config('Match exec "exit 1"\nPort 2222') self.assertEqual(config.get('Port'), None) def test_config_disabled(self): """Test config loading being disabled""" self._load_config(None) def test_config_list(self): """Test reading multiple config files""" with open('config1', 'w') as f: f.write('BindAddress addr') with open('config2', 'w') as f: f.write('Port 2222') config = self._load_config(['config1', 'config2']) self.assertEqual(config.get('BindAddress'), 'addr') self.assertEqual(config.get('Port'), 2222) def test_equals(self): """Test config option with equals instead of space""" for delimiter in ('=', ' =', '= ', ' = '): config = self._parse_config(f'Compression{delimiter}yes') self.assertEqual(config.get('Compression'), True) def test_unknown(self): """Test unknown config option""" config = self._parse_config('XXX') self.assertIsNone(config.get('XXX')) def test_errors(self): """Test config errors""" for desc, config_data in ( ('Missing value', 'AddressFamily'), ('Unbalanced quotes', 'BindAddress "foo'), ('Extra data at end', 'BindAddress foo bar'), ('Invalid address family', 'AddressFamily xxx'), ('Invalid canonicalization option', 'CanonicalizeHostname xxx'), ('Invalid boolean', 'Compression xxx'), ('Invalid integer', 'Port xxx'), ('Invalid match condition', 'Match xxx')): with self.subTest(desc): with self.assertRaises(asyncssh.ConfigParseError): self._parse_config(config_data) class _TestClientConfig(_TestConfig): """Unit tests for client config objects""" def _load_config(self, config, last_config=None, reload=False, canonical=False, final=False, local_user='user', user=(), host='host', port=()): """Load a client configuration""" # pylint: disable=arguments-differ return SSHClientConfig.load(last_config, config, reload, canonical, final, local_user, user, host, port) def test_set_string_none(self): """Test string config option""" config = self._parse_config('IdentityAgent none') self.assertIsNone(config.get('IdentityAgent', ())) def test_append_string(self): """Test appending a string config option to a list""" config = self._parse_config('IdentityFile foo\nIdentityFile bar') self.assertEqual(config.get('IdentityFile'), ['foo', 'bar']) config = self._parse_config('IdentityFile foo\nIdentityFile none') self.assertEqual(config.get('IdentityFile'), ['foo']) config = self._parse_config('IdentityFile none') self.assertEqual(config.get('IdentityFile'), []) def test_set_string_list(self): """Test string list config option""" config = self._parse_config('UserKnownHostsFile file1 file2') self.assertEqual(config.get('UserKnownHostsFile'), ['file1', 'file2']) config = self._parse_config('UserKnownHostsFile file1\n' 'UserKnownHostsFile file2') self.assertEqual(config.get('UserKnownHostsFile'), ['file1']) config = self._parse_config('UserKnownHostsFile none\n' 'UserKnownHostsFile file2') self.assertEqual(config.get('UserKnownHostsFile'), []) def test_append_string_list(self): """Test appending multiple string config options to a list""" config = self._parse_config('SendEnv foo\nSendEnv bar baz') self.assertEqual(config.get('SendEnv'), ['foo', 'bar', 'baz']) def test_set_environment(self): """Test setting environment with equals-separated key/value pairs""" config = self._parse_config('SetEnv A=1 B= C=D=2\nSetEnv E=3') self.assertEqual(config.get('SetEnv'), ['A=1', 'B=', 'C=D=2']) def test_set_remote_command(self): """Test setting a remote command""" config = self._parse_config(' RemoteCommand foo bar baz') self.assertEqual(config.get('RemoteCommand'), 'foo bar baz') def test_set_forward_agent(self): """Test agent forwarding path config option""" for value, result in (('yes', True), ('true', True), ('no', False), ('false', False), ('agent', 'agent'), ('%d/agent', './agent')): config = self._parse_config(f'ForwardAgent {value}') self.assertEqual(config.get('ForwardAgent'), result) config = self._parse_config('ForwardAgent yes\nForwardAgent no') self.assertEqual(config.get('ForwardAgent'), True) def test_set_request_tty(self): """Test pseudo-terminal request config option""" for value, result in (('yes', True), ('true', True), ('no', False), ('false', False), ('force', 'force'), ('auto', 'auto')): config = self._parse_config(f'RequestTTY {value}') self.assertEqual(config.get('RequestTTY'), result) config = self._parse_config('RequestTTY yes\nRequestTTY no') self.assertEqual(config.get('RequestTTY'), True) def test_set_and_match_hostname(self): """Test setting and matching hostname""" config = self._parse_config('Host host\n' ' Hostname new%h\n' 'Match originalhost host\n' ' BindAddress addr\n' 'Match host host\n' ' Port 1111\n' 'Match host newhost\n' ' Hostname newhost2\n' ' Port 2222') self.assertEqual(config.get('Hostname'), 'newhost') self.assertEqual(config.get('BindAddress'), 'addr') self.assertEqual(config.get('Port'), 2222) def test_host_key_alias(self): """Test setting HostKeyAlias""" config = self._parse_config('Host host\n' ' Hostname 127.0.0.1\n' ' HostKeyAlias alias') self.assertEqual(config.get('HostKeyAlias'), 'alias') def test_set_and_match_user(self): """Test setting and matching user""" config = self._parse_config('User newuser\n' 'Match localuser user\n' ' BindAddress addr\n' 'Match user user\n' ' Port 1111\n' 'Match user new*\n' ' User newuser2\n' ' Port 2222') self.assertEqual(config.get('User'), 'newuser') self.assertEqual(config.get('BindAddress'), 'addr') self.assertEqual(config.get('Port'), 2222) def test_tag(self): """Test setting and matching a tag""" config = self._parse_config('Tag tag2\n' 'Match tagged tag1\n' ' Port 1111\n' 'Match tagged tag*\n' ' Port 2222') self.assertEqual(config.get('Port'), 2222) def test_port_already_set(self): """Test that port is ignored if set outside of the config""" config = self._parse_config('Port 2222', port=22) self.assertEqual(config.get('Port'), 22) def test_user_already_set(self): """Test that user is ignored if set outside of the config""" config = self._parse_config('User newuser', user='user') self.assertEqual(config.get('User'), 'user') def test_client_errors(self): """Test client config errors""" for desc, config_data in ( ('Invalid pseudo-terminal request', 'RequestTTY xxx'), ('Missing match host', 'Match host')): with self.subTest(desc): with self.assertRaises(asyncssh.ConfigParseError): self._parse_config(config_data) def test_percent_expansion(self): """Test token percent expansion""" def mock_gethostname(): """Return a static local hostname for testing""" return 'thishost.local' def mock_expanduser(_): """Return a static local home directory""" return '/home/user' with patch('socket.gethostname', mock_gethostname): with patch('os.path.expanduser', mock_expanduser): config = self._parse_config( 'Hostname newhost\n' 'User newuser\n' 'Port 2222\n' 'RemoteCommand %% %C %d %h %L %l %n %p %r %u') self.assertEqual(config.get('RemoteCommand'), '% 98625d1ca14854f2cdc34268f2afcad5237e2d9d ' '/home/user newhost thishost thishost.local ' 'host 2222 newuser user') @unittest.skipUnless(hasattr(os, 'getuid'), 'UID not available') def test_uid_percent_expansion(self): """Test UID token percent expansion where available""" def mock_getuid(): """Return a static local UID""" return 123 with patch('os.getuid', mock_getuid): config = self._parse_config('RemoteCommand %i') self.assertEqual(config.get('RemoteCommand'), '123') def test_home_percent_expansion_unavailable(self): """Test home directory token percent expansion not being available""" def mock_expanduser(path): """Don't expand the home directory""" return path def mock_pathlib_expanduser(self): """Expand user even with os.path.expanduser mocked out""" return Path(os.environ['HOME'], *self.parts[1:]) with self.assertRaises(asyncssh.ConfigParseError): with patch('os.path.expanduser', mock_expanduser), \ patch('pathlib.Path.expanduser', mock_pathlib_expanduser): self._parse_config('RemoteCommand %d') def test_uid_percent_expansion_unavailable(self): """Test UID token percent expansion not being available""" orig_hasattr = hasattr def mock_hasattr(obj, attr): if obj == os and attr == 'getuid': return False else: # pragma: no cover return orig_hasattr(obj, attr) with self.assertRaises(asyncssh.ConfigParseError): with patch('builtins.hasattr', mock_hasattr): self._parse_config('RemoteCommand %i') def test_invalid_percent_expansion(self): """Test invalid percent expansion""" for desc, config_data in ( ('Bad token in hostname', 'Hostname %p'), ('Invalid token', 'IdentityFile %x')): with self.subTest(desc): with self.assertRaises(asyncssh.ConfigParseError): self._parse_config(config_data) def test_env_expansion(self): """Test environment variable expansion""" config = self._parse_config('RemoteCommand ${HOME}/.ssh') self.assertEqual(config.get('RemoteCommand'), './.ssh') def test_invalid_env_expansion(self): """Test invalid environment variable expansion""" with self.assertRaises(asyncssh.ConfigParseError): self._parse_config('RemoteCommand ${XXX}') class _TestServerConfig(_TestConfig): """Unit tests for server config objects""" def _load_config(self, config, last_config=None, reload=False, canonical=False, final=False, local_addr='127.0.0.1', local_port=22, user='user', host=None, addr='127.0.0.1'): """Load a server configuration""" # pylint: disable=arguments-differ return SSHServerConfig.load(last_config, config, reload, canonical, final, local_addr, local_port, user, host, addr) def test_match_local_address(self): """Test matching on local address""" config = self._parse_config('Match localaddress 127.0.0.1\n' 'PermitTTY no') self.assertEqual(config.get('PermitTTY'), False) def test_match_local_port(self): """Test matching on local port""" config = self._parse_config('Match localport 22\nPermitTTY no') self.assertEqual(config.get('PermitTTY'), False) def test_match_user(self): """Test matching on user""" config = self._parse_config('Match user user\nPermitTTY no') self.assertEqual(config.get('PermitTTY'), False) def test_match_address(self): """Test matching on client address""" config = self._parse_config('Match address 127.0.0.0/8\nPermitTTY no') self.assertEqual(config.get('PermitTTY'), False) def test_reload(self): """Test update of match options""" config = self._parse_config('Match address 1.1.1.1\n' ' PermitTTY no\n' 'Match address 2.2.2.2\n' ' PermitTTY yes\n', addr='1.1.1.1') self.assertEqual(config.get('PermitTTY'), False) config = self._load_config('config', config, True, addr='2.2.2.2') self.assertEqual(config.get('PermitTTY'), True) del _TestConfig class _TestOptions(TempDirTestCase): """Test client and server connection options""" def test_client_options(self): """Test client connection options""" with open('config', 'w') as f: f.write('User newuser\nServerAliveInterval 1') options = asyncssh.SSHClientConnectionOptions( username='user', config='config') self.assertEqual(options.username, 'user') self.assertEqual(options.keepalive_interval, 1) with open('config', 'w') as f: f.write('ServerAliveInterval 2\nServerAliveCountMax 3\n') options = asyncssh.SSHClientConnectionOptions(options, config='config') self.assertEqual(options.keepalive_interval, 1) self.assertEqual(options.keepalive_count_max, 3) def test_server_options(self): """Test server connection options""" with open('config', 'w') as f: f.write('ClientAliveInterval 1\nClientAliveInterval 2') options = asyncssh.SSHServerConnectionOptions(config='config') self.assertEqual(options.keepalive_interval, 1) with open('config', 'w') as f: f.write('ClientAliveInterval 2\nClientAliveCountMax 3\n') options = asyncssh.SSHServerConnectionOptions(options, config='config') self.assertEqual(options.keepalive_interval, 1) self.assertEqual(options.keepalive_count_max, 3) asyncssh-2.20.0/tests/test_connection.py000066400000000000000000002713521475467777400204060ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH connection API""" import asyncio from copy import copy import os from pathlib import Path import socket import sys import unittest from unittest.mock import patch import asyncssh from asyncssh.constants import MSG_IGNORE, MSG_DEBUG from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS from asyncssh.constants import MSG_KEX_FIRST, MSG_KEX_LAST from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER from asyncssh.constants import MSG_USERAUTH_FIRST from asyncssh.constants import MSG_GLOBAL_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_DATA from asyncssh.compression import get_compression_algs from asyncssh.crypto.cipher import GCMCipher from asyncssh.encryption import get_encryption_algs from asyncssh.kex import get_kex_algs from asyncssh.kex_dh import MSG_KEX_ECDH_REPLY from asyncssh.mac import _HMAC, _mac_handler, get_mac_algs from asyncssh.packet import SSHPacket, Boolean, NameList, String, UInt32 from asyncssh.public_key import get_default_public_key_algs from asyncssh.public_key import get_default_certificate_algs from asyncssh.public_key import get_default_x509_certificate_algs from .server import Server, ServerTestCase from .util import asynctest, patch_extra_kex, patch_getaddrinfo from .util import patch_getnameinfo, patch_gss from .util import gss_available, nc_available, x509_available class _CheckAlgsClientConnection(asyncssh.SSHClientConnection): """Test specification of encryption algorithms""" def get_enc_algs(self): """Return the selected encryption algorithms""" return self._enc_algs def get_server_host_key_algs(self): """Return the selected server host key algorithms""" return self._server_host_key_algs class _SplitClientConnection(asyncssh.SSHClientConnection): """Test SSH messages being split into multiple packets""" def data_received(self, data, datatype=None): """Handle incoming data on the connection""" super().data_received(data[:3], datatype) super().data_received(data[3:6], datatype) super().data_received(data[6:9], datatype) super().data_received(data[9:], datatype) class _ReplayKexClientConnection(asyncssh.SSHClientConnection): """Test starting SSH key exchange while it is in progress""" def replay_kex(self): """Replay last kexinit packet""" self.send_packet(MSG_KEXINIT, self._client_kexinit[1:]) class _KeepaliveClientConnection(asyncssh.SSHClientConnection): """Test handling of keepalive requests on client""" def _process_keepalive_at_openssh_dot_com_global_request(self, packet): """Process an incoming OpenSSH keepalive request""" super()._process_keepalive_at_openssh_dot_com_global_request(packet) self.disconnect(asyncssh.DISC_BY_APPLICATION, 'Keepalive') class _KeepaliveClientConnectionFailure(asyncssh.SSHClientConnection): """Test handling of keepalive failures on client""" def _process_keepalive_at_openssh_dot_com_global_request(self, packet): """Ignore an incoming OpenSSH keepalive request""" class _KeepaliveServerConnection(asyncssh.SSHServerConnection): """Test handling of keepalive requests on server""" def _process_keepalive_at_openssh_dot_com_global_request(self, packet): """Process an incoming OpenSSH keepalive request""" super()._process_keepalive_at_openssh_dot_com_global_request(packet) self.disconnect(asyncssh.DISC_BY_APPLICATION, 'Keepalive') class _KeepaliveServerConnectionFailure(asyncssh.SSHServerConnection): """Test handling of keepalive failures on server""" def _process_keepalive_at_openssh_dot_com_global_request(self, packet): """Ignore an incoming OpenSSH keepalive request""" class _VersionedServerConnection(asyncssh.SSHServerConnection): """Test alternate SSH server version lines""" def __init__(self, version, leading_text, newline, *args, **kwargs): super().__init__(*args, **kwargs) self._version = version self._leading_text = leading_text self._newline = newline @classmethod def create(cls, version=b'SSH-2.0-AsyncSSH_Test', leading_text=b'', newline=b'\r\n'): """Return a connection factory which sends modified version lines""" return (lambda *args, **kwargs: cls(version, leading_text, newline, *args, **kwargs)) def _send_version(self): """Start the SSH handshake""" self._server_version = self._version self._extra.update(server_version=self._version.decode('ascii')) self._send(self._leading_text + self._version + self._newline) class _BadHostKeyServerConnection(asyncssh.SSHServerConnection): """Test returning invalid server host key""" def get_server_host_key(self): """Return the chosen server host key""" result = copy(super().get_server_host_key()) result.public_data = b'xxx' return result class _ExtInfoServerConnection(asyncssh.SSHServerConnection): """Test adding an unrecognized extension in extension info""" def _send_ext_info(self): """Send extension information""" self._extensions_to_send['xxx'] = b'' super()._send_ext_info() class _BadSignatureServerConnection(asyncssh.SSHServerConnection): """Test returning a bad signature in host keys prove request""" def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( self, packet): """Prove the server has private keys for all requested host keys""" self._report_global_response(String(b'')) class _ProveFailedServerConnection(asyncssh.SSHServerConnection): """Test returning failure in host keys prove request""" def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( self, packet): """Prove the server has private keys for all requested host keys""" super()._process_hostkeys_prove_00_at_openssh_dot_com_global_request( SSHPacket(String(b''))) def _failing_get_mac(alg, key): """Replace HMAC class with FailingMAC""" class _FailingMAC(_HMAC): """Test error in MAC validation""" def verify(self, seq, packet, sig): """Verify the signature of a message""" return super().verify(seq, packet + b'\xff', sig) _, hash_size, args = _mac_handler[alg] return _FailingMAC(key, hash_size, *args) async def _slow_connect(*_args, **_kwargs): """Simulate a really slow connect that ends up timing out""" await asyncio.sleep(5) class _FailingGCMCipher(GCMCipher): """Test error in GCM tag verification""" def verify_and_decrypt(self, header, data, mac): """Verify the signature of and decrypt a block of data""" return super().verify_and_decrypt(header, data + b'\xff', mac) class _ValidateHostKeyClient(asyncssh.SSHClient): """Test server host key/CA validation callbacks""" def __init__(self, host_key=None, ca_key=None): self._host_key = \ asyncssh.read_public_key(host_key) if host_key else None self._ca_key = \ asyncssh.read_public_key(ca_key) if ca_key else None def validate_host_public_key(self, host, addr, port, key): """Return whether key is an authorized key for this host""" # pylint: disable=unused-argument return key == self._host_key def validate_host_ca_key(self, host, addr, port, key): """Return whether key is an authorized CA key for this host""" # pylint: disable=unused-argument return key == self._ca_key class _PreAuthRequestClient(asyncssh.SSHClient): """Test sending a request prior to auth complete""" def __init__(self): self._conn = None def connection_made(self, conn): """Save connection for use later""" self._conn = conn def password_auth_requested(self): """Attempt to execute a command before authentication is complete""" # pylint: disable=protected-access self._conn._auth_complete = True self._conn.send_packet(MSG_GLOBAL_REQUEST, String(b'\xff'), Boolean(True)) return 'pw' class _InternalErrorClient(asyncssh.SSHClient): """Test of internal error exception handler""" def connection_made(self, conn): """Raise an error when a new connection is opened""" # pylint: disable=unused-argument raise RuntimeError('Exception handler test') class _TunnelServer(Server): """Allow forwarding to test server host key request tunneling""" def connection_requested(self, dest_host, dest_port, orig_host, orig_port): """Handle a request to create a new connection""" return True class _AbortServer(Server): """Server for testing connection abort during auth""" def begin_auth(self, username): """Abort the connection during auth""" self._conn.abort() return False class _CloseDuringAuthServer(Server): """Server for testing connection close during long auth callback""" def password_auth_supported(self): """Return that password auth is supported""" return True async def validate_password(self, username, password): """Delay validating password""" # pylint: disable=unused-argument await asyncio.sleep(1) return False # pragma: no cover - closed before we get here class _InternalErrorServer(Server): """Server for testing internal error during auth""" def debug_msg_received(self, msg, lang, always_display): """Process a debug message""" # pylint: disable=unused-argument raise RuntimeError('Exception handler test') class _InvalidAuthBannerServer(Server): """Server for testing invalid auth banner""" def begin_auth(self, username): """Send an invalid auth banner""" self._conn.send_auth_banner(b'\xff') return False class _VersionRecordingClient(asyncssh.SSHClient): """Client for testing custom client version""" def __init__(self): self.reported_version = None def auth_banner_received(self, msg, lang): """Record the client version reported in the auth banner""" self.reported_version = msg class _VersionReportingServer(Server): """Server for testing custom client version""" def begin_auth(self, username): """Report the client's version in the auth banner""" version = self._conn.get_extra_info('client_version') self._conn.send_auth_banner(version) return False @patch_gss @patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection) class _TestConnection(ServerTestCase): """Unit tests for AsyncSSH connection API""" # pylint: disable=too-many-public-methods @classmethod async def start_server(cls): """Start an SSH server to connect to""" def acceptor(conn): """Acceptor for SSH connections""" conn.logger.info('Acceptor called') return (await cls.create_server(_TunnelServer, gss_host=(), compression_algs='*', encryption_algs='*', kex_algs='*', mac_algs='*', acceptor=acceptor)) async def get_server_host_key(self, **kwargs): """Get host key from the test server""" return (await asyncssh.get_server_host_key(self._server_addr, self._server_port, **kwargs)) async def _check_version(self, *args, **kwargs): """Check alternate SSH server version lines""" with patch('asyncssh.connection.SSHServerConnection', _VersionedServerConnection.create(*args, **kwargs)): async with self.connect(): pass @asynctest async def test_connect(self): """Test connecting with async context manager""" async with self.connect() as conn: pass self.assertTrue(conn.is_closed()) @asynctest async def test_connect_sock(self): """Test connecting using an already-connected socket""" sock = socket.socket() await self.loop.sock_connect(sock, (self._server_addr, self._server_port)) async with asyncssh.connect(sock=sock): pass @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_connect_non_tcp_sock(self): """Test connecting using an non-TCP socket""" sock1, sock2 = socket.socketpair() proc = await asyncio.create_subprocess_exec( 'nc', str(self._server_addr), str(self._server_port), stdin=sock1, stdout=sock1, stderr=sock1) async with asyncssh.connect( self._server_addr, self._server_port, sock=sock2): pass await proc.wait() sock1.close() @asynctest async def test_run_client(self): """Test running an SSH client on an already-connected socket""" sock = socket.socket() await self.loop.sock_connect(sock, (self._server_addr, self._server_port)) async with self.run_client(sock): pass @asynctest async def test_connect_encrypted_key(self): """Test connecting with encrypted client key and no passphrase""" async with self.connect(client_keys='ckey_encrypted', ignore_encrypted=True): pass with self.assertRaises(asyncssh.KeyImportError): await self.connect(client_keys='ckey_encrypted') with open('config', 'w') as f: f.write('IdentityFile ckey_encrypted') async with self.connect(config='config'): pass with self.assertRaises(asyncssh.KeyImportError): await self.connect(config='config', ignore_encrypted=False) @asynctest async def test_connect_invalid_options_type(self): """Test connecting using options using incorrect type of options""" options = asyncssh.SSHServerConnectionOptions() with self.assertRaises(TypeError): await self.connect(options=options) @asynctest async def test_connect_invalid_option_name(self): """Test connecting using incorrect option name""" with self.assertRaises(TypeError): await self.connect(xxx=1) @asynctest async def test_connect_failure(self): """Test failure connecting""" with self.assertRaises(OSError): await asyncssh.connect('\xff') @asynctest async def test_connect_failure_without_agent(self): """Test failure connecting with SSH agent disabled""" with self.assertRaises(OSError): await asyncssh.connect('\xff', agent_path=None) @asynctest async def test_connect_timeout_exceeded(self): """Test connect timeout exceeded""" with self.assertRaises(asyncio.TimeoutError): with patch('asyncio.BaseEventLoop.create_connection', _slow_connect): await asyncssh.connect('', connect_timeout=1) @asynctest async def test_connect_timeout_exceeded_string(self): """Test connect timeout exceeded with string value""" with self.assertRaises(asyncio.TimeoutError): with patch('asyncio.BaseEventLoop.create_connection', _slow_connect): await asyncssh.connect('', connect_timeout='0m1s') @asynctest async def test_connect_timeout_exceeded_tunnel(self): """Test connect timeout exceeded""" with self.assertRaises(asyncio.TimeoutError): with patch('asyncio.BaseEventLoop.create_connection', _slow_connect): await asyncssh.listen(server_host_keys=['skey'], tunnel='', connect_timeout=1) @asynctest async def test_invalid_connect_timeout(self): """Test invalid connect timeout""" with self.assertRaises(ValueError): await self.connect(connect_timeout=-1) @asynctest async def test_connect_tcp_keepalive_off(self): """Test connecting with TCP keepalive disabled""" async with self.connect(tcp_keepalive=False) as conn: sock = conn.get_extra_info('socket') self.assertEqual(bool(sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)), False) @asynctest async def test_split_version(self): """Test version split across two packets""" with patch('asyncssh.connection.SSHClientConnection', _SplitClientConnection): async with self.connect(): pass @asynctest async def test_version_1_99(self): """Test SSH server version 1.99""" await self._check_version(b'SSH-1.99-Test') @asynctest async def test_banner_before_version(self): """Test banner lines before SSH server version""" await self._check_version(leading_text=b'Banner 1\r\nBanner 2\r\n') @asynctest async def test_banner_line_too_long(self): """Test excessively long banner line""" with self.assertRaises(asyncssh.ProtocolError): await self._check_version(leading_text=8192*b'*' + b'\r\n') @asynctest async def test_too_many_banner_lines(self): """Test too many banner lines""" with self.assertRaises(asyncssh.ProtocolError): await self._check_version(leading_text=2048*b'Banner line\r\n') @asynctest async def test_version_without_cr(self): """Test SSH server version with LF instead of CRLF""" await self._check_version(newline=b'\n') @asynctest async def test_version_line_too_long(self): """Test excessively long version line""" with self.assertRaises(asyncssh.ProtocolError): await self._check_version(newline=256*b'*' + b'\r\n') @asynctest async def test_unknown_version(self): """Test unknown SSH server version""" with self.assertRaises(asyncssh.ProtocolNotSupported): await self._check_version(b'SSH-1.0-Test') @asynctest async def test_no_server_host_keys(self): """Test starting a server with no host keys""" with self.assertRaises(ValueError): await asyncssh.create_server(Server, server_host_keys=[], gss_host=None) @asynctest async def test_duplicate_type_server_host_keys(self): """Test starting a server with duplicate host key types""" with self.assertRaises(ValueError): await asyncssh.listen(server_host_keys=['skey', 'skey']) @asynctest async def test_reserved_server_host_keys(self): """Test reserved host keys with host key sending enabled""" async with self.listen(server_host_keys=['skey', 'skey'], send_server_host_keys=True): pass @asynctest async def test_get_server_host_key(self): """Test retrieving a server host key""" keylist = asyncssh.load_public_keys('skey.pub') key = await self.get_server_host_key() self.assertEqual(key, keylist[0]) @asynctest async def test_get_server_host_key_tunnel(self): """Test retrieving a server host key while tunneling over SSH""" keylist = asyncssh.load_public_keys('skey.pub') async with self.connect() as conn: key = await self.get_server_host_key(tunnel=conn) self.assertEqual(key, keylist[0]) @asynctest async def test_get_server_host_key_connect_failure(self): """Test failure connecting when retrieving a server host key""" with self.assertRaises(OSError): await asyncssh.get_server_host_key('\xff') @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_get_server_host_key_proxy(self): """Test retrieving a server host key using proxy command""" keylist = asyncssh.load_public_keys('skey.pub') proxy_command = ('nc', str(self._server_addr), str(self._server_port)) key = await self.get_server_host_key(proxy_command=proxy_command) self.assertEqual(key, keylist[0]) @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_get_server_host_key_proxy_failure(self): """Test failure retrieving a server host key using proxy command""" # Leave out arguments to 'nc' to trigger a failure proxy_command = 'nc' with self.assertRaises((OSError, asyncssh.ConnectionLost)): await self.connect(proxy_command=proxy_command) @asynctest async def test_known_hosts_not_present(self): """Test connecting with default known hosts file not present""" try: os.rename(os.path.join('.ssh', 'known_hosts'), os.path.join('.ssh', 'known_hosts.save')) with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() finally: os.rename(os.path.join('.ssh', 'known_hosts.save'), os.path.join('.ssh', 'known_hosts')) @unittest.skipIf(sys.platform == 'win32', 'skip chmod tests on Windows') @asynctest async def test_known_hosts_not_readable(self): """Test connecting with default known hosts file not readable""" try: os.chmod(os.path.join('.ssh', 'known_hosts'), 0) with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() finally: os.chmod(os.path.join('.ssh', 'known_hosts'), 0o644) @asynctest async def test_known_hosts_none(self): """Test connecting with known hosts checking disabled""" default_algs = get_default_x509_certificate_algs() + \ get_default_certificate_algs() + \ get_default_public_key_algs() async with self.connect(known_hosts=None) as conn: self.assertEqual(conn.get_server_host_key_algs(), default_algs) @asynctest async def test_known_hosts_none_in_config(self): """Test connecting with known hosts checking disabled in config file""" with open('config', 'w') as f: f.write('UserKnownHostsFile none') async with self.connect(config='config'): pass @asynctest async def test_known_hosts_none_without_x509(self): """Test connecting with known hosts checking and X.509 disabled""" non_x509_algs = get_default_certificate_algs() + \ get_default_public_key_algs() async with self.connect(known_hosts=None, x509_trusted_certs=None) as conn: self.assertEqual(conn.get_server_host_key_algs(), non_x509_algs) @asynctest async def test_known_hosts_multiple_keys(self): """Test connecting with multiple trusted known hosts keys""" rsa_algs = [alg for alg in get_default_public_key_algs() if b'rsa' in alg] async with self.connect(x509_trusted_certs=None, known_hosts=(['skey.pub', 'skey.pub'], [], [])) as conn: self.assertEqual(conn.get_server_host_key_algs(), rsa_algs) @asynctest async def test_known_hosts_ca(self): """Test connecting with a known hosts CA""" async with self.connect(known_hosts=([], ['skey.pub'], [])) as conn: self.assertEqual(conn.get_server_host_key_algs(), get_default_x509_certificate_algs() + get_default_certificate_algs()) @asynctest async def test_known_hosts_bytes(self): """Test connecting with known hosts passed in as bytes""" with open('skey.pub', 'rb') as f: skey = f.read() async with self.connect(known_hosts=([skey], [], [])): pass @asynctest async def test_known_hosts_keylist_file(self): """Test connecting with known hosts passed as a keylist file""" async with self.connect(known_hosts=('skey.pub', [], [])): pass @asynctest async def test_known_hosts_sshkeys(self): """Test connecting with known hosts passed in as SSHKeys""" keylist = asyncssh.load_public_keys('skey.pub') async with self.connect(known_hosts=(keylist, [], [])) as conn: self.assertEqual(conn.get_server_host_key(), keylist[0]) @asynctest async def test_read_known_hosts(self): """Test connecting with known hosts object from read_known_hosts""" known_hosts = asyncssh.read_known_hosts('~/.ssh/known_hosts') async with self.connect(known_hosts=known_hosts): pass @asynctest async def test_read_known_hosts_filelist(self): """Test connecting with known hosts from read_known_hosts file list""" known_hosts = asyncssh.read_known_hosts(['~/.ssh/known_hosts']) async with self.connect(known_hosts=known_hosts): pass @asynctest async def test_import_known_hosts(self): """Test connecting with known hosts object from import_known_hosts""" known_hosts_path = os.path.join('.ssh', 'known_hosts') with open(known_hosts_path) as f: known_hosts = asyncssh.import_known_hosts(f.read()) async with self.connect(known_hosts=known_hosts): pass @asynctest async def test_validate_host_ca_callback(self): """Test callback to validate server CA key""" def client_factory(): """Return an SSHClient which can validate the sevrer CA key""" return _ValidateHostKeyClient(ca_key='skey.pub') conn, _ = await self.create_connection(client_factory, known_hosts=([], [], [])) async with conn: pass @asynctest async def test_untrusted_known_hosts_ca(self): """Test untrusted server CA key""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], ['ckey.pub'], [])) @asynctest async def test_untrusted_host_key_callback(self): """Test callback to validate server host key returning failure""" def client_factory(): """Return an SSHClient which can validate the sevrer host key""" return _ValidateHostKeyClient(host_key='ckey.pub') with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.create_connection(client_factory, known_hosts=([], [], [])) @asynctest async def test_untrusted_host_ca_callback(self): """Test callback to validate server CA key returning failure""" def client_factory(): """Return an SSHClient which can validate the sevrer CA key""" return _ValidateHostKeyClient(ca_key='ckey.pub') with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.create_connection(client_factory, known_hosts=([], [], [])) @asynctest async def test_revoked_known_hosts_key(self): """Test revoked server host key""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=(['ckey.pub'], [], ['skey.pub'])) @asynctest async def test_revoked_known_hosts_ca(self): """Test revoked server CA key""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], ['ckey.pub'], ['skey.pub'])) @asynctest async def test_empty_known_hosts(self): """Test empty known hosts list""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [])) @asynctest async def test_invalid_server_host_key(self): """Test invalid server host key""" with patch('asyncssh.connection.SSHServerConnection', _BadHostKeyServerConnection): with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() @asynctest async def test_changing_server_host_key(self): """Test changing server host key""" self._server.update(server_host_keys=['skey_ecdsa']) async with self.connect(known_hosts=None): pass self._server.update(server_host_keys=['skey']) with self.assertRaises(asyncssh.KeyExchangeFailed): await self.connect(known_hosts=(['skey_ecdsa.pub'], [], [])) @asynctest async def test_kex_algs(self): """Test connecting with different key exchange algorithms""" for kex in get_kex_algs(): kex = kex.decode('ascii') if kex.startswith('gss-') and not gss_available: # pragma: no cover continue with self.subTest(kex_alg=kex): async with self.connect(kex_algs=[kex], gss_host='1'): pass @asynctest async def test_duplicate_encryption_algs(self): """Test connecting with a duplicated encryption algorithm""" with patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection): async with self.connect( encryption_algs=['aes256-ctr', 'aes256-ctr']) as conn: self.assertEqual(conn.get_enc_algs(), [b'aes256-ctr']) @asynctest async def test_leading_encryption_alg(self): """Test adding a new first encryption algorithm""" with patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection): async with self.connect(encryption_algs='^aes256-ctr') as conn: self.assertEqual(conn.get_enc_algs()[0], b'aes256-ctr') @asynctest async def test_trailing_encryption_alg(self): """Test adding a new last encryption algorithm""" with patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection): async with self.connect(encryption_algs='+3des-cbc') as conn: self.assertEqual(conn.get_enc_algs()[-1], b'3des-cbc') @asynctest async def test_removing_encryption_alg(self): """Test removing an encryption algorithm""" with patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection): async with self.connect(encryption_algs='-aes256-ctr') as conn: self.assertTrue(b'aes256-ctr' not in conn.get_enc_algs()) @asynctest async def test_empty_kex_algs(self): """Test connecting with an empty list of key exchange algorithms""" with self.assertRaises(ValueError): await self.connect(kex_algs=[]) @asynctest async def test_invalid_kex_alg(self): """Test connecting with invalid key exchange algorithm""" with self.assertRaises(ValueError): await self.connect(kex_algs=['xxx']) @asynctest async def test_invalid_kex_alg_str(self): """Test connecting with invalid key exchange algorithm pattern""" with self.assertRaises(ValueError): await self.connect(kex_algs='diffie-hallman-group14-sha1,xxx') @asynctest async def test_invalid_kex_alg_config(self): """Test connecting with invalid key exchange algorithm config""" with open('config', 'w') as f: f.write('KexAlgorithms diffie-hellman-group14-sha1,xxx') async with self.connect(config='config'): pass @asynctest async def test_unsupported_kex_alg(self): """Test connecting with unsupported key exchange algorithm""" def unsupported_kex_alg(): """Patched version of get_kex_algs to test unsupported algorithm""" return [b'fail'] + get_kex_algs() with patch('asyncssh.connection.get_kex_algs', unsupported_kex_alg): with self.assertRaises(asyncssh.KeyExchangeFailed): await self.connect(kex_algs=['fail']) @asynctest async def test_unknown_ext_info(self): """Test receiving unknown extension information""" with patch('asyncssh.connection.SSHServerConnection', _ExtInfoServerConnection): async with self.connect(): pass @asynctest async def test_server_ext_info(self): """Test receiving unsolicited extension information on server""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" asyncssh.connection.SSHConnection.send_newkeys(self, k, h) self._send_ext_info() with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): with self.assertRaises((ConnectionError, asyncssh.ProtocolError)): await self.connect() @asynctest async def test_message_before_kexinit_strict_kex(self): """Test receiving a message before KEXINIT with strict_kex enabled""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEXINIT: self.send_packet(MSG_IGNORE, String(b'')) asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHClientConnection.send_packet', send_packet): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_message_during_kex_strict_kex(self): """Test receiving an unexpected message with strict_kex enabled""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEX_ECDH_REPLY: self.send_packet(MSG_IGNORE, String(b'')) asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHServerConnection.send_packet', send_packet): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_unknown_message_during_kex_strict_kex(self): """Test receiving an unknown message with strict_kex enabled""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEX_ECDH_REPLY: self.send_packet(MSG_KEX_LAST) asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHServerConnection.send_packet', send_packet): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_encryption_algs(self): """Test connecting with different encryption algorithms""" for enc in get_encryption_algs(): enc = enc.decode('ascii') with self.subTest(encryption_alg=enc): async with self.connect(encryption_algs=[enc]): pass @asynctest async def test_empty_encryption_algs(self): """Test connecting with an empty list of encryption algorithms""" with self.assertRaises(ValueError): await self.connect(encryption_algs=[]) @asynctest async def test_invalid_encryption_alg(self): """Test connecting with invalid encryption algorithm""" with self.assertRaises(ValueError): await self.connect(encryption_algs=['xxx']) @asynctest async def test_mac_algs(self): """Test connecting with different MAC algorithms""" for mac in get_mac_algs(): mac = mac.decode('ascii') with self.subTest(mac_alg=mac): async with self.connect(encryption_algs=['aes128-ctr'], mac_algs=[mac]): pass @asynctest async def test_mac_verify_error(self): """Test MAC validation failure""" with patch('asyncssh.encryption.get_mac', _failing_get_mac): for mac in ('hmac-sha2-256-etm@openssh.com', 'hmac-sha2-256'): with self.subTest(mac_alg=mac): with self.assertRaises(asyncssh.MACError): await self.connect(encryption_algs=['aes128-ctr'], mac_algs=[mac]) @asynctest async def test_gcm_verify_error(self): """Test GCM tag validation failure""" with patch('asyncssh.encryption.GCMCipher', _FailingGCMCipher): with self.assertRaises(asyncssh.MACError): await self.connect(encryption_algs=['aes128-gcm@openssh.com']) @asynctest async def test_empty_mac_algs(self): """Test connecting with an empty list of MAC algorithms""" with self.assertRaises(ValueError): await self.connect(mac_algs=[]) @asynctest async def test_invalid_mac_alg(self): """Test connecting with invalid MAC algorithm""" with self.assertRaises(ValueError): await self.connect(mac_algs=['xxx']) @asynctest async def test_compression_algs(self): """Test connecting with different compression algorithms""" for cmp in get_compression_algs(): cmp = cmp.decode('ascii') with self.subTest(cmp_alg=cmp): async with self.connect(compression_algs=[cmp]): pass @asynctest async def test_no_compression(self): """Test connecting with compression disabled""" async with self.connect(compression_algs=None): pass @asynctest async def test_invalid_cmp_alg(self): """Test connecting with invalid compression algorithm""" with self.assertRaises(ValueError): await self.connect(compression_algs=['xxx']) @asynctest async def test_disconnect(self): """Test sending disconnect message""" conn = await self.connect() conn.disconnect(asyncssh.DISC_BY_APPLICATION, 'Closing') await conn.wait_closed() @asynctest async def test_invalid_disconnect(self): """Test sending disconnect message with invalid Unicode in it""" conn = await self.connect() conn.disconnect(asyncssh.DISC_BY_APPLICATION, b'\xff') await conn.wait_closed() @asynctest async def test_debug(self): """Test sending debug message""" async with self.connect() as conn: conn.send_debug('debug') @asynctest async def test_invalid_debug(self): """Test sending debug message with invalid Unicode in it""" conn = await self.connect() conn.send_debug(b'\xff') await conn.wait_closed() @asynctest async def test_service_request_before_kex_complete(self): """Test service request before kex is complete""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" self._kex_complete = True self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) asyncssh.connection.SSHConnection.send_newkeys(self, k, h) with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_service_accept_before_kex_complete(self): """Test service accept before kex is complete""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" self._kex_complete = True self.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) asyncssh.connection.SSHConnection.send_newkeys(self, k, h) with patch('asyncssh.connection.SSHServerConnection.send_newkeys', send_newkeys): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_unexpected_service_name_in_request(self): """Test unexpected service name in service request""" conn = await self.connect() conn.send_packet(MSG_SERVICE_REQUEST, String('xxx')) await conn.wait_closed() @asynctest async def test_unexpected_service_name_in_accept(self): """Test unexpected service name in accept sent by server""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" asyncssh.connection.SSHConnection.send_newkeys(self, k, h) self.send_packet(MSG_SERVICE_ACCEPT, String('xxx')) with patch('asyncssh.connection.SSHServerConnection.send_newkeys', send_newkeys): with self.assertRaises(asyncssh.ServiceNotAvailable): await self.connect() @asynctest async def test_service_accept_from_client(self): """Test service accept sent by client""" conn = await self.connect() conn.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) await conn.wait_closed() @asynctest async def test_service_request_from_server(self): """Test service request sent by server""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" asyncssh.connection.SSHConnection.send_newkeys(self, k, h) self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) with patch('asyncssh.connection.SSHServerConnection.send_newkeys', send_newkeys): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_client_decompression_failure(self): """Test client decompression failure""" def send_packet(self, pkttype, *args, **kwargs): """Send an SSH packet""" asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) if pkttype == MSG_USERAUTH_SUCCESS: self._compressor = None self.send_debug('Test') with patch('asyncssh.connection.SSHServerConnection.send_packet', send_packet): await self.connect(compression_algs=['zlib@openssh.com']) @asynctest async def test_packet_decode_error(self): """Test SSH packet decode error""" conn = await self.connect() conn.send_packet(MSG_DEBUG) await conn.wait_closed() @asynctest async def test_unknown_packet(self): """Test unknown SSH packet""" async with self.connect() as conn: conn.send_packet(0xff) await asyncio.sleep(0.1) @asynctest async def test_client_keepalive(self): """Test sending keepalive from client""" with patch('asyncssh.connection.SSHServerConnection', _KeepaliveServerConnection): conn = await self.connect(keepalive_interval=0.1) await conn.wait_closed() @asynctest async def test_client_keepalive_string(self): """Test sending keepalive from client with string argument""" with patch('asyncssh.connection.SSHServerConnection', _KeepaliveServerConnection): conn = await self.connect(keepalive_interval='0.1s') await conn.wait_closed() @asynctest async def test_client_set_keepalive_interval(self): """Test sending keepalive interval with set_keepalive""" with patch('asyncssh.connection.SSHServerConnection', _KeepaliveServerConnection): conn = await self.connect() conn.set_keepalive('0m0.1s') await conn.wait_closed() @asynctest async def test_invalid_client_keepalive(self): """Test setting invalid keepalive from client""" with self.assertRaises(ValueError): await self.connect(keepalive_interval=-1) @asynctest async def test_client_set_invalid_keepalive_interval(self): """Test setting invalid keepalive interval with set_keepalive""" async with self.connect() as conn: with self.assertRaises(ValueError): conn.set_keepalive(interval=-1) @asynctest async def test_client_set_keepalive_count_max(self): """Test sending keepalive count max with set_keepalive""" with patch('asyncssh.connection.SSHServerConnection', _KeepaliveServerConnection): conn = await self.connect(keepalive_interval=0.1) conn.set_keepalive(count_max=10) await conn.wait_closed() @asynctest async def test_invalid_client_keepalive_count_max(self): """Test setting invalid keepalive count max from client""" with self.assertRaises(ValueError): await self.connect(keepalive_count_max=-1) @asynctest async def test_client_set_invalid_keepalive_count_max(self): """Test setting invalid keepalive count max with set_keepalive""" async with self.connect() as conn: with self.assertRaises(ValueError): conn.set_keepalive(count_max=-1) @asynctest async def test_client_keepalive_failure(self): """Test client keepalive failure""" with patch('asyncssh.connection.SSHServerConnection', _KeepaliveServerConnectionFailure): conn = await self.connect(keepalive_interval=0.1) await conn.wait_closed() @asynctest async def test_rekey_bytes(self): """Test SSH re-keying with byte limit""" async with self.connect(rekey_bytes=1) as conn: await asyncio.sleep(0.1) conn.send_debug('test') await asyncio.sleep(0.1) @asynctest async def test_rekey_bytes_string(self): """Test SSH re-keying with string byte limit""" async with self.connect(rekey_bytes='1') as conn: await asyncio.sleep(0.1) conn.send_debug('test') await asyncio.sleep(0.1) @asynctest async def test_invalid_rekey_bytes(self): """Test invalid rekey bytes""" for desc, rekey_bytes in ( ('Negative inteeger ', -1), ('Missing value', ''), ('Missing integer', 'k'), ('Invalid integer', '!'), ('Invalid integer', '!'), ('Invalid suffix', '1x')): with self.subTest(desc): with self.assertRaises(ValueError): await self.connect(rekey_bytes=rekey_bytes) @asynctest async def test_rekey_seconds(self): """Test SSH re-keying with time limit""" async with self.connect(rekey_seconds=0.1) as conn: await asyncio.sleep(0.1) conn.send_debug('test') await asyncio.sleep(0.1) @asynctest async def test_rekey_seconds_string(self): """Test SSH re-keying with string time limit""" async with self.connect(rekey_seconds='0m0.1s') as conn: await asyncio.sleep(0.1) conn.send_debug('test') await asyncio.sleep(0.1) @asynctest async def test_rekey_time_disabled(self): """Test SSH re-keying by time being disabled""" async with self.connect(rekey_seconds=None): pass @asynctest async def test_invalid_rekey_seconds(self): """Test invalid rekey seconds""" with self.assertRaises(ValueError): await self.connect(rekey_seconds=-1) @asynctest async def test_kex_in_progress(self): """Test starting SSH key exchange while it is in progress""" with patch('asyncssh.connection.SSHClientConnection', _ReplayKexClientConnection): conn = await self.connect() conn.replay_kex() conn.replay_kex() await conn.wait_closed() @asynctest async def test_no_matching_kex_algs(self): """Test no matching key exchange algorithms""" conn = await self.connect() conn.send_packet(MSG_KEXINIT, os.urandom(16), NameList([b'xxx']), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), Boolean(False), UInt32(0)) await conn.wait_closed() @asynctest async def test_no_matching_host_key_algs(self): """Test no matching server host key algorithms""" conn = await self.connect() conn.send_packet(MSG_KEXINIT, os.urandom(16), NameList([b'ecdh-sha2-nistp521']), NameList([b'xxx']), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), NameList([]), Boolean(False), UInt32(0)) await conn.wait_closed() @asynctest async def test_invalid_newkeys(self): """Test invalid new keys request""" conn = await self.connect() conn.send_packet(MSG_NEWKEYS) await conn.wait_closed() @asynctest async def test_kex_after_kex_complete(self): """Test kex request when kex not in progress""" conn = await self.connect() conn.send_packet(MSG_KEX_FIRST) await conn.wait_closed() @asynctest async def test_userauth_after_auth_complete(self): """Test userauth request when auth not in progress""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_FIRST) await conn.wait_closed() @asynctest async def test_userauth_before_kex_complete(self): """Test receiving userauth before kex is complete""" def send_newkeys(self, k, h): """Finish a key exchange and send a new keys message""" self._kex_complete = True self.send_packet(MSG_USERAUTH_REQUEST, String('guest'), String('ssh-connection'), String('none')) asyncssh.connection.SSHConnection.send_newkeys(self, k, h) with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): with self.assertRaises(asyncssh.ProtocolError): await self.connect() @asynctest async def test_invalid_userauth_service(self): """Test invalid service in userauth request""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), String('xxx'), String('none')) await conn.wait_closed() @asynctest async def test_no_local_username(self): """Test username being too long in userauth request""" def _failing_getuser(): raise KeyError with patch('getpass.getuser', _failing_getuser): with self.assertRaises(ValueError): await self.connect() @asynctest async def test_invalid_username(self): """Test invalid username in userauth request""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_REQUEST, String(b'\xff'), String('ssh-connection'), String('none')) await conn.wait_closed() @asynctest async def test_username_too_long(self): """Test username being too long in userauth request""" with self.assertRaises(asyncssh.IllegalUserName): await self.connect(username=2048*'a') @asynctest async def test_extra_userauth_request(self): """Test userauth request after auth is complete""" async with self.connect() as conn: conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), String('ssh-connection'), String('none')) await asyncio.sleep(0.1) @asynctest async def test_late_userauth_request(self): """Test userauth request after auth is final""" async with self.connect() as conn: conn.send_packet(MSG_GLOBAL_REQUEST, String('xxx'), Boolean(False)) conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), String('ssh-connection'), String('none')) await conn.wait_closed() @asynctest async def test_unexpected_userauth_success(self): """Test unexpected userauth success response""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_SUCCESS) await conn.wait_closed() @asynctest async def test_unexpected_userauth_failure(self): """Test unexpected userauth failure response""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(False)) await conn.wait_closed() @asynctest async def test_unexpected_userauth_banner(self): """Test unexpected userauth banner""" conn = await self.connect() conn.send_packet(MSG_USERAUTH_BANNER, String(''), String('')) await conn.wait_closed() @asynctest async def test_invalid_global_request(self): """Test invalid global request""" conn = await self.connect() conn.send_packet(MSG_GLOBAL_REQUEST, String(b'\xff'), Boolean(True)) await conn.wait_closed() @asynctest async def test_unexpected_global_response(self): """Test unexpected global response""" conn = await self.connect() conn.send_packet(MSG_GLOBAL_REQUEST, String('xxx'), Boolean(True)) await conn.wait_closed() @asynctest async def test_invalid_channel_open(self): """Test invalid channel open request""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN, String(b'\xff'), UInt32(0), UInt32(0), UInt32(0)) await conn.wait_closed() @asynctest async def test_unknown_channel_type(self): """Test unknown channel open type""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN, String('xxx'), UInt32(0), UInt32(0), UInt32(0)) await conn.wait_closed() @asynctest async def test_invalid_channel_open_confirmation_number(self): """Test invalid channel number in open confirmation""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN_CONFIRMATION, UInt32(0xff), UInt32(0), UInt32(0), UInt32(0)) await conn.wait_closed() @asynctest async def test_invalid_channel_open_failure_number(self): """Test invalid channel number in open failure""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0xff), UInt32(0), String(''), String('')) await conn.wait_closed() @asynctest async def test_invalid_channel_open_failure_reason(self): """Test invalid reason in channel open failure""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0), UInt32(0), String(b'\xff'), String('')) await conn.wait_closed() @asynctest async def test_invalid_channel_open_failure_language(self): """Test invalid language in channel open failure""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0), UInt32(0), String(''), String(b'\xff')) await conn.wait_closed() @asynctest async def test_missing_data_channel_number(self): """Test missing channel number in channel data message""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_DATA) await conn.wait_closed() @asynctest async def test_invalid_data_channel_number(self): """Test invalid channel number in channel data message""" conn = await self.connect() conn.send_packet(MSG_CHANNEL_DATA, UInt32(99), String('')) await conn.wait_closed() @asynctest async def test_internal_error(self): """Test internal error in client callback""" with self.assertRaises(RuntimeError): await self.create_connection(_InternalErrorClient) @patch_extra_kex class _TestConnectionNoStrictKex(ServerTestCase): """Unit tests for connection API with ext info and strict kex disabled""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return (await cls.create_server(_TunnelServer, gss_host=(), compression_algs='*', encryption_algs='*', kex_algs='*', mac_algs='*')) @asynctest async def test_skip_ext_info(self): """Test not requesting extension info from the server""" async with self.connect(): pass @asynctest async def test_message_before_kexinit(self): """Test receiving a message before KEXINIT""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEXINIT: self.send_packet(MSG_IGNORE, String(b'')) asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHClientConnection.send_packet', send_packet): async with self.connect(): pass @asynctest async def test_message_during_kex(self): """Test receiving an unexpected message in key exchange""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEX_ECDH_REPLY: self.send_packet(MSG_IGNORE, String(b'')) asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHServerConnection.send_packet', send_packet): async with self.connect(): pass @asynctest async def test_sequence_wrap_during_kex(self): """Test sequence wrap during initial key exchange""" def send_packet(self, pkttype, *args, **kwargs): if pkttype == MSG_KEXINIT: if self._options.command == 'send': self._send_seq = 0xfffffffe else: self._recv_seq = 0xfffffffe asyncssh.connection.SSHConnection.send_packet( self, pkttype, *args, **kwargs) with patch('asyncssh.connection.SSHClientConnection.send_packet', send_packet): with self.assertRaises(asyncssh.ProtocolError): await self.connect(command='send') with self.assertRaises(asyncssh.ProtocolError): await self.connect(command='recv') class _TestConnectionHostKeysHandler(ServerTestCase): """Unit test for specifying a host keys handler""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return (await cls.create_server( server_host_keys=['skey', 'skey_ecdsa'], send_server_host_keys=True)) async def _check_host_keys(self, host_keys, known_hosts, expected): """Check server host keys handler""" def host_keys_handler(*results): """Check reported host keys against expected value""" self.assertEqual([len(r) for r in results], expected) conn.close() async def async_host_keys_handler(*results): """Check async version of server host keys handler""" host_keys_handler(*results) self._server.update(server_host_keys=host_keys) conn = await self.connect(server_host_keys_handler=host_keys_handler, known_hosts=known_hosts) if expected is None: await asyncio.sleep(0.1) conn.close() await conn.wait_closed() if expected: conn = await self.connect( server_host_keys_handler=async_host_keys_handler, known_hosts=known_hosts) await conn.wait_closed() @asynctest async def test_host_key_handler_disabled(self): """Test server host keys handler being disabled""" async with self.connect(): await asyncio.sleep(0.1) @asynctest async def test_host_key_added(self): """Test server host keys handler showing a key added""" await self._check_host_keys(['skey', 'skey_ecdsa'], [['skey'], [], []], [1, 0, 1, 0]) @asynctest async def test_host_key_removed(self): """Test server host keys handler showing a key removed""" await self._check_host_keys(['skey'], [['skey', 'skey_ecdsa'], [], []], [0, 1, 1, 0]) @asynctest async def test_host_key_revoked(self): """Test server host keys handler showing a key revoked""" await self._check_host_keys(['skey', 'skey_ecdsa'], [['skey'], [], ['skey_ecdsa']], [0, 0, 1, 1]) @asynctest async def test_no_trusted_hosts(self): """Test server host keys handler is disabled due to no trusted hosts""" await self._check_host_keys(['skey'], None, None) @asynctest async def test_host_key_bad_signature(self): """Test server host keys handler getting back a bad signature""" with patch('asyncssh.connection.SSHServerConnection', _BadSignatureServerConnection): await self._check_host_keys(['skey', 'skey_ecdsa'], [['skey'], [], []], [0, 0, 1, 0]) @asynctest async def test_host_key_prove_failed(self): """Test server host keys handler getting back a prove failure""" with patch('asyncssh.connection.SSHServerConnection', _ProveFailedServerConnection): await self._check_host_keys(['skey', 'skey_ecdsa'], [['skey'], [], []], [0, 0, 1, 0]) class _TestConnectionListenSock(ServerTestCase): """Unit test for specifying a listen socket""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" sock = socket.socket() sock.bind(('', 0)) return await cls.create_server(_TunnelServer, sock=sock) @asynctest async def test_connect(self): """Test specifying explicit listen sock""" with self.assertLogs(level='INFO'): async with self.connect(): pass class _TestConnectionAsyncAcceptor(ServerTestCase): """Unit test for async acceptor""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" async def acceptor(conn): """Async cceptor for SSH connections""" conn.logger.info('Acceptor called') return (await cls.create_server(_TunnelServer, gss_host=(), acceptor=acceptor)) @asynctest async def test_connect(self): """Test acceptor""" with self.assertLogs(level='INFO'): async with self.connect(): pass @patch_gss class _TestConnectionServerCerts(ServerTestCase): """Unit tests for AsyncSSH server using server_certs argument""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return (await cls.create_server(_TunnelServer, gss_host=(), compression_algs='*', encryption_algs='*', kex_algs='*', mac_algs='*', server_host_keys='skey', server_host_certs='skey-cert.pub')) @asynctest async def test_connect(self): """Test connecting with async context manager""" async with self.connect(known_hosts=([], ['skey.pub'], [])): pass class _TestConnectionReverse(ServerTestCase): """Unit test for reverse direction connections""" @classmethod async def start_server(cls): """Start an SSH listener which opens SSH client connections""" def acceptor(conn): """Acceptor for reverse-direction SSH connections""" conn.logger.info('Reverse acceptor called') return await cls.listen_reverse(acceptor=acceptor) @asynctest async def test_connect_reverse(self): """Test reverse direction SSH connection""" with self.assertLogs(level='INFO'): async with self.connect_reverse(): pass @asynctest async def test_connect_reverse_sock(self): """Test reverse connection using an already-connected socket""" sock = socket.socket() await self.loop.sock_connect(sock, (self._server_addr, self._server_port)) async with self.connect_reverse(sock=sock): pass @asynctest async def test_run_server(self): """Test running an SSH server on an already-connected socket""" sock = socket.socket() await self.loop.sock_connect(sock, (self._server_addr, self._server_port)) async with self.run_server(sock): pass @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_connect_reverse_proxy(self): """Test reverse direction SSH connection with proxy command""" proxy_command = ('nc', str(self._server_addr), str(self._server_port)) async with self.connect_reverse(proxy_command=proxy_command): pass @asynctest async def test_connect_reverse_options(self): """Test reverse direction SSH connection with options""" async with self.connect_reverse(passphrase=None): pass @asynctest async def test_connect_reverse_no_server_host_keys(self): """Test starting a reverse direction connection with no host keys""" with self.assertRaises(ValueError): await self.connect_reverse(server_host_keys=[]) class _TestConnectionReverseAsyncAcceptor(ServerTestCase): """Unit test for reverse direction connections with async acceptor""" @classmethod async def start_server(cls): """Start an SSH listener which opens SSH client connections""" async def acceptor(conn): """Acceptor for reverse-direction SSH connections""" conn.logger.info('async acceptor called') return await cls.listen_reverse(acceptor=acceptor) @asynctest async def test_connect_reverse_async_acceptor(self): """Test reverse direction SSH connection with async acceptor""" with self.assertLogs(level='INFO'): async with self.connect_reverse(): pass class _TestConnectionReverseFailed(ServerTestCase): """Unit test for reverse direction connection failure""" @classmethod async def start_server(cls): """Start an SSH listener which opens SSH client connections""" def err_handler(conn, _exc): """Error handler for failed SSH handshake""" conn.logger.info('Error handler called') return (await cls.listen_reverse(username='user', error_handler=err_handler)) @asynctest async def test_connect_failed(self): """Test starting a reverse direction connection which fails""" with self.assertLogs(level='INFO'): with self.assertRaises(asyncssh.ConnectionLost): await self.connect_reverse(authorized_client_keys=[]) class _TestConnectionKeepalive(ServerTestCase): """Unit test for keepalive""" @classmethod async def start_server(cls): """Start an SSH server which sends keepalive messages""" return await cls.create_server(keepalive_interval=0.1, keepalive_count_max=3) @asynctest async def test_server_keepalive(self): """Test sending keepalive""" with patch('asyncssh.connection.SSHClientConnection', _KeepaliveClientConnection): conn = await self.connect() await conn.wait_closed() @asynctest async def test_server_keepalive_failure(self): """Test server keepalive failure""" with patch('asyncssh.connection.SSHClientConnection', _KeepaliveClientConnectionFailure): conn = await self.connect() await conn.wait_closed() class _TestConnectionAbort(ServerTestCase): """Unit test for connection abort""" @classmethod async def start_server(cls): """Start an SSH server which aborts connections during auth""" return await cls.create_server(_AbortServer) @asynctest async def test_abort(self): """Test connection abort""" with self.assertRaises(asyncssh.ConnectionLost): await self.connect() class _TestDuringAuth(ServerTestCase): """Unit test for operations during auth""" @classmethod async def start_server(cls): """Start an SSH server which aborts connections during auth""" return await cls.create_server(_CloseDuringAuthServer) @asynctest async def test_close_during_auth(self): """Test connection close during long auth callback""" with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for(self.connect(username='user', password=''), 0.5) @asynctest async def test_request_during_auth(self): """Test sending a request prior to auth complete""" with self.assertRaises(asyncssh.ProtocolError): await self.create_connection(_PreAuthRequestClient, username='user', compression_algs=['none']) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestServerX509Self(ServerTestCase): """Unit test for server with self-signed X.509 host certificate""" @classmethod async def start_server(cls): """Start an SSH server with a self-signed X.509 host certificate""" return await cls.create_server(server_host_keys=['skey_x509_self']) @asynctest async def test_connect_x509_self(self): """Test connecting with X.509 self-signed certificate""" async with self.connect(): pass @asynctest async def test_connect_x509_untrusted_self(self): """Test connecting with untrusted X.509 self-signed certificate""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(x509_trusted_certs='root_ca_cert.pem') @asynctest async def test_connect_x509_revoked_self(self): """Test connecting with revoked X.509 self-signed certificate""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], ['root_ca_cert.pem'], ['skey_x509_self.pem'], [], [])) @asynctest async def test_connect_x509_trusted_subject(self): """Test connecting to server with trusted X.509 subject name""" async with self.connect(known_hosts=([], [], [], [], [], ['OU=name'], ['OU=name1']), x509_trusted_certs=['skey_x509_self.pem']): pass @asynctest async def test_connect_x509_untrusted_subject(self): """Test connecting to server with untrusted X.509 subject name""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], [], [], ['OU=name1'], []), x509_trusted_certs=['skey_x509_self.pem']) @asynctest async def test_connect_x509_revoked_subject(self): """Test connecting to server with revoked X.509 subject name""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], [], [], [], ['OU=name']), x509_trusted_certs=['skey_x509_self.pem']) @asynctest async def test_connect_x509_disabled(self): """Test connecting to X.509 server with X.509 disabled""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], [], [], ['OU=name'], []), x509_trusted_certs=None) @unittest.skipIf(sys.platform == 'win32', 'skip chmod tests on Windows') @asynctest async def test_trusted_x509_certs_not_readable(self): """Test connecting with default trusted X509 cert file not readable""" try: os.chmod(os.path.join('.ssh', 'ca-bundle.crt'), 0) with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() finally: os.chmod(os.path.join('.ssh', 'ca-bundle.crt'), 0o644) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestServerX509Chain(ServerTestCase): """Unit test for server with X.509 host certificate chain""" @classmethod async def start_server(cls): """Start an SSH server with an X.509 host certificate chain""" return await cls.create_server(server_host_keys=['skey_x509_chain']) @asynctest async def test_connect_x509_chain(self): """Test connecting with X.509 certificate chain""" async with self.connect(x509_trusted_certs='root_ca_cert.pem'): pass @asynctest async def test_connect_x509_chain_cert_path(self): """Test connecting with X.509 certificate and certificate path""" async with self.connect(x509_trusted_cert_paths=['cert_path'], known_hosts=b'\n'): pass @asynctest async def test_connect_x509_untrusted_root(self): """Test connecting to server with untrusted X.509 root CA""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() @asynctest async def test_connect_x509_untrusted_root_cert_path(self): """Test connecting to server with untrusted X.509 root CA""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=b'\n') @asynctest async def test_connect_x509_revoked_intermediate(self): """Test connecting to server with revoked X.509 intermediate CA""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], ['root_ca_cert.pem'], ['int_ca_cert.pem'], [], [])) @asynctest async def test_connect_x509_openssh_known_hosts_trusted(self): """Test connecting with OpenSSH cert in known hosts trusted list""" with self.assertRaises(ValueError): await self.connect(known_hosts=[[], [], [], 'skey-cert.pub', [], [], []]) @asynctest async def test_connect_x509_openssh_known_hosts_revoked(self): """Test connecting with OpenSSH cert in known hosts revoked list""" with self.assertRaises(ValueError): await self.connect(known_hosts=[[], [], [], [], 'skey-cert.pub', [], []]) @asynctest async def test_connect_x509_openssh_x509_trusted(self): """Test connecting with OpenSSH cert in X.509 trusted certs list""" with self.assertRaises(ValueError): await self.connect(x509_trusted_certs='skey-cert.pub') @asynctest async def test_invalid_x509_path(self): """Test passing in invalid trusted X.509 certificate path""" with self.assertRaises(ValueError): await self.connect(x509_trusted_cert_paths='xxx') @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestServerNoHostKey(ServerTestCase): """Unit test for server with no server host key""" @classmethod async def start_server(cls): """Start an SSH server which sets no server host keys""" return await cls.create_server(server_host_keys=None, gss_host='1') @asynctest async def test_gss_with_no_host_key(self): """Test GSS key exchange with no server host key specified""" async with self.connect(known_hosts=b'\n', gss_host='1', x509_trusted_certs=None, x509_trusted_cert_paths=None): pass @asynctest async def test_dh_with_no_host_key(self): """Test failure of DH key exchange with no server host key specified""" with self.assertRaises(asyncssh.KeyExchangeFailed): await self.connect() @patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection) class _TestServerWithoutCert(ServerTestCase): """Unit tests with a server that advertises a host key instead of a cert""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return await cls.create_server(server_host_keys=[('skey', None)]) @asynctest async def test_validate_host_key_callback(self): """Test callback to validate server host key""" def client_factory(): """Return an SSHClient which can validate the sevrer host key""" return _ValidateHostKeyClient(host_key='skey.pub') conn, _ = await self.create_connection(client_factory, known_hosts=([], [], [])) async with conn: pass @asynctest async def test_validate_host_key_callback_with_algs(self): """Test callback to validate server host key with alg list""" def client_factory(): """Return an SSHClient which can validate the sevrer host key""" return _ValidateHostKeyClient(host_key='skey.pub') conn, _ = await self.create_connection( client_factory, known_hosts=([], [], []), server_host_key_algs=['rsa-sha2-256']) async with conn: pass @asynctest async def test_default_server_host_keys(self): """Test validation with default server host key algs""" def client_factory(): """Return an SSHClient which can validate the sevrer host key""" return _ValidateHostKeyClient(host_key='skey.pub') default_algs = get_default_x509_certificate_algs() + \ get_default_certificate_algs() + \ get_default_public_key_algs() conn, _ = await self.create_connection(client_factory, known_hosts=([], [], []), server_host_key_algs='default') async with conn: self.assertEqual(conn.get_server_host_key_algs(), default_algs) @asynctest async def test_untrusted_known_hosts_key(self): """Test untrusted server host key""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=(['ckey.pub'], [], [])) @asynctest async def test_known_hosts_none_with_key(self): """Test disabled known hosts checking with server host key""" async with self.connect(known_hosts=None): pass class _TestHostKeyAlias(ServerTestCase): """Unit test for HostKeyAlias""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" skey = asyncssh.read_private_key('skey') skey_cert = skey.generate_host_certificate( skey, 'name', principals=['certifiedfakehost']) skey_cert.write_certificate('skey-cert.pub') return await cls.create_server(server_host_keys=['skey']) @classmethod async def asyncSetUpClass(cls): """Set up keys, custom host cert, and suitable known_hosts""" await super().asyncSetUpClass() skey_str = Path('skey.pub').read_text() Path('.ssh/known_hosts').write_text( f"fakehost {skey_str}" f"@cert-authority certifiedfakehost {skey_str}") Path('.ssh/config').write_text( 'Host server-with-key-config\n' ' Hostname 127.0.0.1\n' ' HostKeyAlias fakehost\n' '\n' 'Host server-with-cert-config\n' ' Hostname 127.0.0.1\n' ' HostKeyAlias certifiedfakehost\n') @asynctest async def test_host_key_mismatch(self): """Test host key mismatch""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() @asynctest async def test_host_key_unknown(self): """Test unknown host key alias""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(host_key_alias='unknown') @asynctest async def test_host_key_match(self): """Test host key match""" async with self.connect(host_key_alias='fakehost'): pass @asynctest async def test_host_cert_match(self): """Test host cert match""" async with self.connect(host_key_alias='certifiedfakehost'): pass @asynctest async def test_host_key_match_config(self): """Test host key match using HostKeyAlias in config file""" async with self.connect('server-with-key-config'): pass @asynctest async def test_host_cert_match_config(self): """Test host cert match using HostKeyAlias in config file""" async with self.connect('server-with-cert-config'): pass class _TestServerInternalError(ServerTestCase): """Unit test for server internal error during auth""" @classmethod async def start_server(cls): """Start an SSH server which raises an error during auth""" return await cls.create_server(_InternalErrorServer) @asynctest async def test_server_internal_error(self): """Test server internal error during auth""" with self.assertRaises(asyncssh.ChannelOpenError): conn = await self.connect() conn.send_debug('Test') await conn.run() class _TestInvalidAuthBanner(ServerTestCase): """Unit test for invalid auth banner""" @classmethod async def start_server(cls): """Start an SSH server which sends invalid auth banner""" return await cls.create_server(_InvalidAuthBannerServer) @asynctest async def test_invalid_auth_banner(self): """Test server sending invalid auth banner""" with self.assertRaises(asyncssh.ProtocolError): await self.connect() class _TestExpiredServerHostCertificate(ServerTestCase): """Unit tests for expired server host certificate""" @classmethod async def start_server(cls): """Start an SSH server with an expired host certificate""" return await cls.create_server(server_host_keys=['exp_skey']) @asynctest async def test_expired_server_host_cert(self): """Test expired server host certificate""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], ['skey.pub'], [])) @asynctest async def test_known_hosts_none_with_expired_cert(self): """Test disabled known hosts checking with expired host certificate""" async with self.connect(known_hosts=None): pass class _TestCustomClientVersion(ServerTestCase): """Unit test for custom SSH client version""" @classmethod async def start_server(cls): """Start an SSH server which sends client version in auth banner""" return await cls.create_server(_VersionReportingServer) async def _check_client_version(self, version): """Check custom client version""" conn, client = \ await self.create_connection(_VersionRecordingClient, client_version=version) async with conn: self.assertEqual(client.reported_version, 'SSH-2.0-custom') @asynctest async def test_custom_client_version(self): """Test custom client version""" await self._check_client_version('custom') @asynctest async def test_custom_client_version_bytes(self): """Test custom client version set as bytes""" await self._check_client_version(b'custom') @asynctest async def test_long_client_version(self): """Test client version which is too long""" with self.assertRaises(ValueError): await self.connect(client_version=246*'a') @asynctest async def test_nonprintable_client_version(self): """Test client version with non-printable character""" with self.assertRaises(ValueError): await self.connect(client_version='xxx\0') class _TestCustomServerVersion(ServerTestCase): """Unit test for custom SSH server version""" @classmethod async def start_server(cls): """Start an SSH server which sends a custom version""" return await cls.create_server(server_version='custom') @asynctest async def test_custom_server_version(self): """Test custom server version""" async with self.connect() as conn: version = conn.get_extra_info('server_version') self.assertEqual(version, 'SSH-2.0-custom') @asynctest async def test_long_server_version(self): """Test server version which is too long""" with self.assertRaises(ValueError): await self.create_server(server_version=246*'a') @asynctest async def test_nonprintable_server_version(self): """Test server version with non-printable character""" with self.assertRaises(ValueError): await self.create_server(server_version='xxx\0') @patch_getnameinfo class _TestReverseDNS(ServerTestCase): """Unit test for reverse DNS lookup of client address""" @classmethod async def start_server(cls): """Start an SSH server which sends a custom version""" with open('config', 'w') as f: f.write('Match host localhost\nPubkeyAuthentication no') return await cls.create_server( authorized_client_keys='authorized_keys', rdns_lookup=True, config='config') @asynctest async def test_reverse_dns(self): """Test reverse DNS of the client address""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey') class _TestListenerContextManager(ServerTestCase): """Test using an SSH listener as a context manager""" @classmethod async def start_server(cls): """Defer starting the SSH server to the test""" @asynctest async def test_ssh_listen_context_manager(self): """Test using an SSH listener as a context manager""" async with self.listen() as server: listen_port = server.get_port() async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass @patch_getaddrinfo class _TestCanonicalizeHost(ServerTestCase): """Test hostname canonicalization""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return await cls.create_server(_TunnelServer) @asynctest async def test_canonicalize(self): """Test hostname canonicalization""" async with self.connect('testhost', known_hosts=None, canonicalize_hostname=True, canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testhost.test') @asynctest async def test_canonicalize_max_dots(self): """Test hostname canonicalization exceeding max_dots""" async with self.connect('testhost.test', known_hosts=None, canonicalize_hostname=True, canonicalize_max_dots=0, canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testhost.test') @asynctest async def test_canonicalize_ip_address(self): """Test hostname canonicalization with IP address""" async with self.connect('127.0.0.1', known_hosts=None, canonicalize_hostname=True, canonicalize_max_dots=3, canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), '127.0.0.1') @asynctest async def test_canonicalize_proxy(self): """Test hostname canonicalization with proxy""" with open('config', 'w') as f: f.write('UserKnownHostsFile none\n') async with self.connect('testhost', config='config', tunnel=f'localhost:{self._server_port}', canonicalize_hostname=True, canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testhost.test') @asynctest async def test_canonicalize_always(self): """Test hostname canonicalization for all connections""" with open('config', 'w') as f: f.write('UserKnownHostsFile none\n') async with self.connect('testhost', config='config', tunnel=f'localhost:{self._server_port}', canonicalize_hostname='always', canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testhost.test') @asynctest async def test_canonicalize_failure(self): """Test hostname canonicalization failure""" with self.assertRaises(socket.gaierror): await self.connect('unknown', known_hosts=(['skey.pub'], [], []), canonicalize_hostname=True, canonical_domains=['test']) @asynctest async def test_canonicalize_failed_no_fallback(self): """Test hostname canonicalization""" with self.assertRaises(OSError): await self.connect('unknown', known_hosts=(['skey.pub'], [], []), canonicalize_hostname=True, canonical_domains=['test'], canonicalize_fallback_local=False) @asynctest async def test_cname_returned(self): """Test hostname canonicalization with cname returned""" async with self.connect('testcname', known_hosts=(['skey.pub'], [], []), canonicalize_hostname=True, canonical_domains=['test'], canonicalize_permitted_cnames= \ [('*.test', '*.test')]) as conn: self.assertEqual(conn.get_extra_info('host'), 'cname.test') @asynctest async def test_cname_not_returned(self): """Test hostname canonicalization with cname not returned""" async with self.connect('testcname', known_hosts=(['skey.pub'], [], []), canonicalize_hostname=True, canonical_domains=['test'], canonicalize_permitted_cnames= \ ['*.xxx:*.test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testcname.test') @asynctest async def test_bad_cname_rules(self): """Test hostname canonicalization with bad cname rules""" with self.assertRaises(ValueError): await self.connect('testcname', known_hosts=(['skey.pub'], [], []), canonicalize_hostname=True, canonical_domains=['test'], canonicalize_permitted_cnames= \ ['*.xxx:*.test:*.xxx']) asyncssh-2.20.0/tests/test_connection_auth.py000066400000000000000000002277771475467777400214430ustar00rootroot00000000000000# Copyright (c) 2016-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH connection authentication""" import asyncio import os import sys import unittest from unittest.mock import patch from cryptography.exceptions import UnsupportedAlgorithm import asyncssh from asyncssh.misc import async_context_manager, write_file from asyncssh.packet import String from asyncssh.public_key import CERT_TYPE_USER, CERT_TYPE_HOST from .keysign_stub import create_subprocess_exec_stub from .server import Server, ServerTestCase from .util import asynctest, gss_available, patch_getnameinfo from .util import patch_getnameinfo_error, patch_gss from .util import make_certificate, nc_available, x509_available class _FailValidateHostSSHServerConnection(asyncssh.SSHServerConnection): """Test error in validating host key signature""" async def validate_host_based_auth(self, username, key_data, client_host, client_username, msg, signature): """Validate host based authentication for the specified host and user""" return await super().validate_host_based_auth(username, key_data, client_host, client_username, msg + b'\xff', signature) class _AsyncGSSServer(asyncssh.SSHServer): """Server for testing async GSS authentication""" # pylint: disable=useless-super-delegation async def validate_gss_principal(self, username, user_principal, host_principal): """Return whether password is valid for this user""" return super().validate_gss_principal(username, user_principal, host_principal) class _NullServer(Server): """Server for testing disabled auth""" async def begin_auth(self, username): """Handle client authentication request""" return False class _HostBasedServer(Server): """Server for testing host-based authentication""" def __init__(self, host_key=None, ca_key=None): super().__init__() self._host_key = \ asyncssh.read_public_key(host_key) if host_key else None self._ca_key = \ asyncssh.read_public_key(ca_key) if ca_key else None def host_based_auth_supported(self): """Return whether or not host based authentication is supported""" return True def validate_host_public_key(self, client_host, client_addr, client_port, key): """Return whether key is an authorized key for this host""" # pylint: disable=unused-argument return key == self._host_key def validate_host_ca_key(self, client_host, client_addr, client_port, key): """Return whether key is an authorized CA key for this host""" # pylint: disable=unused-argument return key == self._ca_key def validate_host_based_user(self, username, client_host, client_username): """Return whether remote host and user is authorized for this user""" # pylint: disable=unused-argument return client_username == 'user' class _AsyncHostBasedServer(Server): """Server for testing async host-based authentication""" # pylint: disable=useless-super-delegation async def validate_host_based_user(self, username, client_host, client_username): """Return whether remote host and user is authorized for this user""" return super().validate_host_based_user(username, client_host, client_username) class _InvalidUsernameClientConnection(asyncssh.connection.SSHClientConnection): """Test sending a client username with invalid Unicode to the server""" async def host_based_auth_requested(self): """Return a host key pair, host, and user to authenticate with""" keypair, host, _ = await super().host_based_auth_requested() return keypair, host, b'\xff' class _PublicKeyClient(asyncssh.SSHClient): """Test client public key authentication""" def __init__(self, keylist, delay=0): self._keylist = keylist self._delay = delay async def public_key_auth_requested(self): """Return a public key to authenticate with""" if self._delay: await asyncio.sleep(self._delay) return self._keylist.pop(0) if self._keylist else None class _AsyncPublicKeyClient(_PublicKeyClient): """Test async client public key authentication""" # pylint: disable=useless-super-delegation async def public_key_auth_requested(self): """Return a public key to authenticate with""" return await super().public_key_auth_requested() class _PublicKeyServer(Server): """Server for testing public key authentication""" def __init__(self, client_keys=(), authorized_keys=None, delay=0): super().__init__() self._client_keys = client_keys self._authorized_keys = authorized_keys self._delay = delay def connection_made(self, conn): """Called when a connection is made""" super().connection_made(conn) conn.send_auth_banner('auth banner') async def begin_auth(self, username): """Handle client authentication request""" if self._authorized_keys: self._conn.set_authorized_keys(self._authorized_keys) else: self._client_keys = asyncssh.load_public_keys(self._client_keys) if self._delay: await asyncio.sleep(self._delay) return True def public_key_auth_supported(self): """Return whether or not public key authentication is supported""" return True def validate_public_key(self, username, key): """Return whether key is an authorized client key for this user""" return key in self._client_keys def validate_ca_key(self, username, key): """Return whether key is an authorized CA key for this user""" return key in self._client_keys class _AsyncPublicKeyServer(_PublicKeyServer): """Server for testing async public key authentication""" # pylint: disable=useless-super-delegation async def begin_auth(self, username): """Handle client authentication request""" return await super().begin_auth(username) async def validate_public_key(self, username, key): """Return whether key is an authorized client key for this user""" return super().validate_public_key(username, key) async def validate_ca_key(self, username, key): """Return whether key is an authorized CA key for this user""" return super().validate_ca_key(username, key) class _PasswordClient(asyncssh.SSHClient): """Test client password authentication""" def __init__(self, password, old_password, new_password): self._password = password self._old_password = old_password self._new_password = new_password def password_auth_requested(self): """Return a password to authenticate with""" if self._password: result = self._password self._password = None return result else: return None def password_change_requested(self, prompt, lang): """Change the client's password""" return self._old_password, self._new_password class _AsyncPasswordClient(_PasswordClient): """Test async client password authentication""" # pylint: disable=useless-super-delegation async def password_auth_requested(self): """Return a password to authenticate with""" return super().password_auth_requested() async def password_change_requested(self, prompt, lang): """Change the client's password""" return super().password_change_requested(prompt, lang) class _PasswordServer(Server): """Server for testing password authentication""" def password_auth_supported(self): """Enable password authentication""" return True def validate_password(self, username, password): """Accept password of pw, trigger password change on oldpw""" if password == 'oldpw': raise asyncssh.PasswordChangeRequired('Password change required') else: return password == 'pw' def change_password(self, username, old_password, new_password): """Only allow password change from password oldpw""" return old_password == 'oldpw' class _AsyncPasswordServer(_PasswordServer): """Server for testing async password authentication""" # pylint: disable=useless-super-delegation async def validate_password(self, username, password): """Return whether password is valid for this user""" return super().validate_password(username, password) async def change_password(self, username, old_password, new_password): """Handle a request to change a user's password""" return super().change_password(username, old_password, new_password) class _KbdintClient(asyncssh.SSHClient): """Test keyboard-interactive client auth""" def __init__(self, responses): self._responses = responses def kbdint_auth_requested(self): """Return the list of supported keyboard-interactive auth methods""" return '' if self._responses else None def kbdint_challenge_received(self, name, instructions, lang, prompts): """Return responses to a keyboard-interactive auth challenge""" # pylint: disable=unused-argument if not prompts: return [] elif self._responses: result = self._responses self._responses = None return result else: return None class _AsyncKbdintClient(_KbdintClient): """Test keyboard-interactive client auth""" # pylint: disable=useless-super-delegation async def kbdint_auth_requested(self): """Return the list of supported keyboard-interactive auth methods""" return super().kbdint_auth_requested() async def kbdint_challenge_received(self, name, instructions, lang, prompts): """Return responses to a keyboard-interactive auth challenge""" return super().kbdint_challenge_received(name, instructions, lang, prompts) class _KbdintServer(Server): """Server for testing keyboard-interactive authentication""" def __init__(self): super().__init__() self._kbdint_round = 0 def kbdint_auth_supported(self): """Enable keyboard-interactive authentication""" return True def get_kbdint_challenge(self, username, lang, submethods): """Return an initial challenge with only instructions""" return '', 'instructions', '', [] def validate_kbdint_response(self, username, responses): """Return a password challenge after the instructions""" if self._kbdint_round == 0: if username == 'none': result = ('', '', '', []) elif username == 'pw': result = ('', '', '', [('Password:', False)]) elif username == 'pc': result = ('', '', '', [('Passcode:', False)]) elif username == 'multi': result = ('', '', '', [('Prompt1:', True), ('Prompt2', True)]) else: result = ('', '', '', [('Other Challenge:', False)]) else: if responses in ([], ['kbdint'], ['1', '2']): result = True else: result = ('', '', '', [('Second Challenge:', True)]) self._kbdint_round += 1 return result class _AsyncKbdintServer(_KbdintServer): """Server for testing async keyboard-interactive authentication""" # pylint: disable=useless-super-delegation async def get_kbdint_challenge(self, username, lang, submethods): """Return a keyboard-interactive auth challenge""" return super().get_kbdint_challenge(username, lang, submethods) async def validate_kbdint_response(self, username, responses): """Return whether the keyboard-interactive response is valid for this user""" return super().validate_kbdint_response(username, responses) class _UnknownAuthClientConnection(asyncssh.connection.SSHClientConnection): """Test getting back an unknown auth method from the SSH server""" def try_next_auth(self, *, next_method=False): """Attempt client authentication using an unknown method""" self._auth_methods = [b'unknown'] + self._auth_methods super().try_next_auth(next_method=next_method) class _TestNullAuth(ServerTestCase): """Unit tests for testing disabled authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports disabled authentication""" return await cls.create_server(_NullServer) @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port) self.assertEqual(auth_methods, ['none']) @asynctest async def test_disabled_auth(self): """Test disabled authentication""" async with self.connect(username='user'): pass @asynctest async def test_disabled_trivial_auth(self): """Test disabling trivial auth with no authentication""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', disable_trivial_auth=True) @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSAuth(ServerTestCase): """Unit tests for GSS authentication""" @unittest.skipIf(sys.platform == 'win32', 'skip GSS store test on Windows') @classmethod async def start_server(cls): """Start an SSH server which supports GSS authentication""" return await cls.create_server(_AsyncGSSServer, gss_host='1', gss_store='a') @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port) self.assertEqual(auth_methods, ['gssapi-with-mic']) @asynctest async def test_gss_kex_auth(self): """Test GSS key exchange authentication""" async with self.connect(kex_algs=['gss-gex-sha256'], username='user', gss_host='1'): pass @asynctest async def test_gss_mic_auth(self): """Test GSS MIC authentication""" async with self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1'): pass @unittest.skipIf(sys.platform == 'win32', 'skip GSS store test on Windows') @asynctest async def test_gss_mic_auth_store(self): """Test GSS MIC authentication with GSS store set""" async with self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1', gss_store='a'): pass @asynctest async def test_gss_mic_auth_sign_error(self): """Test GSS MIC authentication signing failure""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1,sign_error') @asynctest async def test_gss_mic_auth_verify_error(self): """Test GSS MIC authentication signature verification failure""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1,verify_error') @asynctest async def test_gss_delegate(self): """Test GSS credential delegation""" async with self.connect(username='user', gss_host='1', gss_delegate_creds=True): pass @asynctest async def test_gss_kex_disabled(self): """Test GSS key exchange being disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', gss_host=(), gss_kex=False, preferred_auth='gssapi-keyex') @asynctest async def test_gss_auth_disabled(self): """Test GSS authentication being disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', gss_host=(), gss_auth=False) @asynctest async def test_gss_auth_unavailable(self): """Test GSS authentication being unavailable""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user1', gss_host=()) @asynctest async def test_gss_client_error(self): """Test GSS client error""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(gss_host='1,init_error', username='user') @asynctest async def test_disabled_trivial_gss_kex_auth(self): """Test disabling trivial auth with GSS key exchange authentication""" async with self.connect(kex_algs=['gss-gex-sha256'], username='user', gss_host='1', disable_trivial_auth=True): pass @asynctest async def test_disabled_trivial_gss_mic_auth(self): """Test disabling trivial auth with GSS MIC authentication""" async with self.connect(kex_algs=['ecdh-sha2-nistp256'], username='user', gss_host='1', disable_trivial_auth=True): pass @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSServerAuthDisabled(ServerTestCase): """Unit tests for with GSS key exchange and auth disabled on server""" @classmethod async def start_server(cls): """Start an SSH server with GSS key exchange and auth disabled""" return await cls.create_server(gss_host='1', gss_kex=False, gss_auth=False) @asynctest async def test_gss_kex_unavailable(self): """Test GSS key exchange being unavailable""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', gss_host=(), preferred_auth='gssapi-keyex') @asynctest async def test_gss_auth_unavailable(self): """Test GSS authentication being unavailable""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', gss_host=(), preferred_auth='gssapi-with-mic') @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSServerError(ServerTestCase): """Unit tests for GSS server error""" @classmethod async def start_server(cls): """Start an SSH server which raises an error on GSS authentication""" return await cls.create_server(gss_host='1,init_error') @asynctest async def test_gss_server_error(self): """Test GSS error on server""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user') @unittest.skipUnless(gss_available, 'GSS not available') @patch_gss class _TestGSSFQDN(ServerTestCase): """Unit tests for GSS server error""" @classmethod async def start_server(cls): """Start an SSH server which raises an error on GSS authentication""" def mock_gethostname(): """Return a non-fully-qualified hostname""" return 'host' def mock_getfqdn(): """Confirm getfqdn is called on relative hostnames""" return '1' with patch('socket.gethostname', mock_gethostname): with patch('socket.getfqdn', mock_getfqdn): return await cls.create_server(gss_host=()) @asynctest async def test_gss_fqdn_lookup(self): """Test GSS FQDN lookup""" async with self.connect(username='user', gss_host=()): pass @patch_getnameinfo class _TestHostBasedAuth(ServerTestCase): """Unit tests for host-based authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports host-based authentication""" return await cls.create_server( _HostBasedServer, known_client_hosts='known_hosts') @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port, username='user') self.assertEqual(auth_methods, ['hostbased']) @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_get_server_auth_methods_no_sockname(self): """Test getting auth methods from the test server""" proxy_command = ('nc', str(self._server_addr), str(self._server_port)) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', proxy_command=proxy_command) @asynctest async def test_client_host_auth(self): """Test connecting with host-based authentication""" async with self.connect(username='user', client_host_keys='skey', client_username='user'): pass @asynctest async def test_client_host_auth_disabled(self): """Test connecting with host-based authentication disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', client_username='user', host_based_auth=False) @asynctest async def test_client_host_key_bytes(self): """Test client host key passed in as bytes""" with open('skey', 'rb') as f: skey = f.read() async with self.connect(username='user', client_host_keys=[skey], client_username='user'): pass @asynctest async def test_client_host_key_sshkey(self): """Test client host key passed in as an SSHKey""" skey = asyncssh.read_private_key('skey') async with self.connect(username='user', client_host_keys=[skey], client_username='user'): pass @asynctest async def test_client_host_key_keypairs(self): """Test client host keys passed in as a list of SSHKeyPairs""" keys = asyncssh.load_keypairs('skey') async with self.connect(username='user', client_host_keys=keys, client_username='user'): pass @asynctest async def test_client_host_signature_algs(self): """Test host based authentication with specific signature algorithms""" for alg in ('rsa-sha2-256', 'rsa-sha2-512'): async with self.connect(username='user', client_host_keys='skey', client_username='user', signature_algs=[alg]): pass @asynctest async def test_no_server_signature_algs(self): """Test a server which doesn't advertise signature algorithms""" def skip_ext_info(self): """Don't send extension information""" # pylint: disable=unused-argument return [] with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): try: async with self.connect(username='user', client_host_keys='skey', client_username='user'): pass except UnsupportedAlgorithm: # pragma: no cover pass @asynctest async def test_untrusted_client_host_key(self): """Test untrusted client host key""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='ckey', client_username='user') @asynctest async def test_missing_cert(self): """Test missing client host certificate""" with self.assertRaises(OSError): await self.connect(username='user', client_host_keys=[('skey', 'xxx')], client_username='user') @asynctest async def test_invalid_client_host_signature(self): """Test invalid client host signature""" with patch('asyncssh.connection.SSHServerConnection', _FailValidateHostSSHServerConnection): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', client_username='user') @asynctest async def test_client_host_trailing_dot(self): """Test stripping of trailing dot from client host""" async with self.connect(username='user', client_host_keys='skey', client_host='localhost.', client_username='user'): pass @asynctest async def test_mismatched_client_host(self): """Test ignoring of mismatched client host due to canonicalization""" async with self.connect(username='user', client_host_keys='skey', client_host='xxx', client_username='user'): pass @asynctest async def test_mismatched_client_username(self): """Test mismatched client username""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', client_username='xxx') @asynctest async def test_invalid_client_username(self): """Test invalid client username""" with patch('asyncssh.connection.SSHClientConnection', _InvalidUsernameClientConnection): with self.assertRaises(asyncssh.ProtocolError): await self.connect(username='user', client_host_keys='skey') @asynctest async def test_expired_cert(self): """Test expired certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_HOST, ckey, skey, ['localhost'], valid_before=1) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys=[(ckey, cert)], client_username='user') @asynctest async def test_untrusted_ca(self): """Test untrusted CA""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_HOST, ckey, ckey, ['localhost']) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys=[(ckey, cert)], client_username='user') @asynctest async def test_disabled_trivial_client_host_auth(self): """Test disabling trivial auth with host-based authentication""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', client_username='user', disable_trivial_auth=True) class _TestHostBasedAuthNoRDNS(ServerTestCase): """Unit tests for host-based authentication with no reverse DNS""" @classmethod async def start_server(cls): """Start an SSH server which supports host-based authentication""" return await cls.create_server( _HostBasedServer, known_client_hosts='known_hosts') @patch_getnameinfo_error @asynctest async def test_client_host_auth_no_rdns(self): """Test connecting with host-based authentication with no RDNS""" async with self.connect(username='user', client_host_keys='skey', client_username='user'): pass @patch_getnameinfo class _TestCallbackHostBasedAuth(ServerTestCase): """Unit tests for host-based authentication using callback""" @classmethod async def start_server(cls): """Start an SSH server which supports host-based authentication""" def server_factory(): """Return an SSHServer which can validate the client host key""" return _HostBasedServer(host_key='skey.pub', ca_key='skey.pub') return await cls.create_server(server_factory) @asynctest async def test_validate_client_host_callback(self): """Test using callback to validate client host key""" async with self.connect(username='user', client_host_keys=[('skey', None)], client_username='user'): pass @asynctest async def test_validate_client_host_ca_callback(self): """Test using callback to validate client host CA key""" async with self.connect(username='user', client_host_keys='skey', client_username='user'): pass @asynctest async def test_untrusted_client_host_callback(self): """Test callback to validate client host key returning failure""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys=[('ckey', None)], client_username='user') @asynctest async def test_untrusted_client_host_ca_callback(self): """Test callback to validate client host CA key returning failure""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='ckey', client_username='user') @patch_getnameinfo class _TestKeysignHostBasedAuth(ServerTestCase): """Unit tests for host-based authentication using ssh-keysign""" @classmethod async def start_server(cls): """Start an SSH server which supports host-based authentication""" return await cls.create_server( _HostBasedServer, known_client_hosts=(['skey_ecdsa.pub'], [], [])) @async_context_manager async def _connect_keysign(self, client_host_keysign=True, client_host_keys=None, keysign_dirs=('.',)): """Open a connection to test host-based auth using ssh-keysign""" with patch('asyncio.create_subprocess_exec', create_subprocess_exec_stub): with patch('asyncssh.keysign._DEFAULT_KEYSIGN_DIRS', keysign_dirs): with patch('asyncssh.public_key._DEFAULT_HOST_KEY_DIRS', ['.']): with patch('asyncssh.public_key._DEFAULT_HOST_KEY_FILES', ['skey_ecdsa', 'xxx']): return await self.connect( username='user', client_host_keysign=client_host_keysign, client_host_keys=client_host_keys, client_username='user') @asynctest async def test_keysign(self): """Test host-based authentication using ssh-keysign""" async with self._connect_keysign(): pass @asynctest async def test_explciit_keysign(self): """Test ssh-keysign with an explicit path""" async with self._connect_keysign(client_host_keysign='.'): pass @asynctest async def test_keysign_explicit_host_keys(self): """Test ssh-keysign with explicit host public keys""" async with self._connect_keysign(client_host_keys='skey_ecdsa.pub'): pass @asynctest async def test_invalid_keysign_response(self): """Test invalid ssh-keysign response""" with patch('asyncssh.keysign.KEYSIGN_VERSION', 0): with self.assertRaises(asyncssh.PermissionDenied): await self._connect_keysign() @asynctest async def test_keysign_error(self): """Test ssh-keysign error response""" with patch('asyncssh.keysign.KEYSIGN_VERSION', 1): with self.assertRaises(asyncssh.PermissionDenied): await self._connect_keysign() @asynctest async def test_invalid_keysign_version(self): """Test invalid version in ssh-keysign request""" with patch('asyncssh.keysign.KEYSIGN_VERSION', 99): with self.assertRaises(asyncssh.PermissionDenied): await self._connect_keysign() @asynctest async def test_keysign_not_found(self): """Test ssh-keysign executable not being found""" with self.assertRaises(ValueError): await self._connect_keysign(keysign_dirs=()) @asynctest async def test_explicit_keysign_not_found(self): """Test explicit ssh-keysign executable not being found""" with self.assertRaises(ValueError): await self._connect_keysign(client_host_keysign='xxx') @asynctest async def test_keysign_dir_not_present(self): """Test ssh-keysign executable not in a keysign dir""" with self.assertRaises(ValueError): await self._connect_keysign(keysign_dirs=('xxx',)) @patch_getnameinfo class _TestHostBasedAsyncServerAuth(_TestHostBasedAuth): """Unit tests for host-based authentication with async server callbacks""" @classmethod async def start_server(cls): """Start an SSH server which supports async host-based auth""" return await cls.create_server(_AsyncHostBasedServer, known_client_hosts='known_hosts', trust_client_host=True) @asynctest async def test_mismatched_client_host(self): """Test mismatch of trusted client host""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='user', client_host_keys='skey', client_host='xxx', client_username='user') @patch_getnameinfo class _TestLimitedHostBasedSignatureAlgs(ServerTestCase): """Unit tests for limited host key signature algorithms""" @classmethod async def start_server(cls): """Start an SSH server which supports host-based authentication""" return await cls.create_server( _HostBasedServer, known_client_hosts='known_hosts', signature_algs=['ssh-rsa', 'rsa-sha2-512']) @asynctest async def test_mismatched_host_signature_algs(self): """Test mismatched host key signature algorithms""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_host_keys='skey', client_username='user', signature_algs=['rsa-sha2-256']) @asynctest async def test_host_signature_alg_fallback(self): """Test fall back to default host key signature algorithm""" try: async with self.connect(username='ckey', client_host_keys='skey', client_username='user', signature_algs=['rsa-sha2-256', 'ssh-rsa']): pass except UnsupportedAlgorithm: # pragma: no cover pass class _TestPublicKeyAuth(ServerTestCase): """Unit tests for public key authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" return await cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys') @async_context_manager async def _connect_publickey(self, keylist, test_async=False): """Open a connection to test public key auth""" def client_factory(): """Return an SSHClient to use to do public key auth""" cls = _AsyncPublicKeyClient if test_async else _PublicKeyClient return cls(keylist) conn, _ = await self.create_connection(client_factory, username='ckey', client_keys=None) return conn @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port) self.assertEqual(auth_methods, ['publickey']) @asynctest async def test_encrypted_client_key(self): """Test public key auth with encrypted client key""" async with self.connect(username='ckey', client_keys='ckey_encrypted', passphrase='passphrase'): pass @asynctest async def test_encrypted_client_key_callable(self): """Test public key auth with callable passphrase""" def _passphrase(filename): self.assertEqual(filename, 'ckey_encrypted') return 'passphrase' async with self.connect(username='ckey', client_keys='ckey_encrypted', passphrase=_passphrase): pass @asynctest async def test_encrypted_client_key_awaitable(self): """Test public key auth with awaitable passphrase""" async def _passphrase(filename): self.assertEqual(filename, 'ckey_encrypted') return 'passphrase' async with self.connect(username='ckey', client_keys='ckey_encrypted', passphrase=_passphrase): pass @asynctest async def test_encrypted_client_key_list_callable(self): """Test public key auth with callable passphrase""" def _passphrase(filename): self.assertEqual(filename, 'ckey_encrypted') return 'passphrase' async with self.connect(username='ckey', client_keys=['ckey_encrypted'], passphrase=_passphrase): pass @asynctest async def test_encrypted_client_key_list_awaitable(self): """Test public key auth with awaitable passphrase""" async def _passphrase(filename): self.assertEqual(filename, 'ckey_encrypted') return 'passphrase' async with self.connect(username='ckey', client_keys=['ckey_encrypted'], passphrase=_passphrase): pass @asynctest async def test_encrypted_client_key_bad_passphrase(self): """Test wrong passphrase for encrypted client key""" with self.assertRaises(asyncssh.KeyEncryptionError): await self.connect(username='ckey', client_keys='ckey_encrypted', passphrase='xxx') @asynctest async def test_encrypted_client_key_missing_passphrase(self): """Test missing passphrase for encrypted client key""" with self.assertRaises(asyncssh.KeyImportError): await self.connect(username='ckey', client_keys='ckey_encrypted') @asynctest async def test_client_certs(self): """Test trusted client certificate via client_certs""" async with self.connect(username='ckey', client_keys='ckey', client_certs='ckey-cert.pub'): pass @asynctest async def test_agent_auth(self): """Test connecting with ssh-agent authentication""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with self.connect(username='ckey'): pass @asynctest async def test_agent_identities(self): """Test connecting with ssh-agent auth with specific identities""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') ckey = asyncssh.read_private_key('ckey') ckey.write_private_key('ckey.pem', 'pkcs8-pem') ckey_cert = asyncssh.read_certificate('ckey-cert.pub') ckey_ecdsa = asyncssh.read_public_key('ckey_ecdsa.pub') for pubkey in ('ckey-cert.pub', 'ckey_ecdsa.pub', 'ckey.pem', ckey_cert, ckey_ecdsa, ckey_ecdsa.public_data): async with self.connect(username='ckey', agent_identities=pubkey): pass @asynctest async def test_agent_identities_config(self): """Test connecting with ssh-agent auth and IdentitiesOnly config""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') write_file('ckey_err', b'') write_file('config', 'IdentitiesOnly True\n' 'IdentityFile ckey-cert.pub\n' 'IdentityFile ckey_ecdsa.pub\n' 'IdentityFile ckey_err\n', 'w') async with self.connect(username='ckey', config='config'): pass @asynctest async def test_agent_identities_config_default_keys(self): """Test connecting with ssh-agent auth and default IdentitiesOnly""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') write_file('config', 'IdentitiesOnly True\n', 'w') async with self.connect(username='ckey', config='config'): pass @asynctest async def test_agent_signature_algs(self): """Test ssh-agent keys with specific signature algorithms""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') for alg in ('rsa-sha2-256', 'rsa-sha2-512'): async with self.connect(username='ckey', signature_algs=[alg]): pass @asynctest async def test_agent_auth_failure(self): """Test failure connecting with ssh-agent authentication""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') with patch.dict(os.environ, HOME='xxx'): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', agent_path='xxx', known_hosts='.ssh/known_hosts') @asynctest async def test_agent_auth_unset(self): """Test connecting with no local keys and no ssh-agent configured""" with patch.dict(os.environ, HOME='xxx', USERPROFILE='xxx', SSH_AUTH_SOCK=''): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', known_hosts='.ssh/known_hosts') @asynctest async def test_public_key_auth(self): """Test connecting with public key authentication""" async with self.connect(username='ckey', client_keys='ckey'): pass @asynctest async def test_public_key_auth_disabled(self): """Test connecting with public key authentication disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys='ckey', public_key_auth=False) @asynctest async def test_public_key_auth_not_preferred(self): """Test public key authentication not being in preferred auth list""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys='ckey', preferred_auth='password') @asynctest async def test_public_key_signature_algs(self): """Test public key authentication with specific signature algorithms""" for alg in ('rsa-sha2-256', 'rsa-sha2-512'): async with self.connect(username='ckey', agent_path=None, client_keys='ckey', signature_algs=[alg]): pass @asynctest async def test_no_server_signature_algs(self): """Test a server which doesn't advertise signature algorithms""" def skip_ext_info(self): """Don't send extension information""" # pylint: disable=unused-argument return [] with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): try: async with self.connect(username='ckey', client_keys='ckey', agent_path=None): pass except UnsupportedAlgorithm: # pragma: no cover pass @asynctest async def test_default_public_key_auth(self): """Test connecting with default public key authentication""" async with self.connect(username='ckey', agent_path=None): pass @asynctest async def test_invalid_default_key(self): """Test connecting with invalid default client key""" key_path = os.path.join('.ssh', 'id_dsa') with open(key_path, 'w') as f: f.write('-----XXX-----') with self.assertRaises(asyncssh.KeyImportError): await self.connect(username='ckey', agent_path=None) os.remove(key_path) @asynctest async def test_client_key_bytes(self): """Test client key passed in as bytes""" with open('ckey', 'rb') as f: ckey = f.read() async with self.connect(username='ckey', client_keys=[ckey]): pass @asynctest async def test_client_key_sshkey(self): """Test client key passed in as an SSHKey""" ckey = asyncssh.read_private_key('ckey') async with self.connect(username='ckey', client_keys=[ckey]): pass @asynctest async def test_client_key_keypairs(self): """Test client keys passed in as a list of SSHKeyPairs""" keys = asyncssh.load_keypairs('ckey') async with self.connect(username='ckey', client_keys=keys): pass @asynctest async def test_client_key_agent_keypairs(self): """Test client keys passed in as a list of SSHAgentKeyPairs""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with asyncssh.connect_agent() as agent: for key in await agent.get_keys(): async with self.connect(username='ckey', client_keys=[key]): pass @asynctest async def test_keypair_with_replaced_cert(self): """Test connecting with a keypair with replaced cert""" ckey = asyncssh.load_keypairs(['ckey'])[0] async with self.connect(username='ckey', client_keys=[(ckey, 'ckey-cert.pub')]): pass @asynctest async def test_agent_keypair_with_replaced_cert(self): """Test connecting with an agent key with replaced cert""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with asyncssh.connect_agent() as agent: ckey = (await agent.get_keys())[2] async with self.connect(username='ckey', client_keys=[(ckey, 'ckey-cert.pub')]): pass @asynctest async def test_untrusted_client_key(self): """Test untrusted client key""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys='skey', agent_path=None) @asynctest async def test_missing_cert(self): """Test missing client certificate""" with self.assertRaises(OSError): await self.connect(username='ckey', client_keys=[('ckey', 'xxx')]) @asynctest async def test_expired_cert(self): """Test expired certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], valid_before=1) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(skey, cert)], agent_path=None) @asynctest async def test_allowed_address(self): """Test allowed address in certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], options={'source-address': String('0.0.0.0/0,::/0')}) async with self.connect(username='ckey', client_keys=[(skey, cert)]): pass @asynctest async def test_disallowed_address(self): """Test disallowed address in certificate""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, ckey, ['ckey'], options={'source-address': String('0.0.0.0')}) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(skey, cert)], agent_path=None) @asynctest async def test_untrusted_ca(self): """Test untrusted CA""" skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, skey, ['skey']) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(skey, cert)], agent_path=None) @asynctest async def test_mismatched_ca(self): """Test mismatched CA""" ckey = asyncssh.read_private_key('ckey') skey = asyncssh.read_private_key('skey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, skey, skey, ['skey']) with self.assertRaises(ValueError): await self.connect(username='ckey', client_keys=[(ckey, cert)]) @asynctest async def test_callback(self): """Test connecting with public key authentication using callback""" async with self._connect_publickey(['ckey'], test_async=True): pass @asynctest async def test_callback_sshkeypair(self): """Test client key passed in as an SSHKeyPair by callback""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with asyncssh.connect_agent() as agent: keylist = await agent.get_keys() async with self._connect_publickey(keylist): pass @asynctest async def test_callback_untrusted_client_key(self): """Test failure connecting with public key authentication callback""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_publickey(['skey']) @asynctest async def test_unknown_auth(self): """Test server returning an unknown auth method before public key""" with patch('asyncssh.connection.SSHClientConnection', _UnknownAuthClientConnection): async with self.connect(username='ckey', client_keys='ckey', agent_path=None): pass @asynctest async def test_disabled_trivial_public_key_auth(self): """Test disabling trivial auth with public key authentication""" async with self.connect(username='ckey', agent_path=None, disable_trivial_auth=True): pass class _TestPublicKeyAsyncServerAuth(_TestPublicKeyAuth): """Unit tests for public key authentication with async server callbacks""" @classmethod async def start_server(cls): """Start an SSH server which supports async public key auth""" def server_factory(): """Return an SSH server which trusts specific client keys""" return _AsyncPublicKeyServer(client_keys=['ckey.pub', 'ckey_ecdsa.pub']) return await cls.create_server(server_factory) class _TestLimitedPublicKeySignatureAlgs(ServerTestCase): """Unit tests for limited public key signature algorithms""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" return await cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', signature_algs=['ssh-rsa', 'rsa-sha2-512']) @asynctest async def test_mismatched_client_signature_algs(self): """Test mismatched client key signature algorithms""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys='ckey', signature_algs=['rsa-sha2-256']) class _TestSetAuthorizedKeys(ServerTestCase): """Unit tests for public key authentication with set_authorized_keys""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" return _PublicKeyServer(authorized_keys='authorized_keys') return await cls.create_server(server_factory) @asynctest async def test_set_authorized_keys(self): """Test set_authorized_keys method on server""" async with self.connect(username='ckey', client_keys='ckey'): pass @asynctest async def test_cert_principals(self): """Test certificate principals check""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey']) async with self.connect(username='ckey', client_keys=[(ckey, cert)]): pass class _TestPreloadedAuthorizedKeys(ServerTestCase): """Unit tests for authentication with pre-loaded authorized keys""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" authorized_keys = asyncssh.read_authorized_keys('authorized_keys') return _PublicKeyServer(authorized_keys=authorized_keys) return await cls.create_server(server_factory) @asynctest async def test_pre_loaded_authorized_keys(self): """Test pre-loaded authorized keys file""" async with self.connect(username='ckey', client_keys='ckey'): pass class _TestPreloadedAuthorizedKeysFileList(ServerTestCase): """Unit tests with pre-loaded authorized keys file list""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSH server which calls set_authorized_keys""" authorized_keys = asyncssh.read_authorized_keys(['authorized_keys']) return _PublicKeyServer(authorized_keys=authorized_keys) return await cls.create_server(server_factory) @asynctest async def test_pre_loaded_authorized_keys(self): """Test pre-loaded authorized keys file list""" async with self.connect(username='ckey', client_keys='ckey'): pass @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Auth(ServerTestCase): """Unit tests for X.509 certificate authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" return await cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys_x509') @asynctest async def test_x509_self(self): """Test connecting with X.509 self-signed certificate""" async with self.connect(username='ckey', client_keys=['ckey_x509_self']): pass @asynctest async def test_x509_chain(self): """Test connecting with X.509 certificate chain""" async with self.connect(username='ckey', client_keys=['ckey_x509_chain']): pass @asynctest async def test_keypair_with_x509_cert(self): """Test connecting with a keypair with replaced X.509 cert""" ckey = asyncssh.load_keypairs(['ckey'])[0] async with self.connect(username='ckey', client_keys=[(ckey, 'ckey_x509_chain')]): pass @asynctest async def test_agent_keypair_with_x509_cert(self): """Test connecting with an agent key with replaced X.509 cert""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with asyncssh.connect_agent() as agent: ckey = (await agent.get_keys())[2] async with self.connect(username='ckey', client_keys=[(ckey, 'ckey_x509_chain')]): pass @asynctest async def test_x509_incomplete_chain(self): """Test connecting with incomplete X.509 certificate chain""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[('ckey_x509_chain', 'ckey_x509_partial.pem')]) @asynctest async def test_x509_untrusted_cert(self): """Test connecting with untrusted X.509 certificate chain""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=['skey_x509_chain']) @asynctest async def test_disabled_trivial_x509_auth(self): """Test disabling trivial auth with X.509 certificate authentication""" async with self.connect(username='ckey', client_keys=['ckey_x509_self'], disable_trivial_auth=True): pass @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509AuthDisabled(ServerTestCase): """Unit tests for disabled X.509 certificate authentication""" @classmethod async def start_server(cls): """Start an SSH server which doesn't support X.509 authentication""" return await cls.create_server( _PublicKeyServer, x509_trusted_certs=None, authorized_client_keys='authorized_keys') @asynctest async def test_failed_x509_auth(self): """Test connect failure with X.509 certificate""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=['ckey_x509_self'], signature_algs=['x509v3-ssh-rsa']) @asynctest async def test_non_x509(self): """Test connecting without an X.509 certificate""" async with self.connect(username='ckey', client_keys=['ckey']): pass @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Subject(ServerTestCase): """Unit tests for X.509 certificate authentication by subject name""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" authorized_keys = asyncssh.import_authorized_keys( 'x509v3-ssh-rsa subject=OU=name\n') return await cls.create_server( _PublicKeyServer, authorized_client_keys=authorized_keys, x509_trusted_certs=['ckey_x509_self.pub']) @asynctest async def test_x509_subject(self): """Test authenticating X.509 certificate by subject name""" async with self.connect(username='ckey', client_keys=['ckey_x509_self']): pass @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Untrusted(ServerTestCase): """Unit tests for X.509 authentication with no trusted certificates""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" return await cls.create_server(_PublicKeyServer, authorized_client_keys=None) @asynctest async def test_x509_untrusted(self): """Test untrusted X.509 self-signed certificate""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=['ckey_x509_self']) @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509Disabled(ServerTestCase): """Unit tests for X.509 authentication with server support disabled""" @classmethod async def start_server(cls): """Start an SSH server with X.509 authentication disabled""" return await cls.create_server(_PublicKeyServer, x509_purposes=None) @asynctest async def test_x509_disabled(self): """Test X.509 client certificate with server support disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys='skey_x509_self') class _TestPasswordAuth(ServerTestCase): """Unit tests for password authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports password authentication""" return await cls.create_server(_PasswordServer) @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port, username='pw') self.assertEqual(auth_methods, ['keyboard-interactive', 'password']) @async_context_manager async def _connect_password(self, username, password, old_password='', new_password='', disable_trivial_auth=False, test_async=False): """Open a connection to test password authentication""" def client_factory(): """Return an SSHClient to use to do password change""" cls = _AsyncPasswordClient if test_async else _PasswordClient return cls(password, old_password, new_password) conn, _ = await self.create_connection( client_factory, username=username, client_keys=None, disable_trivial_auth=disable_trivial_auth) return conn @asynctest async def test_password_auth(self): """Test connecting with password authentication""" async with self.connect(username='pw', password='pw', client_keys=None): pass @asynctest async def test_password_auth_callable(self): """Test connecting with a callable for password authentication""" async with self.connect(username='pw', password=lambda: 'pw', client_keys=None): pass @asynctest async def test_password_auth_async_callable(self): """Test connecting with an async callable for password authentication""" async def get_password(): return 'pw' async with self.connect(username='pw', password=get_password, client_keys=None): pass @asynctest async def test_password_auth_awaitable(self): """Test connecting with an awaitable for password authentication""" async def get_password(): return 'pw' async with self.connect(username='pw', password=get_password(), client_keys=None): pass @asynctest async def test_password_auth_disabled(self): """Test connecting with password authentication disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='pw', password='kbdint', password_auth=False, preferred_auth='password') @asynctest async def test_password_auth_failure(self): """Test _failure connecting with password authentication""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='pw', password='badpw', client_keys=None) @asynctest async def test_password_auth_callback(self): """Test connecting with password authentication callback""" async with self._connect_password('pw', 'pw', test_async=True): pass @asynctest async def test_password_auth_callback_failure(self): """Test failure connecting with password authentication callback""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_password('pw', 'badpw') @asynctest async def test_password_change(self): """Test password change""" async with self._connect_password('pw', 'oldpw', 'oldpw', 'pw', test_async=True): pass @asynctest async def test_password_change_failure(self): """Test failure of password change""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_password('pw', 'oldpw', 'badpw', 'pw') @asynctest async def test_disabled_trivial_password_auth(self): """Test disabling trivial auth with password authentication""" async with self.connect(username='pw', password='pw', client_keys=None, disable_trivial_auth=True): pass @asynctest async def test_disabled_trivial_password_change(self): """Test disabling trivial aith with password change""" async with self._connect_password('pw', 'oldpw', 'oldpw', 'pw', disable_trivial_auth=True): pass class _TestPasswordAsyncServerAuth(_TestPasswordAuth): """Unit tests for password authentication with async server callbacks""" @classmethod async def start_server(cls): """Start an SSH server which supports async password authentication""" return await cls.create_server(_AsyncPasswordServer) class _TestKbdintAuth(ServerTestCase): """Unit tests for keyboard-interactive authentication""" @classmethod async def start_server(cls): """Start an SSH server which supports keyboard-interactive auth""" return await cls.create_server(_KbdintServer) @asynctest async def test_get_server_auth_methods(self): """Test getting auth methods from the test server""" auth_methods = await asyncssh.get_server_auth_methods( self._server_addr, self._server_port, username='none') self.assertEqual(auth_methods, ['keyboard-interactive']) @async_context_manager async def _connect_kbdint(self, username, responses, test_async=False): """Open a connection to test keyboard-interactive auth""" def client_factory(): """Return an SSHClient to use to do keyboard-interactive auth""" cls = _AsyncKbdintClient if test_async else _KbdintClient return cls(responses) conn, _ = await self.create_connection(client_factory, username=username, client_keys=None) return conn @asynctest async def test_kbdint_auth_no_prompts(self): """Test keyboard-interactive authentication with no prompts""" async with self.connect(username='none', password='kbdint', client_keys=None): pass @asynctest async def test_kbdint_auth_password(self): """Test keyboard-interactive authentication via password""" async with self.connect(username='pw', password='kbdint', client_keys=None): pass @asynctest async def test_kbdint_auth_passcode(self): """Test keyboard-interactive authentication via passcode""" async with self.connect(username='pc', password='kbdint', client_keys=None): pass @asynctest async def test_kbdint_auth_not_password(self): """Test keyboard-interactive authentication other than password""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='kbdint', password='kbdint', client_keys=None) @asynctest async def test_kbdint_auth_multi_not_password(self): """Test keyboard-interactive authentication with multiple prompts""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='multi', password='kbdint', client_keys=None) @asynctest async def test_kbdint_auth_disabled(self): """Test connecting with keyboard-interactive authentication disabled""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='pw', password='kbdint', kbdint_auth=False) @asynctest async def test_kbdint_auth_failure(self): """Test failure connecting with keyboard-interactive authentication""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='kbdint', password='badpw', client_keys=None) @asynctest async def test_kbdint_auth_callback(self): """Test keyboard-interactive auth callback""" async with self._connect_kbdint('kbdint', ['kbdint'], test_async=True): pass @asynctest async def test_kbdint_auth_callback_multi(self): """Test keyboard-interactive auth callback with multiple challenges""" async with self._connect_kbdint('multi', ['1', '2'], test_async=True): pass @asynctest async def test_kbdint_auth_callback_failure(self): """Test failure connecting with keyboard-interactive auth callback""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_kbdint('kbdint', ['badpw']) @asynctest async def test_disabled_trivial_kbdint_auth(self): """Test disabled trivial auth with keyboard-interactive auth""" async with self.connect(username='pw', password='kbdint', client_keys=None, disable_trivial_auth=True): pass @asynctest async def test_disabled_trivial_kbdint_no_prompts(self): """Test disabled trivial with with no keyboard-interactive prompts""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='none', password='kbdint', client_keys=None, disable_trivial_auth=True) class _TestKbdintAsyncServerAuth(_TestKbdintAuth): """Unit tests for keyboard-interactive auth with async server callbacks""" @classmethod async def start_server(cls): """Start an SSH server which supports async kbd-int auth""" return await cls.create_server(_AsyncKbdintServer) class _TestKbdintPasswordServerAuth(ServerTestCase): """Unit tests for keyboard-interactive auth with server password auth""" @classmethod async def start_server(cls): """Start an SSH server which supports server password auth""" return await cls.create_server(_PasswordServer) @async_context_manager async def _connect_kbdint(self, username, responses): """Open a connection to test keyboard-interactive auth""" def client_factory(): """Return an SSHClient to use to do keyboard-interactive auth""" return _KbdintClient(responses) conn, _ = await self.create_connection(client_factory, username=username, client_keys=None) return conn @asynctest async def test_kbdint_password_auth(self): """Test keyboard-interactive server password authentication""" async with self._connect_kbdint('pw', ['pw']): pass @asynctest async def test_kbdint_password_auth_multiple_responses(self): """Test multiple responses to server password authentication""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_kbdint('pw', ['xxx', 'yyy']) @asynctest async def test_kbdint_password_change(self): """Test keyboard-interactive server password change""" with self.assertRaises(asyncssh.PermissionDenied): await self._connect_kbdint('pw', ['oldpw']) class _TestClientLoginTimeout(ServerTestCase): """Unit test for client login timeout""" @classmethod async def start_server(cls): """Start an SSH server which supports public key authentication""" def server_factory(): """Return an SSHServer that delays before starting auth""" return _PublicKeyServer(delay=2) return await cls.create_server( server_factory, authorized_client_keys='authorized_keys') @asynctest async def test_client_login_timeout_exceeded(self): """Test client login timeout exceeded""" with self.assertRaises(asyncssh.ConnectionLost): await self.connect(username='ckey', client_keys='ckey', login_timeout=1) @asynctest async def test_client_login_timeout_exceeded_string(self): """Test client login timeout exceeded with string value""" with self.assertRaises(asyncssh.ConnectionLost): await self.connect(username='ckey', client_keys='ckey', login_timeout='0m1s') @asynctest async def test_invalid_client_login_timeout(self): """Test invalid client login timeout""" with self.assertRaises(ValueError): await self.connect(login_timeout=-1) class _TestServerLoginTimeoutExceeded(ServerTestCase): """Unit test for server login timeout""" @classmethod async def start_server(cls): """Start an SSH server with a 1 second login timeout""" return await cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', login_timeout=1) @asynctest async def test_server_login_timeout_exceeded(self): """Test server_login timeout exceeded""" def client_factory(): """Return an SSHClient that delays before providing a key""" return _PublicKeyClient(['ckey'], 2) with self.assertRaises(asyncssh.ConnectionLost): await self.create_connection(client_factory, username='ckey', client_keys=None) class _TestServerLoginTimeoutDisabled(ServerTestCase): """Unit test for disabled server login timeout""" @classmethod async def start_server(cls): """Start an SSH server with no login timeout""" return await cls.create_server( _PublicKeyServer, authorized_client_keys='authorized_keys', login_timeout=None) @asynctest async def test_server_login_timeout_disabled(self): """Test with login timeout disabled""" async with self.connect(username='ckey', client_keys='ckey'): pass asyncssh-2.20.0/tests/test_editor.py000066400000000000000000000434541475467777400175350ustar00rootroot00000000000000# Copyright (c) 2016-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH line editor""" import asyncio import asyncssh from .server import ServerTestCase from .util import asynctest async def _handle_session(stdin, stdout, stderr): """Accept lines of input and echo them with a prefix""" encoding = stdin.channel.get_encoding()[0] prefix = '>>>' if encoding else b'>>>' data = '' if encoding else b'' while not stdin.at_eof(): try: data += await stdin.readline() except asyncssh.SignalReceived as exc: if exc.signal == 'CLEAR': stdin.channel.clear_input() elif exc.signal == 'ECHO_OFF': # Set twice to get coverage of when echo isn't changing stdin.channel.set_echo(False) stdin.channel.set_echo(False) elif exc.signal == 'ECHO_ON': stdin.channel.set_echo(True) elif exc.signal == 'LINE_OFF': stdin.channel.set_line_mode(False) else: break except asyncssh.BreakReceived: stdin.channel.set_input('BREAK', 0) except asyncssh.TerminalSizeChanged: continue stderr.write('.' if encoding else b'.') stdout.write(prefix + data) stdout.close() async def _handle_ansi_attrs(_stdin, stdout, _stderr): """Output a line which has ANSI attributes in it""" stdout.write('\x1b[2m' + 72*'*' + '\x1b[0m') stdout.close() async def _handle_output_wrap(_stdin, stdout, _stderr): """Output a line which needs to wrap early""" stdout.write(79*'*' + '\uff10') stdout.close() async def _handle_soft_eof(stdin, stdout, _stderr): """Accept soft EOF using read()""" while not stdin.at_eof(): data = await stdin.read() stdout.write(data or 'EOF\n') stdout.close() async def _handle_app_line_echo(stdin, stdout, _stderr): """Perform line echo in the application""" while not stdin.at_eof(): stdout.write('> ') data = await stdin.readline() stdout.write(data) stdout.close() def _trigger_signal(line, pos): """Trigger a signal when Ctrl-Z is input""" # pylint: disable=unused-argument return 'SIG', -1 def _handle_key(_line, pos): """Handle exclamation point being input""" if pos == 0: return 'xyz', 3 elif pos == 1: return True else: return False async def _handle_register(stdin, stdout, _stderr): """Accept input using read() and echo it back""" while not stdin.at_eof(): try: data = await stdin.readline() except asyncssh.SignalReceived: stdout.write('SIGNAL') break if data == 'R\n': stdin.channel.register_key('!', _handle_key) stdin.channel.register_key('"', _handle_key) stdin.channel.register_key('\u2013', _handle_key) stdin.channel.register_key('\x1bOP', _handle_key) stdin.channel.register_key('\x1a', _trigger_signal) elif data == 'U\n': stdin.channel.unregister_key('!') stdin.channel.unregister_key('"') stdin.channel.unregister_key('\u2013') stdin.channel.unregister_key('\x1bOP') stdin.channel.unregister_key('\x1bOQ') # Test unregistered key stdin.channel.unregister_key('\x1b[25~') # Test unregistered prefix stdin.channel.unregister_key('\x1a') stdout.close() class _CheckEditor(ServerTestCase): """Utility functions for AsyncSSH line editor unit tests""" async def check_input(self, input_data, expected_result, term_type='ansi', set_width=False): """Feed input data and compare echoed back result""" async with self.connect() as conn: process = await conn.create_process(term_type=term_type) process.stdin.write(input_data) await process.stderr.readexactly(1) if set_width: process.change_terminal_size(132, 24) process.stdin.write_eof() output_data = (await process.wait()).stdout idx = output_data.rfind('>>>') self.assertNotEqual(idx, -1) output_data = output_data[idx+3:] self.assertEqual(output_data, expected_result) class _TestEditor(_CheckEditor): """Unit tests for AsyncSSH line editor""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_session) @asynctest async def test_editor(self): """Test line editing""" tests = ( ('Simple line', 'abc\n', 'abc\r\n'), ('EOF', '\x04', ''), ('Erase left', 'abcd\x08\n', 'abc\r\n'), ('Erase left', 'abcd\x08\n', 'abc\r\n'), ('Erase left at beginning', '\x08abc\n', 'abc\r\n'), ('Erase right', 'abcd\x02\x04\n', 'abc\r\n'), ('Erase right at end', 'abc\x04\n', 'abc\r\n'), ('Erase line', 'abcdef\x15abc\n', 'abc\r\n'), ('Erase to end', 'abcdef\x02\x02\x02\x0b\n', 'abc\r\n'), ('Wrapping erase to end', 80*'*' + '\x02\x0b\n', 79*'*' + '\r\n'), ('History previous', 'abc\n\x10\n', 'abc\r\nabc\r\n'), ('History previous at top', '\x10abc\n', 'abc\r\n'), ('History next', 'a\nb\n\x10\x10\x0e\x08c\n', 'a\r\nb\r\nc\r\n'), ('History next to bottom', 'abc\n\x10\x0e\n', 'abc\r\n\r\n'), ('History next at bottom', '\x0eabc\n', 'abc\r\n'), ('Move left', 'abc\x02\n', 'abc\r\n'), ('Move left at beginning', '\x02abc\n', 'abc\r\n'), ('Move left arrow', 'abc\x1b[D\n', 'abc\r\n'), ('Move right', 'abc\x02\x06\n', 'abc\r\n'), ('Move right at end', 'abc\x06\n', 'abc\r\n'), ('Move to start', 'abc\x01\n', 'abc\r\n'), ('Move to end', 'abc\x02\x05\n', 'abc\r\n'), ('Redraw', 'abc\x12\n', 'abc\r\n'), ('Insert erased', 'abc\x15\x19\x19\n', 'abcabc\r\n'), ('Send break', 'abc\x03', 'BREAK'), ('Long line', 100*'*' + '\x02\x01\x05\n', 100*'*' + '\r\n'), ('Wide char wrap', 79*'*' + '\U0001F910\x08\n', 79*'*' + '\r\n'), ('Line length limit', 1024*'*' + '\x05*\n', 1024*'*' + '\r\n'), ('Unknown key', '\x07abc\n', 'abc\r\n') ) for testname, input_data, expected_result in tests: with self.subTest(testname): await self.check_input(input_data, expected_result) @asynctest async def test_non_wrap(self): """Test line editing in non-wrap mode""" tests = ( ('Simple line', 'abc\n', 'abc\r\n'), ('Long line', 100*'*' + '\x02\x01\x05\n', 100*'*' + '\r\n'), ('Long line 2', 101*'*' + 30*'\x02' + '\x08\n', 100*'*' + '\r\n'), ('Redraw', 'abc\x12\n', 'abc\r\n') ) for testname, input_data, expected_result in tests: with self.subTest(testname): await self.check_input(input_data, expected_result, term_type='dumb') @asynctest async def test_no_terminal(self): """Test that editor is disabled when no pseudo-terminal is requested""" await self.check_input('abc\n', 'abc\n', term_type=None) @asynctest async def test_change_width(self): """Test changing the terminal width""" await self.check_input('abc\n', 'abc\r\n', set_width=True) @asynctest async def test_change_width_non_wrap(self): """Test changing the terminal width when not wrapping""" await self.check_input('abc\n', 'abc\r\n', term_type='dumb', set_width=True) @asynctest async def test_editor_clear_input(self): """Test clearing editor's input line""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.stdin.write('abc') process.send_signal('CLEAR') await process.stderr.readexactly(1) process.stdin.write('\n') await process.stderr.readexactly(1) process.stdin.write_eof() output_data = (await process.wait()).stdout self.assertEqual(output_data, 'abc\x1b[3D \x1b[3D\r\n>>>\r\n') @asynctest async def test_editor_echo_off(self): """Test editor with echo disabled""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.send_signal('ECHO_OFF') await process.stderr.readexactly(1) process.stdin.write('abcd\x08\n') await process.stderr.readexactly(1) process.stdin.write_eof() output_data = (await process.wait()).stdout self.assertEqual(output_data, '\r\n>>>abc\r\n') @asynctest async def test_editor_echo_on(self): """Test editor with echo re-enabled""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.send_signal('ECHO_OFF') await process.stderr.readexactly(1) process.stdin.write('abc') process.send_signal('ECHO_ON') await process.stderr.readexactly(1) process.stdin.write('d\x08\n') await process.stderr.readexactly(1) process.stdin.write_eof() output_data = (await process.wait()).stdout self.assertEqual(output_data, 'abcd\x1b[1D \x1b[1D\r\n>>>abc\r\n') @asynctest async def test_editor_line_mode_off(self): """Test editor with line mode disabled""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.send_signal('LINE_OFF') await process.stderr.readexactly(1) process.stdin.write('abc\n') await process.stderr.readexactly(1) process.stdin.write_eof() output_data = (await process.wait()).stdout self.assertEqual(output_data, '>>>abc\r\n') @asynctest async def test_unknown_signal(self): """Test unknown signal""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.send_signal('XXX') output_data = (await process.wait()).stdout self.assertEqual(output_data, '>>>') class _TestEditorDisabled(_CheckEditor): """Unit tests for AsyncSSH line editor being disabled""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server(session_factory=_handle_session, line_editor=False)) @asynctest async def test_editor_disabled(self): """Test that editor is disabled""" await self.check_input('abc\n', 'abc\n') class _TestEditorEncodingNone(_CheckEditor): """Unit tests for AsyncSSH line editor disabled due to encoding None""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server(session_factory=_handle_session, encoding=None)) @asynctest async def test_editor_disabled_encoding_none(self): """Test that editor is disabled when encoding is None""" await self.check_input('abc\n', 'abc\n') @asynctest async def test_change_width(self): """Test changing the terminal width""" await self.check_input('abc\n', 'abc\n', set_width=True) class _TestEditorUnlimitedLength(_CheckEditor): """Unit tests for AsyncSSH line editor with no line length limit""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_session, max_line_length=None) @asynctest async def test_editor_unlimited_length(self): """Test that editor can handle very long lines""" await self.check_input(32768*'*' + '\n', 32768*'*' + '\r\n') class _TestEditorANSI(_CheckEditor): """Unit tests for AsyncSSH line editor handling ANSI attributes""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_ansi_attrs) @asynctest async def test_editor_ansi(self): """Test that editor properly handles ANSI attributes in output""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') output_data = (await process.wait()).stdout self.assertEqual(output_data, '\x1b[2m' + 72*'*' + '\x1b[0m') class _TestEditorOutputWrap(_CheckEditor): """Unit tests for AsyncSSH line editor wrapping output text""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_output_wrap) @asynctest async def test_editor_output_wrap(self): """Test that editor properly wraps wide characters during output""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') output_data = (await process.wait()).stdout self.assertEqual(output_data, 79*'*' + '\uff10') class _TestEditorSoftEOF(ServerTestCase): """Unit tests for AsyncSSH line editor sending soft EOF""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_soft_eof) @asynctest async def test_editor_soft_eof(self): """Test editor sending soft EOF""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.stdin.write('\x04') self.assertEqual((await process.stdout.readline()), 'EOF\r\n') process.stdin.write('abc\n\x04') self.assertEqual((await process.stdout.readline()), 'abc\r\n') self.assertEqual((await process.stdout.readline()), 'abc\r\n') self.assertEqual((await process.stdout.readline()), 'EOF\r\n') process.stdin.write('abc\n') process.stdin.write_eof() self.assertEqual((await process.stdout.read()), 'abc\r\nabc\r\n') class _TestEditorRegisterKey(ServerTestCase): """Unit tests for AsyncSSH line editor register key callback""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(session_factory=_handle_register) @asynctest async def test_editor_register_key(self): """Test editor register key functionality""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') for inp, result in (('R', 'R'), ('!a', 'xyza'), ('\u2013a', 'xyza'), ('a!b', 'a!b'), ('ab!', 'ab\x07'), ('ab!!', 'ab\x07'), ('\x1bOPa', 'xyza'), ('a\x1bOPb', 'a\x07b'), ('ab\x1bOP', 'ab\x07'), ('U', 'U'), ('!', '!'), ('\x1bOP', '\x07')): process.stdin.write(inp + '\n') self.assertEqual((await process.stdout.readline()), result + '\r\n') process.stdin.write_eof() @asynctest async def test_editor_signal(self): """Test editor register key triggering a signal""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.stdin.write('R\n') await process.stdout.readline() process.stdin.write('\x1a') self.assertEqual((await process.stdout.read()), 'SIGNAL') class _TestEditorLineEcho(_CheckEditor): """Unit tests for AsyncSSH line editor with line echo in application""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server(session_factory=_handle_app_line_echo, line_echo=False)) @asynctest async def test_editor_line_echo(self): """Test line echo handled by application""" async with self.connect() as conn: process = await conn.create_process(term_type='ansi') process.stdin.write('abc\rdef\r') await asyncio.sleep(0.1) process.stdin.write('ghi\r') await asyncio.sleep(0.1) process.stdin.write_eof() self.assertEqual((await process.stdout.read()), '> abc\x1b[3D \x1b[3Ddef\x1b[3D \x1b[3D' 'abc\r\n> def\r\n> ghi\r\n> ') asyncssh-2.20.0/tests/test_encryption.py000066400000000000000000000052741475467777400204370ustar00rootroot00000000000000# Copyright (c) 2015-2020 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for symmetric key encryption""" import os import random import unittest from asyncssh.encryption import register_encryption_alg, get_encryption_algs from asyncssh.encryption import get_encryption_params, get_encryption from asyncssh.mac import get_mac_algs class _TestEncryption(unittest.TestCase): """Unit tests for encryption module""" def check_encryption_alg(self, enc_alg, mac_alg): """Check a symmetric encryption algorithm""" enc_keysize, enc_ivsize, enc_blocksize, mac_keysize, _, etm = \ get_encryption_params(enc_alg, mac_alg) enc_blocksize = max(8, enc_blocksize) enc_key = os.urandom(enc_keysize) enc_iv = os.urandom(enc_ivsize) mac_key = os.urandom(mac_keysize) seq = random.getrandbits(32) enc = get_encryption(enc_alg, enc_key, enc_iv, mac_alg, mac_key, etm) dec = get_encryption(enc_alg, enc_key, enc_iv, mac_alg, mac_key, etm) for i in range(2, 6): data = os.urandom(4*etm + i*enc_blocksize) hdr, packet = data[:4], data[4:] encdata, encmac = enc.encrypt_packet(seq, hdr, packet) first, rest = encdata[:enc_blocksize], encdata[enc_blocksize:] decfirst, dechdr = dec.decrypt_header(seq, first, 4) decdata = dec.decrypt_packet(seq, decfirst, rest, 4, encmac) self.assertEqual(dechdr, hdr) self.assertEqual(decdata, packet) seq = (seq + 1) & 0xffffffff def test_encryption_algs(self): """Unit test encryption algorithms""" for enc_alg in get_encryption_algs(): for mac_alg in get_mac_algs(): with self.subTest(enc_alg=enc_alg, mac_alg=mac_alg): self.check_encryption_alg(enc_alg, mac_alg) def test_unavailable_cipher(self): """Test registering encryption that uses an unavailable cipher""" # pylint: disable=no-self-use register_encryption_alg('xxx', 'xxx', '', True) asyncssh-2.20.0/tests/test_forward.py000066400000000000000000001426011475467777400177050ustar00rootroot00000000000000# Copyright (c) 2016-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH forwarding API""" import asyncio import codecs import os import socket import sys import unittest from unittest.mock import patch import asyncssh from asyncssh.misc import maybe_wait_closed, write_file from asyncssh.packet import String, UInt32 from asyncssh.public_key import CERT_TYPE_USER from asyncssh.socks import SOCKS5, SOCKS5_AUTH_NONE from asyncssh.socks import SOCKS4_OK_RESPONSE, SOCKS5_OK_RESPONSE_HDR from .server import Server, ServerTestCase from .util import asynctest, echo, make_certificate, try_remove def _echo_non_async(stdin, stdout, stderr=None): """Non-async version of echo callback""" conn = stdin.get_extra_info('connection') conn.create_task(echo(stdin, stdout, stderr)) def _listener(_orig_host, _orig_port): """Handle a forwarded TCP/IP connection""" return echo def _listener_non_async(_orig_host, _orig_port): """Non-async version of handler for a forwarded TCP/IP connection""" return _echo_non_async def _unix_listener(): """Handle a forwarded UNIX domain connection""" return echo def _unix_listener_non_async(): """Non-async version of handler for a forwarded UNIX domain connection""" return _echo_non_async async def _pause(reader, writer): """Sleep to allow buffered data to build up and trigger a pause""" await asyncio.sleep(0.1) await reader.read() writer.close() await maybe_wait_closed(writer) async def _async_runtime_error(_reader, _writer): """Raise a runtime error""" raise RuntimeError('Async internal error') class _ClientConn(asyncssh.SSHClientConnection): """Patched SSH client connection for unit testing""" async def make_global_request(self, request, *args): """Send a global request and wait for the response""" return await self._make_global_request(request, *args) class _EchoPortListener(asyncssh.SSHListener): """A TCP listener which opens a connection that echoes data""" def __init__(self, conn): super().__init__() self._conn = conn conn.create_task(self._open_connection()) async def _open_connection(self): """Open a forwarded connection that echoes data""" await asyncio.sleep(0.1) reader, writer = await self._conn.open_connection('open', 65535) await echo(reader, writer) def close(self): """Stop listening for new connections""" async def wait_closed(self): """Wait for the listener to close""" class _EchoPathListener(asyncssh.SSHListener): """A UNIX domain listener which opens a connection that echoes data""" def __init__(self, conn): super().__init__() self._conn = conn conn.create_task(self._open_connection()) async def _open_connection(self): """Open a forwarded connection that echoes data""" await asyncio.sleep(0.1) reader, writer = await self._conn.open_unix_connection('open') await echo(reader, writer) def close(self): """Stop listening for new connections""" async def wait_closed(self): """Wait for the listener to close""" class _TCPConnectionServer(Server): """Server for testing direct and forwarded TCP connections""" def connection_requested(self, dest_host, dest_port, orig_host, orig_port): """Handle a request to create a new connection""" if dest_port == 1: return False elif dest_port == 7: return (self._conn.create_tcp_channel(), echo) elif dest_port == 8: return _pause elif dest_port == 9: self._conn.close() return (self._conn.create_tcp_channel(), echo) elif dest_port == 10: return _async_runtime_error else: return True def server_requested(self, listen_host, listen_port): """Handle a request to create a new socket listener""" if listen_host == 'open': return _EchoPortListener(self._conn) else: return listen_host != 'fail' class _TCPAsyncConnectionServer(_TCPConnectionServer): """Server for testing async direct and forwarded TCP connections""" async def server_requested(self, listen_host, listen_port): """Handle a request to create a new socket listener""" if listen_host == 'open': return _EchoPortListener(self._conn) else: return listen_host != 'fail' class _TCPAcceptHandlerServer(Server): """Server for testing forwarding accept handler""" async def server_requested(self, listen_host, listen_port): """Handle a request to create a new socket listener""" def accept_handler(_orig_host: str, _orig_port: int) -> bool: return True return accept_handler class _UNIXConnectionServer(Server): """Server for testing direct and forwarded UNIX domain connections""" def unix_connection_requested(self, dest_path): """Handle a request to create a new UNIX domain connection""" if dest_path == '': return True elif dest_path == '/echo': return (self._conn.create_unix_channel(), echo) else: return False def unix_server_requested(self, listen_path): """Handle a request to create a new UNIX domain listener""" if listen_path == 'open': return _EchoPathListener(self._conn) else: return listen_path != 'fail' class _UNIXAsyncConnectionServer(_UNIXConnectionServer): """Server for testing async direct and forwarded UNIX connections""" async def unix_server_requested(self, listen_path): """Handle a request to create a new UNIX domain listener""" if listen_path == 'open': return _EchoPathListener(self._conn) else: return listen_path != 'fail' class _CheckForwarding(ServerTestCase): """Utility functions for AsyncSSH forwarding unit tests""" async def _check_echo_line(self, reader, writer, delay=False, encoded=False): """Check if an input line is properly echoed back""" if delay: await asyncio.sleep(delay) line = str(id(self)) + '\n' if not encoded: line = line.encode('utf-8') writer.write(line) await writer.drain() result = await reader.readline() writer.close() await maybe_wait_closed(writer) self.assertEqual(line, result) async def _check_echo_block(self, reader, writer): """Check if a block of data is properly echoed back""" data = 4 * [1025*1024*b'\0'] writer.writelines(data) await writer.drain() writer.write_eof() result = await reader.read() #await reader.channel.wait_closed() writer.close() self.assertEqual(b''.join(data), result) async def _check_local_connection(self, listen_port, delay=None): """Open a local connection and test if an input line is echoed back""" reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) await self._check_echo_line(reader, writer, delay=delay) async def _check_local_unix_connection(self, listen_path): """Open a local connection and test if an input line is echoed back""" # pylint doesn't think open_unix_connection exists # pylint: disable=no-member reader, writer = await asyncio.open_unix_connection(listen_path) # pylint: enable=no-member await self._check_echo_line(reader, writer) class _TestTCPForwarding(_CheckForwarding): """Unit tests for AsyncSSH TCP connection forwarding""" @classmethod async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" return (await cls.create_server( _TCPConnectionServer, authorized_client_keys='authorized_keys')) async def _check_connection(self, conn, dest_host='', dest_port=7, **kwargs): """Open a connection and test if a block of data is echoed back""" reader, writer = await conn.open_connection(dest_host, dest_port, *kwargs) await self._check_echo_block(reader, writer) @asynctest async def test_ssh_create_tunnel(self): """Test creating a tunneled SSH connection""" async with self.connect() as conn: conn2, _ = await conn.create_ssh_connection( None, self._server_addr, self._server_port) async with conn2: await self._check_connection(conn2) @asynctest async def test_ssh_connect_tunnel(self): """Test connecting a tunneled SSH connection""" async with self.connect() as conn: async with conn.connect_ssh(self._server_addr, self._server_port) as conn2: await self._check_connection(conn2) @asynctest async def test_ssh_connect_tunnel_string(self): """Test connecting a tunneled SSH connection via string""" async with self.connect(tunnel=f'{self._server_addr}:' f'{self._server_port}') as conn: await self._check_connection(conn) @asynctest async def test_ssh_connect_tunnel_string_failed(self): """Test failed connection on a tunneled SSH connection via string""" with self.assertRaises(asyncssh.ChannelOpenError): await asyncssh.connect( '\xff', tunnel=f'{self._server_addr}:{self._server_port}') @asynctest async def test_proxy_jump(self): """Test connecting a tunnneled SSH connection using ProxyJump""" write_file('.ssh/config', 'Host target\n' ' Hostname localhost\n' f' Port {self._server_port}\n' f' ProxyJump localhost:{self._server_port}\n' 'IdentityFile ckey\n', 'w') try: async with self.connect(host='target', username='ckey'): pass finally: os.remove('.ssh/config') @asynctest async def test_proxy_jump_multiple(self): """Test connecting a tunnneled SSH connection using ProxyJump""" write_file('.ssh/config', 'Host target\n' ' Hostname localhost\n' f' Port {self._server_port}\n' f' ProxyJump localhost:{self._server_port},' f'localhost:{self._server_port}\n' 'IdentityFile ckey\n', 'w') try: async with self.connect(host='target', username='ckey'): pass finally: os.remove('.ssh/config') @asynctest async def test_proxy_jump_encrypted_key(self): """Test ProxyJump with encrypted client key""" write_file('.ssh/config', 'Host *\n' ' User ckey\n' 'Host target\n' ' Hostname localhost\n' f' Port {self._server_port}\n' f' ProxyJump localhost:{self._server_port}\n' ' IdentityFile ckey_encrypted\n', 'w') try: async with self.connect(host='target', username='ckey', client_keys='ckey_encrypted', passphrase='passphrase'): pass finally: os.remove('.ssh/config') @asynctest async def test_proxy_jump_encrypted_key_missing_passphrase(self): """Test ProxyJump with encrypted client key and missing passphrase""" write_file('.ssh/config', 'Host *\n' ' User ckey\n' 'Host target\n' ' Hostname localhost\n' f' Port {self._server_port}\n' f' ProxyJump localhost:{self._server_port}\n' ' IdentityFile ckey_encrypted\n', 'w') try: with self.assertRaises(asyncssh.KeyImportError): await self.connect(host='target', username='ckey', client_keys='ckey_encrypted') finally: os.remove('.ssh/config') @asynctest async def test_ssh_connect_reverse_tunnel(self): """Test creating a tunneled reverse direction SSH connection""" server2 = await self.listen_reverse() listen_port = server2.sockets[0].getsockname()[1] async with self.connect() as conn: async with conn.connect_reverse_ssh('127.0.0.1', listen_port, server_factory=Server, server_host_keys=['skey']): pass server2.close() await server2.wait_closed() @asynctest async def test_ssh_listen_tunnel(self): """Test opening a tunneled SSH listener""" async with self.connect() as conn: async with conn.listen_ssh(port=0, server_factory=Server, server_host_keys=['skey']) as server: listen_port = server.get_port() self.assertEqual(server.get_addresses(), [('', listen_port)]) async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass @asynctest async def test_ssh_listen_tunnel_string(self): """Test opening a tunneled SSH listener via string""" async with self.listen( tunnel=f'ckey@{self._server_addr}:{self._server_port}', server_factory=Server, server_host_keys=['skey']) as server: listen_port = server.get_port() async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass @asynctest async def test_ssh_listen_tunnel_string_failed(self): """Test open failure on a tunneled SSH listener via string""" with self.assertRaises(asyncssh.ChannelListenError): await asyncssh.listen( '\xff', tunnel=f'{self._server_addr}:{self._server_port}', server_factory=Server, server_host_keys=['skey']) @asynctest async def test_ssh_listen_tunnel_default_port(self): """Test opening a tunneled SSH listener via string without port""" with patch('asyncssh.connection.DEFAULT_PORT', self._server_port): async with self.listen(tunnel='localhost', server_factory=Server, server_host_keys=['skey']) as server: listen_port = server.get_port() async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass @asynctest async def test_ssh_listen_reverse_tunnel(self): """Test creating a tunneled reverse direction SSH connection""" async with self.connect() as conn: async with conn.listen_reverse_ssh(port=0, known_hosts=(['skey.pub'], [], [])) as server2: listen_port = server2.get_port() async with asyncssh.connect_reverse('127.0.0.1', listen_port, server_factory=Server, server_host_keys=['skey']): pass @asynctest async def test_connection(self): """Test opening a remote connection""" async with self.connect() as conn: await self._check_connection(conn) @asynctest async def test_connection_failure(self): """Test failure in opening a remote connection""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection('', 0) @asynctest async def test_connection_rejected(self): """Test rejection in opening a remote connection""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection('fail', 0) @asynctest async def test_connection_not_permitted(self): """Test permission denied in opening a remote connection""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection('', 7) @asynctest async def test_connection_not_permitted_open(self): """Test open permission denied in opening a remote connection""" async with self.connect(username='ckey', client_keys=['ckey'], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection('fail', 7) @asynctest async def test_connection_invalid_unicode(self): """Test opening a connection with invalid Unicode in host""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection(b'\xff', 0) @asynctest async def test_server(self): """Test creating a remote listener""" async with self.connect() as conn: listener = await conn.start_server(_listener, '', 0) await self._check_local_connection(listener.get_port()) listener.close() listener.close() await listener.wait_closed() listener.close() @asynctest async def test_server_context_manager(self): """Test using a remote listener as a context manager""" async with self.connect() as conn: async with conn.start_server(_listener, '', 0) as listener: await self._check_local_connection(listener.get_port()) @asynctest async def test_server_open(self): """Test creating a remote listener which uses open_connection""" def new_connection(reader, writer): """Handle a forwarded TCP/IP connection""" waiter.set_result((reader, writer)) def handler_factory(_orig_host, _orig_port): """Handle all connections using new_connection""" return new_connection async with self.connect() as conn: waiter = self.loop.create_future() await conn.start_server(handler_factory, 'open', 0) reader, writer = await waiter await self._check_echo_line(reader, writer) # Clean up the listener during connection close @asynctest async def test_server_non_async(self): """Test creating a remote listener using non-async handler""" async with self.connect() as conn: async with conn.start_server(_listener_non_async, '', 0) as listener: await self._check_local_connection(listener.get_port()) @asynctest async def test_server_failure(self): """Test failure in creating a remote listener""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.start_server(_listener, 'fail', 0) @asynctest async def test_forward_local_port(self): """Test forwarding of a local port""" async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 7) as listener: await self._check_local_connection(listener.get_port(), delay=0.1) @asynctest async def test_forward_local_port_accept_handler(self): """Test forwarding of a local port with an accept handler""" def accept_handler(_orig_host: str, _orig_port: int) -> bool: return True async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 7, accept_handler) as listener: await self._check_local_connection(listener.get_port(), delay=0.1) @asynctest async def test_forward_local_port_accept_handler_denial(self): """Test forwarding of a local port with an accept handler denial""" async def accept_handler(_orig_host: str, _orig_port: int) -> bool: return False async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 7, accept_handler) as listener: listen_port = listener.get_port() reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) self.assertEqual((await reader.read()), b'') writer.close() await maybe_wait_closed(writer) @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest async def test_forward_local_path_to_port(self): """Test forwarding of a local UNIX domain path to a remote TCP port""" async with self.connect() as conn: async with conn.forward_local_path_to_port('local', '', 7): await self._check_local_unix_connection('local') try_remove('local') @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest async def test_forward_local_path_to_port_failure(self): """Test failure forwarding a local UNIX domain path to a TCP port""" open('local', 'w').close() async with self.connect() as conn: with self.assertRaises(OSError): await conn.forward_local_path_to_port('local', '', 7) try_remove('local') @asynctest async def test_forward_local_port_pause(self): """Test pause during forwarding of a local port""" async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 8) as listener: listen_port = listener.get_port() reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) writer.write(4*1024*1024*b'\0') writer.write_eof() await reader.read() writer.close() await maybe_wait_closed(writer) @asynctest async def test_forward_local_port_failure(self): """Test failure in forwarding a local port""" async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 65535) as listener: listen_port = listener.get_port() reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) self.assertEqual((await reader.read()), b'') writer.close() await maybe_wait_closed(writer) @unittest.skipIf(sys.platform == 'win32', 'skip dual-stack tests on Windows') @asynctest async def test_forward_bind_error_ipv4(self): """Test error binding a local forwarding port""" async with self.connect() as conn: async with conn.forward_local_port('0.0.0.0', 0, '', 7) as listener: with self.assertRaises(OSError): await conn.forward_local_port('', listener.get_port(), '', 7) @unittest.skipIf(sys.platform == 'win32', 'skip dual-stack tests on Windows') @asynctest async def test_forward_bind_error_ipv6(self): """Test error binding a local forwarding port""" async with self.connect() as conn: async with conn.forward_local_port('::', 0, '', 7) as listener: with self.assertRaises(OSError): await conn.forward_local_port('', listener.get_port(), '', 7) @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest async def test_forward_port_to_path_bind_error(self): """Test error binding a local port forwarding to remote path""" async with self.connect() as conn: async with conn.forward_local_port('0.0.0.0', 0, '', 7) as listener: with self.assertRaises(OSError): await conn.forward_local_port_to_path( '', listener.get_port(), '') @asynctest async def test_forward_connect_error(self): """Test error connecting a local forwarding port""" async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 1) as listener: listen_port = listener.get_port() reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) self.assertEqual((await reader.read()), b'') writer.close() await maybe_wait_closed(writer) @asynctest async def test_forward_immediate_eof(self): """Test getting EOF before forwarded connection is fully open""" async with self.connect() as conn: async with conn.forward_local_port('', 0, '', 7) as listener: listen_port = listener.get_port() _, writer = await asyncio.open_connection('127.0.0.1', listen_port) writer.close() await maybe_wait_closed(writer) await asyncio.sleep(0.1) @asynctest async def test_forward_remote_port(self): """Test forwarding of a remote port""" server = await asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] async with self.connect() as conn: async with conn.forward_remote_port( '', 0, '127.0.0.1', server_port) as listener: await self._check_local_connection(listener.get_port()) server.close() await server.wait_closed() @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest async def test_forward_remote_port_to_path(self): """Test forwarding of a remote port to a local UNIX domain socket""" server = await asyncio.start_unix_server(echo, 'local') async with self.connect() as conn: async with conn.forward_remote_port_to_path( '', 0, 'local') as listener: await self._check_local_connection(listener.get_port()) server.close() await server.wait_closed() try_remove('local') @asynctest async def test_forward_remote_specific_port(self): """Test forwarding of a specific remote port""" server = await asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] sock = socket.socket() sock.bind(('', 0)) remote_port = sock.getsockname()[1] sock.close() async with self.connect() as conn: async with conn.forward_remote_port( '', remote_port, '127.0.0.1', server_port) as listener: await self._check_local_connection(listener.get_port()) server.close() await server.wait_closed() @asynctest async def test_forward_remote_port_failure(self): """Test failure of forwarding a remote port""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_port('', 65536, '', 0) @asynctest async def test_forward_remote_port_not_permitted(self): """Test permission denied in forwarding of a remote port""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_port('', 0, '', 0) @asynctest async def test_forward_remote_port_invalid_unicode(self): """Test TCP/IP forwarding with invalid Unicode in host""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_port(b'\xff', 0, '', 0) @asynctest async def test_cancel_forward_remote_port_invalid_unicode(self): """Test canceling TCP/IP forwarding with invalid Unicode in host""" with patch('asyncssh.connection.SSHClientConnection', _ClientConn): async with self.connect() as conn: pkttype, _ = await conn.make_global_request( b'cancel-tcpip-forward', String(b'\xff'), UInt32(0)) self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) @asynctest async def test_add_channel_after_close(self): """Test opening a connection after a close""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_connection('', 9) @asynctest async def test_async_runtime_error(self): """Test runtime error in async listener""" async with self.connect() as conn: reader, _ = await conn.open_connection('', 10) with self.assertRaises(asyncssh.ConnectionLost): await reader.read() @asynctest async def test_multiple_global_requests(self): """Test sending multiple global requests in parallel""" async with self.connect() as conn: listeners = await asyncio.gather( conn.forward_remote_port('', 0, '', 7), conn.forward_remote_port('', 0, '', 7)) for listener in listeners: listener.close() await listener.wait_closed() @asynctest async def test_listener_close_on_conn_close(self): """Test listener closes when connection closes""" async with self.connect() as conn: listener = await conn.forward_local_port('', 0, '', 80) await conn.open_connection('', 10) await listener.wait_closed() class _TestTCPForwardingAcceptHandler(_CheckForwarding): """Unit tests for TCP forwarding with accept handler""" @classmethod async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" return await cls.create_server( _TCPAcceptHandlerServer, authorized_client_keys='authorized_keys') @asynctest async def test_forward_remote_port_accept_handler(self): """Test forwarding of a remote port with accept handler""" server = await asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] async with self.connect() as conn: async with conn.forward_remote_port( '', 0, '127.0.0.1', server_port) as listener: await self._check_local_connection(listener.get_port()) server.close() await server.wait_closed() class _TestAsyncTCPForwarding(_TestTCPForwarding): """Unit tests for AsyncSSH TCP connection forwarding with async return""" @classmethod async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" return await cls.create_server( _TCPAsyncConnectionServer, authorized_client_keys='authorized_keys') @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') class _TestUNIXForwarding(_CheckForwarding): """Unit tests for AsyncSSH UNIX connection forwarding""" @classmethod async def start_server(cls): """Start an SSH server which supports UNIX connection forwarding""" return (await cls.create_server( _UNIXConnectionServer, authorized_client_keys='authorized_keys')) async def _check_unix_connection(self, conn, dest_path='/echo', **kwargs): """Open a UNIX connection and test if an input line is echoed back""" reader, writer = await conn.open_unix_connection(dest_path, encoding='utf-8', *kwargs) await self._check_echo_line(reader, writer, encoded=True) @asynctest async def test_unix_connection(self): """Test opening a remote UNIX connection""" async with self.connect() as conn: await self._check_unix_connection(conn) @asynctest async def test_unix_connection_failure(self): """Test failure in opening a remote UNIX connection""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_unix_connection('') @asynctest async def test_unix_connection_rejected(self): """Test rejection in opening a remote UNIX connection""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_unix_connection('/fail') @asynctest async def test_unix_connection_not_permitted(self): """Test permission denied in opening a remote UNIX connection""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_unix_connection('/echo') @asynctest async def test_unix_connection_invalid_unicode(self): """Test opening a UNIX connection with invalid Unicode in path""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_unix_connection(b'\xff') @asynctest async def test_unix_server(self): """Test creating a remote UNIX listener""" path = os.path.abspath('echo') async with self.connect() as conn: listener = await conn.start_unix_server(_unix_listener, path) await self._check_local_unix_connection('echo') listener.close() listener.close() await listener.wait_closed() listener.close() try_remove('echo') @asynctest async def test_unix_server_open(self): """Test creating a UNIX listener which uses open_unix_connection""" def new_connection(reader, writer): """Handle a forwarded UNIX domain connection""" waiter.set_result((reader, writer)) def handler_factory(): """Handle all connections using new_connection""" return new_connection async with self.connect() as conn: waiter = self.loop.create_future() async with conn.start_unix_server(handler_factory, 'open'): reader, writer = await waiter await self._check_echo_line(reader, writer) @asynctest async def test_unix_server_non_async(self): """Test creating a remote UNIX listener using non-async handler""" path = os.path.abspath('echo') async with self.connect() as conn: async with conn.start_unix_server(_unix_listener_non_async, path): await self._check_local_unix_connection('echo') try_remove('echo') @asynctest async def test_unix_server_failure(self): """Test failure in creating a remote UNIX listener""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.start_unix_server(_unix_listener, 'fail') @asynctest async def test_forward_local_path(self): """Test forwarding of a local UNIX domain path""" async with self.connect() as conn: async with conn.forward_local_path('local', '/echo'): await self._check_local_unix_connection('local') try_remove('local') @asynctest async def test_forward_local_port_to_path_accept_handler(self): """Test forwarding of port to UNIX path with accept handler""" def accept_handler(_orig_host: str, _orig_port: int) -> bool: return True async with self.connect() as conn: async with conn.forward_local_port_to_path( '', 0, '/echo', accept_handler) as listener: await self._check_local_connection(listener.get_port(), delay=0.1) @asynctest async def test_forward_local_port_to_path_accept_handler_denial(self): """Test forwarding of port to UNIX path with accept handler denial""" async def accept_handler(_orig_host: str, _orig_port: int) -> bool: return False async with self.connect() as conn: async with conn.forward_local_port_to_path( '', 0, '/echo', accept_handler) as listener: listen_port = listener.get_port() reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) self.assertEqual((await reader.read()), b'') writer.close() await maybe_wait_closed(writer) @asynctest async def test_forward_local_port_to_path(self): """Test forwarding of a local port to a remote UNIX domain socket""" async with self.connect() as conn: async with conn.forward_local_port_to_path('', 0, '/echo') as listener: await self._check_local_connection(listener.get_port(), delay=0.1) @asynctest async def test_forward_specific_local_port_to_path(self): """Test forwarding of a specific local port to a UNIX domain socket""" sock = socket.socket() sock.bind(('', 0)) listen_port = sock.getsockname()[1] sock.close() async with self.connect() as conn: async with conn.forward_local_port_to_path( '', listen_port, '/echo') as listener: await self._check_local_connection(listener.get_port(), delay=0.1) @asynctest async def test_forward_remote_path(self): """Test forwarding of a remote UNIX domain path""" # pylint doesn't think start_unix_server exists # pylint: disable=no-member server = await asyncio.start_unix_server(echo, 'local') # pylint: enable=no-member path = os.path.abspath('echo') async with self.connect() as conn: async with conn.forward_remote_path(path, 'local'): await self._check_local_unix_connection('echo') server.close() await server.wait_closed() try_remove('echo') try_remove('local') @asynctest async def test_forward_remote_path_to_port(self): """Test forwarding of a remote UNIX domain path to a local TCP port""" server = await asyncio.start_server(echo, None, 0, family=socket.AF_INET) server_port = server.sockets[0].getsockname()[1] path = os.path.abspath('echo') async with self.connect() as conn: async with conn.forward_remote_path_to_port( path, '127.0.0.1', server_port): await self._check_local_unix_connection('echo') server.close() await server.wait_closed() try_remove('echo') @asynctest async def test_forward_remote_path_failure(self): """Test failure of forwarding a remote UNIX domain path""" open('echo', 'w').close() path = os.path.abspath('echo') async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_path(path, 'local') try_remove('echo') @asynctest async def test_forward_remote_path_not_permitted(self): """Test permission denied in forwarding a remote UNIX domain path""" ckey = asyncssh.read_private_key('ckey') cert = make_certificate('ssh-rsa-cert-v01@openssh.com', CERT_TYPE_USER, ckey, ckey, ['ckey'], extensions={'no-port-forwarding': ''}) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_path('', 'local') @asynctest async def test_forward_remote_path_invalid_unicode(self): """Test forwarding a UNIX domain path with invalid Unicode in it""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_path(b'\xff', 'local') @asynctest async def test_cancel_forward_remote_path_invalid_unicode(self): """Test canceling UNIX forwarding with invalid Unicode in path""" with patch('asyncssh.connection.SSHClientConnection', _ClientConn): async with self.connect() as conn: pkttype, _ = await conn.make_global_request( b'cancel-streamlocal-forward@openssh.com', String(b'\xff')) self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) class _TestAsyncUNIXForwarding(_TestUNIXForwarding): """Unit tests for AsyncSSH UNIX connection forwarding with async return""" @classmethod async def start_server(cls): """Start an SSH server which supports UNIX connection forwarding""" return await cls.create_server( _UNIXAsyncConnectionServer, authorized_client_keys='authorized_keys') class _TestSOCKSForwarding(_CheckForwarding): """Unit tests for AsyncSSH SOCKS dynamic port forwarding""" @classmethod async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" return (await cls.create_server( _TCPConnectionServer, authorized_client_keys='authorized_keys')) async def _check_early_error(self, reader, writer, data): """Check errors in the initial SOCKS message""" writer.write(data) self.assertEqual((await reader.read()), b'') async def _check_socks5_error(self, reader, writer, data): """Check SOCKSv5 errors after auth""" writer.write(bytes((SOCKS5, 1, SOCKS5_AUTH_NONE))) self.assertEqual((await reader.readexactly(2)), bytes((SOCKS5, SOCKS5_AUTH_NONE))) writer.write(data) self.assertEqual((await reader.read()), b'') async def _check_socks4_connect(self, reader, writer, data, result): """Check SOCKSv4 connect requests""" writer.write(data) response = await reader.readexactly(len(SOCKS4_OK_RESPONSE)) self.assertEqual(response, SOCKS4_OK_RESPONSE) if result: await self._check_echo_line(reader, writer) else: self.assertEqual((await reader.read()), b'') async def _check_socks5_connect(self, reader, writer, data, addrtype, addrlen, result): """Check SOCKSv5 connect_requests""" writer.write(bytes((SOCKS5, 1, SOCKS5_AUTH_NONE))) self.assertEqual((await reader.readexactly(2)), bytes((SOCKS5, SOCKS5_AUTH_NONE))) writer.write(data[:20]) await asyncio.sleep(0.1) writer.write(data[20:]) expected = SOCKS5_OK_RESPONSE_HDR + bytes((addrtype,)) + \ (addrlen + 2) * b'\0' response = await reader.readexactly(len(expected)) self.assertEqual(response, expected) if result: await self._check_echo_line(reader, writer) else: self.assertEqual((await reader.read()), b'') async def _check_socks(self, handler, listen_port, msg, data, *args): """Unit test SOCKS dynamic port forwarding""" with self.subTest(msg=msg, data=data): data = codecs.decode(data, 'hex') reader, writer = await asyncio.open_connection('127.0.0.1', listen_port) try: await handler(reader, writer, data, *args) finally: writer.close() await maybe_wait_closed(writer) @asynctest async def test_forward_socks(self): """Test dynamic port forwarding via SOCKS""" _socks_early_errors = [ ('Bad version', '0000'), ('Bad SOCKSv4 command', '0400'), ('Bad SOCKSv4 Unicode data', '040100010000000100ff00'), ('SOCKSv4 hostname too long', '040100010000000100' + 256 * 'ff'), ('Bad SOCKSv5 auth list', '050101') ] _socks5_postauth_errors = [ ('Bad command', '05000001'), ('Bad address', '05010000'), ('Bad Unicode data', '0501000301ff0007') ] _socks4_connects = [ ('IPv4', '040100077f00000100', True), ('Hostname', '0401000700000001006c6f63616c686f737400', True), ('Rejected', '04010001000000010000', False) ] _socks5_connects = [ ('IPv4', '050100017f0000010007', 1, 4, True), ('Hostname', '05010003096c6f63616c686f73740007', 1, 4, True), ('IPv6', '05010004' + 15*'00' + '010007', 4, 16, True), ('Rejected', '05010003000001', 1, 4, False) ] async with self.connect() as conn: async with conn.forward_socks('', 0) as listener: listen_port = listener.get_port() for msg, data in _socks_early_errors: await self._check_socks(self._check_early_error, listen_port, msg, data) for msg, data in _socks5_postauth_errors: await self._check_socks(self._check_socks5_error, listen_port, msg, data) for msg, data, result in _socks4_connects: await self._check_socks(self._check_socks4_connect, listen_port, msg, data, result) for msg, data, addrtype, addrlen, result in _socks5_connects: await self._check_socks(self._check_socks5_connect, listen_port, msg, data, addrtype, addrlen, result) @asynctest async def test_forward_socks_specific_port(self): """Test dynamic forwarding on a specific port""" sock = socket.socket() sock.bind(('', 0)) listen_port = sock.getsockname()[1] sock.close() async with self.connect() as conn: async with conn.forward_socks('', listen_port): pass @unittest.skipIf(sys.platform == 'win32', 'Avoid issue with SO_REUSEADDR on Windows') @asynctest async def test_forward_bind_error_socks(self): """Test error binding a local dynamic forwarding port""" async with self.connect() as conn: async with conn.forward_socks('', 0) as listener: with self.assertRaises(OSError): await conn.forward_socks('', listener.get_port()) asyncssh-2.20.0/tests/test_kex.py000066400000000000000000000573111475467777400170330ustar00rootroot00000000000000# Copyright (c) 2015-2020 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for key exchange""" import asyncio import inspect import unittest from hashlib import sha1 import asyncssh from asyncssh.crypto import curve25519_available, curve448_available from asyncssh.crypto import sntrup_available from asyncssh.crypto import Curve25519DH, Curve448DH, ECDH, PQDH from asyncssh.kex_dh import MSG_KEXDH_INIT, MSG_KEXDH_REPLY from asyncssh.kex_dh import MSG_KEX_DH_GEX_REQUEST, MSG_KEX_DH_GEX_GROUP from asyncssh.kex_dh import MSG_KEX_DH_GEX_INIT, MSG_KEX_DH_GEX_REPLY, _KexDHGex from asyncssh.kex_dh import MSG_KEX_ECDH_INIT, MSG_KEX_ECDH_REPLY from asyncssh.kex_dh import MSG_KEXGSS_INIT, MSG_KEXGSS_COMPLETE from asyncssh.kex_dh import MSG_KEXGSS_ERROR from asyncssh.kex_rsa import MSG_KEXRSA_PUBKEY, MSG_KEXRSA_SECRET from asyncssh.kex_rsa import MSG_KEXRSA_DONE from asyncssh.gss import GSSClient, GSSServer from asyncssh.kex import register_kex_alg, get_kex_algs, get_kex from asyncssh.packet import SSHPacket, Boolean, Byte, MPInt, String from asyncssh.public_key import decode_ssh_public_key from .util import asynctest, get_test_key, gss_available, patch_gss from .util import AsyncTestCase, ConnectionStub class _KexConnectionStub(ConnectionStub): """Connection stub class to test key exchange""" def __init__(self, alg, gss, peer, server=False): super().__init__(peer, server) self._gss = gss self._key_waiter = asyncio.Future() self._kex = get_kex(self, alg) async def start(self): """Start key exchange""" await self._kex.start() def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError def enable_gss_kex_auth(self): """Ignore request to enable GSS key exchange authentication""" async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() result = self._kex.process_packet(pkttype, None, packet) if inspect.isawaitable(result): await result def get_hash_prefix(self): """Return the bytes used in calculating unique connection hashes""" # pylint: disable=no-self-use return b'prefix' def send_newkeys(self, k, h): """Handle a request to send a new keys message""" self._key_waiter.set_result(self._kex.compute_key(k, h, b'A', h, 128)) async def get_key(self): """Return generated key data""" return await self._key_waiter def get_gss_context(self): """Return the GSS context associated with this connection""" return self._gss async def simulate_dh_init(self, e): """Simulate receiving a DH init packet""" await self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) async def simulate_dh_reply(self, host_key_data, f, sig): """Simulate receiving a DH reply packet""" await self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), String(host_key_data), MPInt(f), String(sig)))) async def simulate_dh_gex_group(self, p, g): """Simulate receiving a DH GEX group packet""" await self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + MPInt(p) + MPInt(g)) async def simulate_dh_gex_init(self, e): """Simulate receiving a DH GEX init packet""" await self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) async def simulate_dh_gex_reply(self, host_key_data, f, sig): """Simulate receiving a DH GEX reply packet""" await self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), String(host_key_data), MPInt(f), String(sig)))) async def simulate_gss_complete(self, f, sig): """Simulate receiving a GSS complete packet""" await self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), MPInt(f), String(sig), Boolean(False)))) async def simulate_ecdh_init(self, client_pub): """Simulate receiving an ECDH init packet""" await self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) async def simulate_ecdh_reply(self, host_key_data, server_pub, sig): """Simulate receiving ab ECDH reply packet""" await self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), String(host_key_data), String(server_pub), String(sig)))) async def simulate_rsa_pubkey(self, host_key_data, trans_key_data): """Simulate receiving an RSA pubkey packet""" await self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + String(host_key_data) + String(trans_key_data)) async def simulate_rsa_secret(self, encrypted_k): """Simulate receiving an RSA secret packet""" await self.process_packet(Byte(MSG_KEXRSA_SECRET) + String(encrypted_k)) async def simulate_rsa_done(self, sig): """Simulate receiving an RSA done packet""" await self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig)) class _KexClientStub(_KexConnectionStub): """Stub class for client connection""" @classmethod def make_pair(cls, alg, gss_host=None): """Make a client and server connection pair to test key exchange""" client_conn = cls(alg, gss_host) return client_conn, client_conn.get_peer() def __init__(self, alg, gss_host): server_conn = _KexServerStub(alg, gss_host, self) if gss_host: gss = GSSClient(gss_host, None, 'delegate' in gss_host) else: gss = None super().__init__(alg, gss, server_conn) def connection_lost(self, exc): """Handle the closing of a connection""" if exc and not self._key_waiter.done(): self._key_waiter.set_exception(exc) self.close() def validate_server_host_key(self, host_key_data): """Validate and return the server's host key""" # pylint: disable=no-self-use return decode_ssh_public_key(host_key_data) class _KexServerStub(_KexConnectionStub): """Stub class for server connection""" def __init__(self, alg, gss_host, peer): gss = GSSServer(gss_host, None) if gss_host else None super().__init__(alg, gss, peer, True) if gss_host and 'no_host_key' in gss_host: self._server_host_key = None else: priv_key = get_test_key('ecdsa-sha2-nistp256') self._server_host_key = asyncssh.load_keypairs(priv_key)[0] def connection_lost(self, exc): """Handle the closing of a connection""" if self._peer: self._peer.connection_lost(exc) self.close() def get_server_host_key(self): """Return the server host key""" return self._server_host_key @patch_gss class _TestKex(AsyncTestCase): """Unit tests for kex module""" async def _check_kex(self, alg, gss_host=None): """Unit test key exchange""" client_conn, server_conn = _KexClientStub.make_pair(alg, gss_host) try: await client_conn.start() await server_conn.start() self.assertEqual((await client_conn.get_key()), (await server_conn.get_key())) finally: client_conn.close() server_conn.close() @asynctest async def test_key_exchange_algs(self): """Unit test key exchange algorithms""" for alg in get_kex_algs(): with self.subTest(alg=alg): if alg.startswith(b'gss-'): if gss_available: # pragma: no branch await self._check_kex(alg + b'-mech', '1') else: await self._check_kex(alg) if gss_available: # pragma: no branch for steps in range(4): with self.subTest('GSS key exchange', steps=steps): await self._check_kex(b'gss-group1-sha1-mech', str(steps)) with self.subTest('GSS with credential delegation'): await self._check_kex(b'gss-group1-sha1-mech', '1,delegate') with self.subTest('GSS with no host key'): await self._check_kex(b'gss-group1-sha1-mech', '1,no_host_key') with self.subTest('GSS with full host principal'): await self._check_kex(b'gss-group1-sha1-mech', 'host/1@TEST') @asynctest async def test_dh_gex_old(self): """Unit test old DH group exchange request""" register_kex_alg(b'dh-gex-sha1-1024', _KexDHGex, sha1, (1024,), True) register_kex_alg(b'dh-gex-sha1-2048', _KexDHGex, sha1, (2048,), True) for size in (b'1024', b'2048'): with self.subTest('Old DH group exchange', size=size): await self._check_kex(b'dh-gex-sha1-' + size) @asynctest async def test_dh_gex(self): """Unit test old DH group exchange request""" register_kex_alg(b'dh-gex-sha1-1024-1536', _KexDHGex, sha1, (1024, 1536), True) register_kex_alg(b'dh-gex-sha1-1536-3072', _KexDHGex, sha1, (1536, 3072), True) register_kex_alg(b'dh-gex-sha1-2560-2560', _KexDHGex, sha1, (2560, 2560), True) register_kex_alg(b'dh-gex-sha1-2560-4096', _KexDHGex, sha1, (2560, 4096), True) register_kex_alg(b'dh-gex-sha1-9216-9216', _KexDHGex, sha1, (9216, 9216), True) for size in (b'1024-1536', b'1536-3072', b'2560-2560', b'2560-4096', b'9216-9216'): with self.subTest('Old DH group exchange', size=size): await self._check_kex(b'dh-gex-sha1-' + size) @asynctest async def test_dh_errors(self): """Unit test error conditions in DH key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'diffie-hellman-group14-sha1') host_key = server_conn.get_server_host_key() with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.process_packet(Byte(MSG_KEXDH_INIT)) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) with self.subTest('Invalid e value'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_dh_init(0) with self.subTest('Invalid f value'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.start() await client_conn.simulate_dh_reply(host_key.public_data, 0, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): await client_conn.start() await client_conn.simulate_dh_reply(host_key.public_data, 2, b'') client_conn.close() server_conn.close() @asynctest async def test_dh_gex_errors(self): """Unit test error conditions in DH group exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'diffie-hellman-group-exchange-sha1') with self.subTest('Request sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) with self.subTest('Group sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_dh_gex_group(1, 2) with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_dh_gex_init(1) with self.subTest('Init sent before group'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_dh_gex_init(1) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_dh_gex_reply(b'', 1, b'') with self.subTest('Reply sent before group'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_dh_gex_reply(b'', 1, b'') client_conn.close() server_conn.close() @unittest.skipUnless(gss_available, 'GSS not available') @asynctest async def test_gss_errors(self): """Unit test error conditions in GSS key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'gss-group1-sha1-mech', '3') with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) with self.subTest('Complete sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) with self.subTest('Exchange failed to complete'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_gss_complete(1, b'succeed') with self.subTest('Error sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) client_conn.close() server_conn.close() with self.subTest('Signature verification failure'): with self.assertRaises(asyncssh.KeyExchangeFailed): await self._check_kex(b'gss-group1-sha1-mech', '0,verify_error') with self.subTest('Empty token in init'): with self.assertRaises(asyncssh.ProtocolError): await self._check_kex(b'gss-group1-sha1-mech', '0,empty_init') with self.subTest('Empty token in continue'): with self.assertRaises(asyncssh.ProtocolError): await self._check_kex(b'gss-group1-sha1-mech', '1,empty_continue') with self.subTest('Token after complete'): with self.assertRaises(asyncssh.ProtocolError): await self._check_kex(b'gss-group1-sha1-mech', '0,continue_token') for steps in range(2): with self.subTest('Token after complete', steps=steps): with self.assertRaises(asyncssh.ProtocolError): await self._check_kex(b'gss-group1-sha1-mech', str(steps) + ',extra_token') with self.subTest('Context not secure'): with self.assertRaises(asyncssh.ProtocolError): await self._check_kex(b'gss-group1-sha1-mech', '1,no_server_integrity') with self.subTest('GSS error'): with self.assertRaises(asyncssh.KeyExchangeFailed): await self._check_kex(b'gss-group1-sha1-mech', '1,step_error') with self.subTest('GSS error with error token'): with self.assertRaises(asyncssh.KeyExchangeFailed): await self._check_kex(b'gss-group1-sha1-mech', '1,step_error,errtok') @asynctest async def test_ecdh_errors(self): """Unit test error conditions in ECDH key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'ecdh-sha2-nistp256') with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_ecdh_init(b'') with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server host key'): with self.assertRaises(asyncssh.KeyImportError): await client_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() await client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = ECDH(b'nistp256').get_public() await client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') client_conn.close() server_conn.close() @unittest.skipUnless(curve25519_available, 'Curve25519 not available') @asynctest async def test_curve25519dh_errors(self): """Unit test error conditions in Curve25519DH key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'curve25519-sha256') with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() await client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid peer public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() server_pub = b'\x01' + 31*b'\x00' await client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve25519DH().get_public() await client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') client_conn.close() server_conn.close() @unittest.skipUnless(curve448_available, 'Curve448 not available') @asynctest async def test_curve448dh_errors(self): """Unit test error conditions in Curve448DH key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'curve448-sha512') with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() await client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve448DH().get_public() await client_conn.simulate_ecdh_reply(host_key.public_data, server_pub, b'') client_conn.close() server_conn.close() @unittest.skipUnless(sntrup_available, 'SNTRUP761 not available') @asynctest async def test_sntrup761dh_errors(self): """Unit test error conditions in SNTRUP761 key exchange""" pqdh = PQDH(b'sntrup761') client_conn, server_conn = \ _KexClientStub.make_pair(b'sntrup761x25519-sha512@openssh.com') with self.subTest('Invalid client SNTRUP761 public key'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client Curve25519 public key'): with self.assertRaises(asyncssh.ProtocolError): pub = pqdh.pubkey_bytes * b'\0' await server_conn.simulate_ecdh_init(pub) with self.subTest('Invalid server SNTRUP761 public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() await client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') with self.subTest('Invalid server Curve25519 public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() ciphertext = pqdh.ciphertext_bytes * b'\0' await client_conn.simulate_ecdh_reply(host_key.public_data, ciphertext, b'') client_conn.close() server_conn.close() @asynctest async def test_rsa_errors(self): """Unit test error conditions in RSA key exchange""" client_conn, server_conn = \ _KexClientStub.make_pair(b'rsa2048-sha256') with self.subTest('Pubkey sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Secret sent to client'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_rsa_secret(b'') with self.subTest('Done sent to server'): with self.assertRaises(asyncssh.ProtocolError): await server_conn.simulate_rsa_done(b'') with self.subTest('Invalid transient public key'): with self.assertRaises(asyncssh.ProtocolError): await client_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Invalid encrypted secret'): with self.assertRaises(asyncssh.KeyExchangeFailed): await server_conn.start() await server_conn.simulate_rsa_secret(b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() trans_key = get_test_key('ssh-rsa', 2048) await client_conn.simulate_rsa_pubkey(host_key.public_data, trans_key.public_data) await client_conn.simulate_rsa_done(b'') client_conn.close() server_conn.close() asyncssh-2.20.0/tests/test_known_hosts.py000066400000000000000000000263651475467777400206250ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for matching against known_hosts file""" import binascii import hashlib import hmac import os import asyncssh from .util import TempDirTestCase, get_test_key, x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509NamePattern def _hash(host): """Return a hashed version of a hostname in a known_hosts file""" salt = os.urandom(20) hosthash = hmac.new(salt, host.encode(), hashlib.sha1).digest() entry = b'|'.join((b'', b'1', binascii.b2a_base64(salt)[:-1], binascii.b2a_base64(hosthash)[:-1])) return entry.decode() class _TestKnownHosts(TempDirTestCase): """Unit tests for known_hosts module""" keylists = ([], [], [], [], [], [], []) imported_keylists = ([], [], [], [], [], [], []) @classmethod def setUpClass(cls): """Create public keys needed for test""" super().setUpClass() for keylist, imported_keylist in zip(cls.keylists[:3], cls.imported_keylists[:3]): for i in range(3): key = get_test_key('ssh-rsa', i) keylist.append(key.export_public_key().decode('ascii')) imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch for keylist, imported_keylist in zip(cls.keylists[3:5], cls.imported_keylists[3:5]): for i in range(3, 5): key = get_test_key('ssh-rsa', i) cert = key.generate_x509_user_certificate(key, 'OU=user', 'OU=user') keylist.append( cert.export_certificate('openssh').decode('ascii')) imported_keylist.append(cert) for keylist, imported_keylist in zip(cls.keylists[5:], cls.imported_keylists[5:]): for name in ('OU=user', 'OU=revoked'): keylist.append('x509v3-ssh-rsa subject=' + name + '\n') imported_keylist.append(X509NamePattern(name)) def check_match(self, known_hosts, results=None, host='host', addr='1.2.3.4', port=22): """Check the result of calling match_known_hosts""" if results: results = tuple([kl[r] for r in result] for kl, result in zip(self.imported_keylists, results)) matches = asyncssh.match_known_hosts(known_hosts, host, addr, port) self.assertEqual(matches, results) def check_hosts(self, patlists, results=None, host='host', addr='1.2.3.4', port=22, from_file=False, from_bytes=False, as_callable=False, as_tuple=False): """Check a known_hosts file built from the specified patterns""" def call_match(host, addr, port): """Test passing callable as known_hosts""" return asyncssh.match_known_hosts(_known_hosts, host, addr, port) prefixes = ('', '@cert-authority ', '@revoked ', '', '@revoked ', '', '@revoked ') known_hosts = '# Comment line\n # Comment line with whitespace\n\n' for prefix, patlist, keys in zip(prefixes, patlists, self.keylists): for pattern, key in zip(patlist, keys): known_hosts += f'{prefix}{pattern} {key}' if from_file: with open('known_hosts', 'w') as f: f.write(known_hosts) known_hosts = 'known_hosts' elif from_bytes: known_hosts = known_hosts.encode() elif as_callable: _known_hosts = asyncssh.import_known_hosts(known_hosts) known_hosts = call_match elif as_tuple: known_hosts = asyncssh.import_known_hosts(known_hosts) known_hosts = asyncssh.match_known_hosts(known_hosts, host, addr, port) else: known_hosts = asyncssh.import_known_hosts(known_hosts) return self.check_match(known_hosts, results, host, addr, port) def test_match(self): """Test known host matching""" matches = ( ('Empty file', ([], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Exact host and port', (['[host]:22'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact host', (['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact host CA', ([], ['host'], [], [], [], [], []), ([], [0], [], [], [], [], [])), ('Exact host revoked', ([], [], ['host'], [], [], [], []), ([], [], [0], [], [], [], [])), ('Multiple exact', (['host'], ['host'], [], [], [], [], []), ([0], [0], [], [], [], [], [])), ('Wildcard host', (['hos*'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Mismatched port', (['[host]:23'], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Negative host', (['hos*,!host'], [], [], [], [], [], []), ([], [], [], [], [], [], [])), ('Exact addr and port', (['[1.2.3.4]:22'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Exact addr', (['1.2.3.4'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Subnet', (['1.2.3.0/24'], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Negative addr', (['1.2.3.0/24,!1.2.3.4', [], [], []], [], [], []), ([], [], [], [], [], [], [])), ('Hashed host', ([_hash('host')], [], [], [], [], [], []), ([0], [], [], [], [], [], [])), ('Hashed addr', ([_hash('1.2.3.4')], [], [], [], [], [], []), ([0], [], [], [], [], [], [])) ) if x509_available: # pragma: no branch matches += ( ('Exact host X.509', ([], [], [], ['host'], [], [], []), ([], [], [], [0], [], [], [])), ('Exact host X.509 revoked', ([], [], [], [], ['host'], [], []), ([], [], [], [], [0], [], [])), ('Exact host subject', ([], [], [], [], [], ['host'], []), ([], [], [], [], [], [0], [])), ('Exact host revoked subject', ([], [], [], [], [], [], ['host']), ([], [], [], [], [], [], [0])), ) for testname, patlists, result in matches: with self.subTest(testname): self.check_hosts(patlists, result) def test_no_addr(self): """Test match without providing addr""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), addr='') self.check_hosts((['1.2.3.4'], [], [], [], [], [], []), ([], [], [], [], [], [], []), addr='') def test_no_port(self): """Test match without providing port""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), port=None) self.check_hosts((['[host]:22'], [], [], [], [], [], []), ([], [], [], [], [], [], []), port=None) def test_no_match(self): """Test for cases where no match is found""" no_match = (([], [], [], [], [], [], []), (['host1', 'host2'], [], [], [], [], [], []), (['2.3.4.5', '3.4.5.6'], [], [], [], [], [], []), (['[host]:2222', '[host]:22222'], [], [], [], [], [], [])) for patlists in no_match: self.check_hosts(patlists, ([], [], [], [], [], [], [])) def test_scoped_addr(self): """Test match on scoped addresses""" self.check_hosts((['fe80::1%1'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), addr='fe80::1%1') self.check_hosts((['fe80::%1/64'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), addr='fe80::1%1') self.check_hosts((['fe80::1%2'], [], [], [], [], [], []), ([], [], [], [], [], [], []), addr='fe80::1%1') self.check_hosts((['2001:2::%3/64'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), addr='2001:2::1') def test_missing_key(self): """Test for line with missing key data""" with self.assertRaises(ValueError): self.check_match(b'xxx\n') def test_missing_key_with_tag(self): """Test for line with tag with missing key data""" with self.assertRaises(ValueError): self.check_match(b'@cert-authority xxx\n') def test_invalid_key(self): """Test for line with invalid key""" self.check_match(b'xxx yyy\n', ([], [], [], [], [], [], [])) def test_invalid_marker(self): """Test for line with invalid marker""" with self.assertRaises(ValueError): self.check_match(b'@xxx yyy zzz\n') def test_incomplete_hash(self): """Test for line with incomplete host hash""" with self.assertRaises(ValueError): self.check_hosts((['|1|aaaa'], [], [], [], [], [], [], [], [])) def test_invalid_hash(self): """Test for line with invalid host hash""" with self.assertRaises(ValueError): self.check_hosts((['|1|aaa'], [], [], [], [], [], [], [], [])) def test_unknown_hash_type(self): """Test for line with unknown host hash type""" with self.assertRaises(ValueError): self.check_hosts((['|2|aaaa|'], [], [], [], [], [], [], [], [])) def test_file(self): """Test match against file""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), from_file=True) def test_bytes(self): """Test match against byte string""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), from_bytes=True) def test_callable(self): """Test match using callable""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), as_callable=True) def test_tuple(self): """Test passing already constructed tuple of keys""" self.check_hosts((['host'], [], [], [], [], [], []), ([0], [], [], [], [], [], []), as_tuple=True) asyncssh-2.20.0/tests/test_logging.py000066400000000000000000000141171475467777400176670ustar00rootroot00000000000000# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH logging API""" import asyncssh from asyncssh.logging import logger from asyncssh.session import SSHClientSession from asyncssh.sftp import SFTPServer from .server import ServerTestCase from .util import asynctest, echo async def _handle_client(process): """Handle a new client request""" await echo(process.stdin, process.stdout, process.stderr) process.close() await process.wait_closed() class _SFTPServer(SFTPServer): """Test SFTP server""" def stat(self, path): """Get attributes of a file or directory""" self.logger.info('stat called') return super().stat(path) class _TestLogging(ServerTestCase): """Unit tests for AsyncSSH logging API""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(process_factory=_handle_client, sftp_factory=_SFTPServer) @asynctest async def test_logging(self): """Test AsyncSSH logging""" asyncssh.set_log_level('INFO') with self.assertLogs(level='INFO') as log: logger.info('Test') self.assertEqual(len(log.records), 1) self.assertEqual(log.records[0].msg, 'Test') @asynctest async def test_debug_levels(self): """Test log debug levels""" asyncssh.set_log_level('DEBUG') for debug_level in range(1, 4): with self.subTest(debug_level=debug_level): asyncssh.set_debug_level(debug_level) with self.assertLogs(level='DEBUG') as log: logger.debug1('DEBUG') logger.debug2('DEBUG') logger.packet(None, b'', 'DEBUG') self.assertEqual(len(log.records), debug_level) for record in log.records: self.assertEqual(record.msg, record.levelname) @asynctest async def test_packet_logging(self): """Test packet logging""" asyncssh.set_log_level('DEBUG') asyncssh.set_debug_level(3) with self.assertLogs(level='DEBUG') as log: logger.packet(0, bytes(range(0x10, 0x30)), 'CONTROL') self.assertEqual(log.records[0].msg, '[pktid=0] CONTROL\n' + ' 00000000: 10 11 12 13 14 15 16 17 18 ' + '19 1a 1b 1c 1d 1e 1f ................\n' + ' 00000010: 20 21 22 23 24 25 26 27 28 ' + '29 2a 2b 2c 2d 2e 2f !"#$%%&\'()*+,-./') @asynctest async def test_connection_log(self): """Test connection-level logger""" asyncssh.set_log_level('INFO') async with self.connect() as conn: with self.assertLogs(level='INFO') as log: conn.logger.info('Test') self.assertEqual(len(log.records), 1) self.assertRegex(log.records[0].msg, r'\[conn=\d+\] Test') @asynctest async def test_channel_log(self): """Test channel-level logger""" asyncssh.set_log_level('INFO') async with self.connect() as conn: for i in range(2): chan, _ = await conn.create_session(SSHClientSession) with self.assertLogs(level='INFO') as log: chan.logger.info('Test') chan.write_eof() await chan.wait_closed() self.assertEqual(len(log.records), 1) self.assertRegex(log.records[0].msg, rf'\[conn=\d+, chan={i}\] Test') @asynctest async def test_stream_log(self): """Test stream-level logger""" asyncssh.set_log_level('INFO') async with self.connect() as conn: stdin, _, _ = await conn.open_session() with self.assertLogs(level='INFO') as log: stdin.logger.info('Test') stdin.write_eof() await stdin.channel.wait_closed() self.assertEqual(len(log.records), 1) self.assertRegex(log.records[0].msg, r'\[conn=\d+, chan=0\] Test') @asynctest async def test_process_log(self): """Test process-level logger""" asyncssh.set_log_level('INFO') async with self.connect() as conn: process = await conn.create_process() with self.assertLogs(level='INFO') as log: process.logger.info('Test') process.stdin.write_eof() await process.wait() asyncssh.set_log_level('WARNING') self.assertEqual(len(log.records), 1) self.assertRegex(log.records[0].msg, r'\[conn=\d+, chan=0\] Test') @asynctest async def test_sftp_log(self): """Test sftp-level logger""" asyncssh.set_sftp_log_level('INFO') async with self.connect() as conn: async with conn.start_sftp_client() as sftp: with self.assertLogs(level='INFO') as log: sftp.logger.info('Test') await sftp.stat('.') asyncssh.set_sftp_log_level('WARNING') self.assertEqual(len(log.records), 1) self.assertEqual(log.records[0].name, 'asyncssh.sftp') self.assertRegex(log.records[0].msg, r'\[conn=\d+, chan=0\] Test') @asynctest async def test_invalid_debug_level(self): """Test invalid debug level""" with self.assertRaises(ValueError): asyncssh.set_debug_level(5) asyncssh-2.20.0/tests/test_mac.py000066400000000000000000000044241475467777400170010ustar00rootroot00000000000000# Copyright (c) 2015-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for message authentication""" import os import unittest from asyncssh.mac import get_mac_algs, get_mac_params, get_mac class _TestMAC(unittest.TestCase): """Unit tests for mac module""" def test_mac_algs(self): """Unit test MAC algorithms""" for mac_alg in get_mac_algs(): with self.subTest(mac_alg=mac_alg): mac_keysize, _, _ = get_mac_params(mac_alg) mac_key = os.urandom(mac_keysize) packet = os.urandom(256) enc_mac = get_mac(mac_alg, mac_key) dec_mac = get_mac(mac_alg, mac_key) badpacket = bytearray(packet) badpacket[-1] ^= 0xff mac = enc_mac.sign(0, packet) badmac = bytearray(mac) badmac[-1] ^= 0xff self.assertTrue(dec_mac.verify(0, packet, mac)) self.assertFalse(dec_mac.verify(0, bytes(badpacket), mac)) self.assertFalse(dec_mac.verify(0, packet, bytes(badmac))) def test_umac_wrapper(self): """Unit test some unused parts of the UMAC wrapper code""" try: # pylint: disable=import-outside-toplevel from asyncssh.crypto import umac32 except ImportError: # pragma: no cover self.skipTest('umac not available') mac_key = os.urandom(16) mac1 = umac32(mac_key) mac1.update(b'test') mac2 = mac1.copy() mac1.update(b'123') mac2.update(b'123') self.assertEqual(mac1.hexdigest(), mac2.hexdigest()) asyncssh-2.20.0/tests/test_packet.py000066400000000000000000000153201475467777400175050ustar00rootroot00000000000000# Copyright (c) 2016-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for SSH packet encoding and decoding""" import codecs import unittest from asyncssh.packet import Byte, Boolean, UInt32, UInt64, String, MPInt from asyncssh.packet import NameList, PacketDecodeError, SSHPacket class _TestPacket(unittest.TestCase): """Unit tests for SSH packet module""" tests = [ (Byte, SSHPacket.get_byte, [ (0, '00'), (127, '7f'), (128, '80'), (255, 'ff') ]), (Boolean, SSHPacket.get_boolean, [ (False, '00'), (True, '01') ]), (UInt32, SSHPacket.get_uint32, [ (0, '00000000'), (256, '00000100'), (0x12345678, '12345678'), (0x7fffffff, '7fffffff'), (0x80000000, '80000000'), (0xffffffff, 'ffffffff') ]), (UInt64, SSHPacket.get_uint64, [ (0, '0000000000000000'), (256, '0000000000000100'), (0x123456789abcdef0, '123456789abcdef0'), (0x7fffffffffffffff, '7fffffffffffffff'), (0x8000000000000000, '8000000000000000'), (0xffffffffffffffff, 'ffffffffffffffff') ]), (String, SSHPacket.get_string, [ (b'', '00000000'), (b'foo', '00000003666f6f'), (1024*b'\xff', '00000400' + 1024*'ff') ]), (MPInt, SSHPacket.get_mpint, [ (0, '00000000'), (1, '0000000101'), (127, '000000017f'), (128, '000000020080'), (32767, '000000027fff'), (32768, '00000003008000'), (0x123456789abcdef01234, '0000000a123456789abcdef01234'), (-1, '00000001ff'), (-128, '0000000180'), (-129, '00000002ff7f'), (-32768, '000000028000'), (-32769, '00000003ff7fff'), (-0xdeadbeef, '00000005ff21524111') ]), (NameList, SSHPacket.get_namelist, [ ([], '00000000'), ([b'foo'], '00000003666f6f'), ([b'foo', b'bar'], '00000007666f6f2c626172') ]) ] encode_errors = [ (Byte, -1, ValueError), (Byte, 256, ValueError), (Byte, 'a', TypeError), (UInt32, None, AttributeError), (UInt32, -1, OverflowError), (UInt32, 0x100000000, OverflowError), (UInt64, None, AttributeError), (UInt64, -1, OverflowError), (UInt64, 0x10000000000000000, OverflowError), (String, None, TypeError), (String, True, TypeError), (String, 0, TypeError), (MPInt, None, AttributeError), (MPInt, '', AttributeError), (MPInt, [], AttributeError), (NameList, None, TypeError), (NameList, 'xxx', TypeError) ] decode_errors = [ (SSHPacket.get_byte, ''), (SSHPacket.get_byte, '1234'), (SSHPacket.get_boolean, ''), (SSHPacket.get_boolean, '1234'), (SSHPacket.get_uint32, '123456'), (SSHPacket.get_uint32, '1234567890'), (SSHPacket.get_uint64, '12345678'), (SSHPacket.get_uint64, '123456789abcdef012'), (SSHPacket.get_string, '123456'), (SSHPacket.get_string, '12345678'), (SSHPacket.get_string, '000000011234') ] def test_packet(self): """Unit test SSH packet module""" for encode, decode, values in self.tests: for value, data in values: data = codecs.decode(data, 'hex') with self.subTest(msg='encode', value=value): self.assertEqual(encode(value), data) with self.subTest(msg='decode', data=data): packet = SSHPacket(data) decoded_value = decode(packet) packet.check_end() self.assertEqual(decoded_value, value) self.assertEqual(packet.get_consumed_payload(), data) self.assertEqual(packet.get_remaining_payload(), b'') for encode, value, exc in self.encode_errors: with self.subTest(msg='encode error', encode=encode, value=value): with self.assertRaises(exc): encode(value) for decode, data in self.decode_errors: with self.subTest(msg='decode error', data=data): with self.assertRaises(PacketDecodeError): packet = SSHPacket(codecs.decode(data, 'hex')) decode(packet) packet.check_end() def test_unicode(self): """Unit test encoding of UTF-8 string""" self.assertEqual(String('\u2000'), b'\x00\x00\x00\x03\xe2\x80\x80') asyncssh-2.20.0/tests/test_pkcs11.py000066400000000000000000000144331475467777400173440ustar00rootroot00000000000000# Copyright (c) 2020-2021 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH PKCS#11 security token support""" import unittest import asyncssh from .pkcs11_stub import pkcs11_available from .pkcs11_stub import get_pkcs11_public_keys, get_pkcs11_certs from .pkcs11_stub import stub_pkcs11, unstub_pkcs11 from .server import ServerTestCase from .util import asynctest class _CheckPKCS11Auth(ServerTestCase): """Common code for testing security key authentication""" _certs_available = False _pkcs11_tokens = [ ('Token 1', b'1234', [('ssh-rsa', 'RSA key'), ('ecdsa-sha2-nistp256', 'EC key 1')]), ('Token 2', b'5678', [('ecdsa-sha2-nistp384', 'EC key 2'), ('ssh-ed25519', 'ED key (unsupported)')]) ] @classmethod async def start_server(cls): """Start an SSH server which supports security key authentication""" cls.addClassCleanup(unstub_pkcs11, *stub_pkcs11(cls._pkcs11_tokens)) pubkeys = get_pkcs11_public_keys() certs = get_pkcs11_certs() cls._certs_available = bool(certs) for cert in certs: cert.append_certificate('auth_keys') for key in pubkeys: key.append_public_key('auth_keys') if pubkeys: ca_key = asyncssh.read_private_key('ckey') cert = ca_key.generate_user_certificate(pubkeys[0], 'name', principals=['ckey']) with open('auth_keys', 'a') as auth_keys: auth_keys.write('cert-authority ') ca_key.append_public_key('auth_keys') cert.write_certificate('pkcs11_cert.pub') auth_keys = 'auth_keys' if cls._pkcs11_tokens else () return await cls.create_server(authorized_client_keys=auth_keys, x509_trusted_certs=certs) @unittest.skipUnless(pkcs11_available, 'pkcs11 support not available') class _TestPKCS11TokenNotFound(_CheckPKCS11Auth): """Unit tests for PKCS#11 authentication with no token found""" _pkcs11_tokens = [] @asynctest async def test_key_not_found(self): """Test PKCS#11 with no token found""" self.assertEqual(asyncssh.load_pkcs11_keys('xxx'), []) @unittest.skipUnless(pkcs11_available, 'pkcs11 support not available') class _TestPKCS11Auth(_CheckPKCS11Auth): """Unit tests for PKCS#11 authentication""" @asynctest async def test_load_keys(self): """Test loading keys and certs from PKCS#11 tokens""" keys = asyncssh.load_pkcs11_keys('xxx') self.assertEqual(len(keys), 6 if self._certs_available else 3) @asynctest async def test_load_keys_without_certs(self): """Test loading keys without certs from PKCS#11 tokens""" keys = asyncssh.load_pkcs11_keys('xxx', load_certs=False) self.assertEqual(len(keys), 3) @asynctest async def test_match_token_label(self): """Test matching on PKCS#11 token label""" keys = asyncssh.load_pkcs11_keys('xxx', token_label='Token 2') self.assertEqual(len(keys), 2 if self._certs_available else 1) @asynctest async def test_match_token_serial(self): """Test matching on PKCS#11 token serial number""" keys = asyncssh.load_pkcs11_keys('xxx', token_serial='1234') self.assertEqual(len(keys), 4 if self._certs_available else 2) @asynctest async def test_match_token_serial_bytes(self): """Test matching on PKCS#11 token serial number as bytes""" keys = asyncssh.load_pkcs11_keys('xxx', token_serial=b'1234') self.assertEqual(len(keys), 4 if self._certs_available else 2) @asynctest async def test_match_key_label(self): """Test matching on PKCS#11 key label""" keys = asyncssh.load_pkcs11_keys('xxx', key_label='EC key 2') self.assertEqual(len(keys), 2 if self._certs_available else 1) @asynctest async def test_match_key_id(self): """Test matching on PKCS#11 key id""" keys = asyncssh.load_pkcs11_keys('xxx', key_id='02') self.assertEqual(len(keys), 2 if self._certs_available else 1) @asynctest async def test_match_key_id_bytes(self): """Test matching on PKCS#11 key id as bytes""" keys = asyncssh.load_pkcs11_keys('xxx', key_id=b'\x02') self.assertEqual(len(keys), 2 if self._certs_available else 1) @asynctest async def test_pkcs11_auth(self): """Test authenticating with PKCS#11 token""" async with self.connect(username='ckey', pkcs11_provider='xxx'): pass @asynctest async def test_pkcs11_load_keys(self): """Test authenticating with explicitly loaded PKCS#11 keys""" for key in asyncssh.load_pkcs11_keys('xxx'): for sig_alg in key.sig_algorithms: sig_alg = sig_alg.decode('ascii') # Disable unit tests that involve SHA-1 hashes if sig_alg in ('ssh-rsa', 'x509v3-ssh-rsa'): continue with self.subTest(key=key.get_comment(), sig_alg=sig_alg): async with self.connect( username='ckey', pkcs11_provider='xxx', client_keys=[key], signature_algs=[sig_alg]): pass @asynctest async def test_pkcs11_with_replaced_cert(self): """Test authenticating with a PKCS#11 with replaced cert""" ckey = asyncssh.load_pkcs11_keys('xxx')[1] async with self.connect(username='ckey', pkcs11_provider='xxx', client_keys=[(ckey, 'pkcs11_cert.pub')]): pass asyncssh-2.20.0/tests/test_process.py000066400000000000000000001420321475467777400177150ustar00rootroot00000000000000# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH process API""" import asyncio import io import os from pathlib import Path from signal import SIGINT import socket import sys import unittest import asyncssh from .server import ServerTestCase from .util import asynctest, echo if sys.platform != 'win32': # pragma: no branch import fcntl import struct import termios try: import aiofiles _aiofiles_available = True except ImportError: # pragma: no cover _aiofiles_available = False async def _handle_client(process): """Handle a new client request""" action = process.command or process.subsystem if not action: action = 'echo' if action == 'break': try: await process.stdin.readline() except asyncssh.BreakReceived as exc: process.exit_with_signal('ABRT', False, str(exc.msec)) elif action == 'delay': await asyncio.sleep(1) await echo(process.stdin, process.stdout, process.stderr) elif action == 'echo': await echo(process.stdin, process.stdout, process.stderr) elif action == 'exit_status': process.channel.set_encoding('utf-8') process.stderr.write('Exiting with status 1') process.exit(1) elif action == 'env': process.channel.set_encoding('utf-8') process.stdout.write(process.env.get('TEST', '')) elif action.startswith('redirect '): _, addr, port, action = action.split(None, 3) async with asyncssh.connect(addr, int(port)) as conn: upstream_process = await conn.create_process( command=action, encoding=None, term_type=process.term_type, stdin=process.stdin, stdout=process.stdout) result = await upstream_process.wait() process.exit_with_signal(*result.exit_signal) elif action == 'redirect_stdin': await process.redirect_stdin(process.stdout) await process.stdout.drain() elif action == 'redirect_stdout': await process.redirect_stdout(process.stdin) await process.stdout.drain() elif action == 'redirect_stderr': await process.redirect_stderr(process.stdin) await process.stderr.drain() elif action == 'old_term': info = str((process.get_terminal_type(), process.get_terminal_size(), process.get_terminal_mode(asyncssh.PTY_OP_OSPEED))) process.channel.set_encoding('utf-8') process.stdout.write(info) elif action == 'term': info = str((process.term_type, process.term_size, process.term_modes.get(asyncssh.PTY_OP_OSPEED), sorted(process.term_modes.items()))) process.channel.set_encoding('utf-8') process.stdout.write(info) elif action == 'term_size': try: await process.stdin.readline() except asyncssh.TerminalSizeChanged as exc: process.exit_with_signal('ABRT', False, f'{exc.width}x{exc.height}') elif action == 'term_size_tty': master, slave = os.openpty() await process.redirect_stdin(master, recv_eof=False) process.stdout.write(b'\n') await process.stdin.readline() size = fcntl.ioctl(slave, termios.TIOCGWINSZ, 8*b'\0') height, width, _, _ = struct.unpack('hhhh', size) process.stdout.write(f'{width}x{height}'.encode()) os.close(slave) elif action == 'term_size_nontty': rpipe, wpipe = os.pipe() await process.redirect_stdin(wpipe) process.stdout.write(b'\n') await process.stdin.readline() os.close(rpipe) elif action == 'timeout': process.channel.set_encoding('utf-8') process.stdout.write('Sleeping') await asyncio.sleep(1) else: process.exit(255) process.close() await process.wait_closed() class _TestProcess(ServerTestCase): """Unit tests for AsyncSSH process API""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(process_factory=_handle_client, encoding=None) class _TestProcessBasic(_TestProcess): """Unit tests for AsyncSSH process basic functions""" @asynctest async def test_shell(self): """Test starting a remote shell""" data = str(id(self)) async with self.connect() as conn: process = await conn.create_process(env={'TEST': 'test'}) process.stdin.write(data) self.assertFalse(process.is_closing()) process.stdin.write_eof() self.assertTrue(process.is_closing()) result = await process.wait() self.assertEqual(result.env, {'TEST': 'test'}) self.assertEqual(result.command, None) self.assertEqual(result.subsystem, None) self.assertEqual(result.exit_status, None) self.assertEqual(result.exit_signal, None) self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_command(self): """Test executing a remote command""" data = str(id(self)) async with self.connect() as conn: process = await conn.create_process('echo') process.stdin.write(data) process.stdin.write_eof() result = await process.wait() self.assertEqual(result.command, 'echo') self.assertEqual(result.subsystem, None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_subsystem(self): """Test starting a remote subsystem""" data = str(id(self)) async with self.connect() as conn: process = await conn.create_process(subsystem='echo') process.stdin.write(data) process.stdin.write_eof() result = await process.wait() self.assertEqual(result.command, None) self.assertEqual(result.subsystem, 'echo') self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_communicate(self): """Test communicate""" data = str(id(self)) async with self.connect() as conn: async with conn.create_process() as process: stdout_data, stderr_data = await process.communicate(data) self.assertEqual(stdout_data, data) self.assertEqual(stderr_data, data) @asynctest async def test_communicate_paused(self): """Test communicate when reading is already paused""" data = 4*1024*1024*'*' async with self.connect() as conn: async with conn.create_process(input=data) as process: await asyncio.sleep(1) stdout_data, stderr_data = await process.communicate() self.assertEqual(stdout_data, data) self.assertEqual(stderr_data, data) @asynctest async def test_env(self): """Test sending environment""" async with self.connect() as conn: process = await conn.create_process('env', env={'TEST': 'test'}) result = await process.wait() self.assertEqual(result.stdout, 'test') @asynctest async def test_old_terminal_info(self): """Test setting and retrieving terminal information with old API""" modes = {asyncssh.PTY_OP_OSPEED: 9600} async with self.connect() as conn: process = await conn.create_process('old_term', term_type='ansi', term_size=(80, 24), term_modes=modes) result = await process.wait() self.assertEqual(result.stdout, "('ansi', (80, 24, 0, 0), 9600)") @asynctest async def test_terminal_info(self): """Test setting and retrieving terminal information""" modes = {asyncssh.PTY_OP_ISPEED: 9600, asyncssh.PTY_OP_OSPEED: 9600} async with self.connect() as conn: process = await conn.create_process('term', term_type='ansi', term_size=(80, 24), term_modes=modes) result = await process.wait() self.assertEqual(result.stdout, "('ansi', (80, 24, 0, 0), 9600, " "[(128, 9600), (129, 9600)])") @asynctest async def test_change_terminal_size(self): """Test changing terminal size""" async with self.connect() as conn: process = await conn.create_process('term_size', term_type='ansi') process.change_terminal_size(80, 24) result = await process.wait() self.assertEqual(result.exit_signal[2], '80x24') @asynctest async def test_break(self): """Test sending a break""" async with self.connect() as conn: process = await conn.create_process('break') process.send_break(1000) result = await process.wait() self.assertEqual(result.exit_signal[2], '1000') @asynctest async def test_signal(self): """Test sending a signal""" async with self.connect() as conn: process = await conn.create_process() process.send_signal('INT') result = await process.wait() self.assertEqual(result.exit_signal[0], 'INT') self.assertEqual(result.returncode, -SIGINT) @asynctest async def test_numeric_signal(self): """Test sending a signal using a numeric value""" async with self.connect() as conn: process = await conn.create_process() process.send_signal(SIGINT) result = await process.wait() self.assertEqual(result.exit_signal[0], 'INT') self.assertEqual(result.returncode, -SIGINT) @asynctest async def test_terminate(self): """Test sending a terminate signal""" async with self.connect() as conn: process = await conn.create_process() process.terminate() result = await process.wait() self.assertEqual(result.exit_signal[0], 'TERM') @asynctest async def test_kill(self): """Test sending a kill signal""" async with self.connect() as conn: process = await conn.create_process() process.kill() result = await process.wait() self.assertEqual(result.exit_signal[0], 'KILL') @asynctest async def test_exit_status(self): """Test checking exit status""" async with self.connect() as conn: result = await conn.run('exit_status') self.assertEqual(result.exit_status, 1) self.assertEqual(result.returncode, 1) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, 'Exiting with status 1') @asynctest async def test_raise_on_exit_status(self): """Test raising an exception on non-zero exit status""" async with self.connect() as conn: with self.assertRaises(asyncssh.ProcessError) as exc: await conn.run('exit_status', env={'TEST': 'test'}, check=True) self.assertEqual(exc.exception.env, {'TEST': 'test'}) self.assertEqual(exc.exception.command, 'exit_status') self.assertEqual(exc.exception.subsystem, None) self.assertEqual(exc.exception.exit_status, 1) self.assertEqual(exc.exception.reason, 'Process exited with non-zero exit status 1') self.assertEqual(exc.exception.returncode, 1) @asynctest async def test_raise_on_timeout(self): """Test raising an exception on timeout""" async with self.connect() as conn: with self.assertRaises(asyncssh.ProcessError) as exc: await conn.run('timeout', timeout=0.1) self.assertEqual(exc.exception.command, 'timeout') self.assertEqual(exc.exception.reason, '') self.assertEqual(exc.exception.stdout, 'Sleeping') @asynctest async def test_exit_signal(self): """Test checking exit signal""" async with self.connect() as conn: process = await conn.create_process() process.send_signal('INT') result = await process.wait() self.assertEqual(result.exit_status, -1) self.assertEqual(result.exit_signal[0], 'INT') self.assertEqual(result.returncode, -SIGINT) @asynctest async def test_raise_on_exit_signal(self): """Test raising an exception on exit signal""" async with self.connect() as conn: process = await conn.create_process() with self.assertRaises(asyncssh.ProcessError) as exc: process.send_signal('INT') await process.wait(check=True) self.assertEqual(exc.exception.exit_status, -1) self.assertEqual(exc.exception.exit_signal[0], 'INT') self.assertEqual(exc.exception.reason, 'Process exited with signal INT') self.assertEqual(exc.exception.returncode, -SIGINT) @asynctest async def test_split_unicode(self): """Test Unicode split across blocks""" data = '\u2000test\u2000' with open('stdin', 'w', encoding='utf-8') as file: file.write(data) async with self.connect() as conn: result = await conn.run('echo', stdin='stdin', bufsize=2) self.assertEqual(result.stdout, data) @asynctest async def test_invalid_unicode(self): """Test invalid Unicode data""" data = b'\xfftest' with open('stdin', 'wb') as file: file.write(data) async with self.connect() as conn: with self.assertRaises(asyncssh.ProtocolError): await conn.run('echo', stdin='stdin') @asynctest async def test_ignoring_invalid_unicode(self): """Test ignoring invalid Unicode data""" data = b'\xfftest' with open('stdin', 'wb') as file: file.write(data) async with self.connect() as conn: await conn.run('echo', stdin='stdin', encoding='utf-8', errors='ignore') @asynctest async def test_incomplete_unicode(self): """Test incomplete Unicode data""" data = '\u2000'.encode()[:2] with open('stdin', 'wb') as file: file.write(data) async with self.connect() as conn: with self.assertRaises(asyncssh.ProtocolError): await conn.run('echo', stdin='stdin') @asynctest async def test_disconnect(self): """Test collecting output from a disconnected channel""" data = str(id(self)) async with self.connect() as conn: process = await conn.create_process() process.stdin.write(data) process.send_signal('ABRT') result = await process.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_get_extra_info(self): """Test get_extra_info on streams""" async with self.connect() as conn: process = await conn.create_process() self.assertEqual(process.get_extra_info('connection'), conn) process.stdin.write_eof() await process.wait() @asynctest async def test_unknown_action(self): """Test unknown action""" async with self.connect() as conn: result = await conn.run('unknown') self.assertEqual(result.exit_status, 255) class _TestProcessRedirection(_TestProcess): """Unit tests for AsyncSSH process I/O redirection""" @asynctest async def test_input(self): """Test with input from a string""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('echo', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_devnull(self): """Test with stdin redirected to DEVNULL""" async with self.connect() as conn: result = await conn.run('echo', stdin=asyncssh.DEVNULL) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, '') @asynctest async def test_stdin_file(self): """Test with stdin redirected to a file""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) async with self.connect() as conn: result = await conn.run('echo', stdin='stdin') self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_binary_file(self): """Test with stdin redirected to a file in binary mode""" data = str(id(self)).encode() + b'\xff' with open('stdin', 'wb') as file: file.write(data) async with self.connect() as conn: result = await conn.run('echo', stdin='stdin', encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_pathlib(self): """Test with stdin redirected to a file name specified by pathlib""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) async with self.connect() as conn: result = await conn.run('echo', stdin=Path('stdin')) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_open_file(self): """Test with stdin redirected to an open file""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = open('stdin') async with self.connect() as conn: result = await conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_open_binary_file(self): """Test with stdin redirected to an open file in binary mode""" data = str(id(self)).encode() + b'\xff' with open('stdin', 'wb') as file: file.write(data) file = open('stdin', 'rb') async with self.connect() as conn: result = await conn.run('echo', stdin=file, encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_stringio(self): """Test with stdin redirected to a StringIO object""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = io.StringIO(data) async with self.connect() as conn: result = await conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_bytesio(self): """Test with stdin redirected to a BytesIO object""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = io.BytesIO(data.encode('ascii')) async with self.connect() as conn: result = await conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_process(self): """Test with stdin redirected to another SSH process""" data = str(id(self)) async with self.connect() as conn: proc1 = await conn.create_process(input=data) proc2 = await conn.create_process(stdin=proc1.stdout) result = await proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_forward_terminal_size(self): """Test forwarding a terminal size change""" async with self.connect() as conn: cmd = f'redirect {self._server_addr} {self._server_port} term_size' process = await conn.create_process(cmd, term_type='ansi') process.change_terminal_size(80, 24) result = await process.wait() self.assertEqual(result.exit_signal[2], '80x24') @unittest.skipIf(sys.platform == 'win32', 'skip TTY terminal size tests on Windows') @asynctest async def test_forward_terminal_size_tty(self): """Test forwarding a terminal size change to a remote tty""" async with self.connect() as conn: process = await conn.create_process('term_size_tty', term_type='ansi') await process.stdout.readline() process.change_terminal_size(80, 24) process.stdin.write_eof() result = await process.wait() self.assertEqual(result.stdout, '80x24') @unittest.skipIf(sys.platform == 'win32', 'skip TTY terminal size tests on Windows') @asynctest async def test_forward_terminal_size_nontty(self): """Test forwarding a terminal size change to a remote non-tty""" async with self.connect() as conn: process = await conn.create_process('term_size_nontty', term_type='ansi') await process.stdout.readline() process.change_terminal_size(80, 24) process.stdin.write_eof() result = await process.wait() self.assertEqual(result.stdout, '') @asynctest async def test_forward_break(self): """Test forwarding a break""" async with self.connect() as conn: cmd = f'redirect {self._server_addr} {self._server_port} break' process = await conn.create_process(cmd) process.send_break(1000) result = await process.wait() self.assertEqual(result.exit_signal[2], '1000') @asynctest async def test_forward_signal(self): """Test forwarding a signal""" async with self.connect() as conn: cmd = f'redirect {self._server_addr} {self._server_port} echo' process = await conn.create_process(cmd) process.stdin.write('\n') await process.stdout.readline() process.send_signal('INT') result = await process.wait() self.assertEqual(result.exit_signal[0], 'INT') self.assertEqual(result.returncode, -SIGINT) @unittest.skipIf(sys.platform == 'win32', 'skip asyncio.subprocess tests on Windows') @asynctest async def test_stdin_stream(self): """Test with stdin redirected to an asyncio stream""" data = 4*1024*1024*'*' async with self.connect() as conn: proc1 = await asyncio.create_subprocess_shell( 'cat', stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE) proc1.stdin.write(data.encode('ascii')) proc1.stdin.write_eof() proc2 = await conn.create_process('delay', stdin=proc1.stdout) result = await proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdout_devnull(self): """Test with stdout redirected to DEVNULL""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=asyncssh.DEVNULL) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_file(self): """Test with stdout redirected to a file""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('echo', input=data, stdout='stdout') with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_binary_file(self): """Test with stdout redirected to a file in binary mode""" data = str(id(self)).encode() + b'\xff' async with self.connect() as conn: result = await conn.run('echo', input=data, stdout='stdout', encoding=None) with open('stdout', 'rb') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_pathlib(self): """Test with stdout redirected to a file name specified by pathlib""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=Path('stdout')) with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_open_file(self): """Test with stdout redirected to an open file""" data = str(id(self)) file = open('stdout', 'w') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_open_file_keep_open(self): """Test with stdout redirected to an open file which remains open""" data = str(id(self)) with open('stdout', 'w') as file: async with self.connect() as conn: await conn.run('echo', input=data, stdout=file, recv_eof=False) await conn.run('echo', input=data, stdout=file, recv_eof=False) with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, 2*data) @asynctest async def test_stdout_open_binary_file(self): """Test with stdout redirected to an open binary file""" data = str(id(self)).encode() + b'\xff' file = open('stdout', 'wb') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file, encoding=None) with open('stdout', 'rb') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_stringio(self): """Test with stdout redirected to a StringIO""" class _StringIOTest(io.StringIO): """Test class for StringIO which preserves output after close""" def __init__(self): super().__init__() self.output = None def close(self): if self.output is None: self.output = self.getvalue() super().close() data = str(id(self)) file = _StringIOTest() async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) self.assertEqual(file.output, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_bytesio(self): """Test with stdout redirected to a BytesIO""" class _BytesIOTest(io.BytesIO): """Test class for BytesIO which preserves output after close""" def __init__(self): super().__init__() self.output = None def close(self): if self.output is None: self.output = self.getvalue() super().close() data = str(id(self)) file = _BytesIOTest() async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) self.assertEqual(file.output, data.encode('ascii')) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_process(self): """Test with stdout redirected to another SSH process""" data = str(id(self)) async with self.connect() as conn: async with conn.create_process() as proc2: proc1 = await conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) proc1.stdin.write_eof() result = await proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @unittest.skipIf(sys.platform == 'win32', 'skip asyncio.subprocess tests on Windows') @asynctest async def test_stdout_stream(self): """Test with stdout redirected to an asyncio stream""" data = str(id(self)) async with self.connect() as conn: proc2 = await asyncio.create_subprocess_shell( 'cat', stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE) async with conn.create_process(input=data, stdout=proc2.stdin): stdout_data = await proc2.stdout.read() self.assertEqual(stdout_data, data.encode('ascii')) @unittest.skipIf(sys.platform == 'win32', 'skip asyncio.subprocess tests on Windows') @asynctest async def test_stdout_stream_keep_open(self): """Test with stdout redirected to asyncio stream which remains open""" data = str(id(self)) async with self.connect() as conn: proc2 = await asyncio.create_subprocess_shell( 'cat', stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE) await conn.run('echo', input=data, stdout=proc2.stdin, stderr=asyncssh.DEVNULL, recv_eof=False) await conn.run('echo', input=data, stdout=proc2.stdin, stderr=asyncssh.DEVNULL) stdout_data = await proc2.stdout.read() self.assertEqual(stdout_data, 2*data.encode('ascii')) @asynctest async def test_change_stdout(self): """Test changing stdout of an open process""" async with self.connect() as conn: process = await conn.create_process(stdout='stdout') process.stdin.write('xxx') await asyncio.sleep(0.1) await process.redirect_stdout(asyncssh.PIPE) process.stdin.write('yyy') process.stdin.write_eof() result = await process.wait() with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, 'xxx') self.assertEqual(result.stdout, 'yyy') self.assertEqual(result.stderr, 'xxxyyy') @asynctest async def test_change_stdin_process(self): """Test changing stdin of an open process reading from another""" data = str(id(self)) async with self.connect() as conn: async with conn.create_process() as proc2: proc1 = await conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) await asyncio.sleep(0.1) await proc2.redirect_stdin(asyncssh.PIPE) proc2.stdin.write(data) await asyncio.sleep(0.1) await proc2.redirect_stdin(proc1.stdout) proc1.stdin.write_eof() result = await proc2.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @asynctest async def test_change_stdout_process(self): """Test changing stdout of an open process sending to another""" data = str(id(self)) async with self.connect() as conn: async with conn.create_process() as proc2: proc1 = await conn.create_process(stdout=proc2.stdin) proc1.stdin.write(data) await asyncio.sleep(0.1) await proc1.redirect_stdout(asyncssh.DEVNULL) proc1.stdin.write(data) await asyncio.sleep(0.1) await proc1.redirect_stdout(proc2.stdin) proc1.stdin.write_eof() result = await proc2.wait() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stderr_stdout(self): """Test with stderr redirected to stdout""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('echo', input=data, stderr=asyncssh.STDOUT) self.assertEqual(result.stdout, data+data) @asynctest async def test_server_redirect_stdin(self): """Test redirect on server of stdin""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('redirect_stdin', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, '') @asynctest async def test_server_redirect_stdout(self): """Test redirect on server of stdout""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('redirect_stdout', input=data) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, '') @asynctest async def test_server_redirect_stderr(self): """Test redirect on server of stderr""" data = str(id(self)) async with self.connect() as conn: result = await conn.run('redirect_stderr', input=data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_pause_file_reader(self): """Test pausing and resuming reading from a file""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) async with self.connect() as conn: result = await conn.run('echo', stdin='stdin', stderr=asyncssh.DEVNULL) self.assertEqual(result.stdout, data) @asynctest async def test_pause_process_reader(self): """Test pausing and resuming reading from another SSH process""" data = 4*1024*1024*'*' async with self.connect() as conn: proc1 = await conn.create_process(input=data) proc2 = await conn.create_process('delay', stdin=proc1.stdout, stderr=asyncssh.DEVNULL) proc3 = await conn.create_process('delay', stdin=proc1.stderr, stderr=asyncssh.DEVNULL) result2, result3 = await asyncio.gather(proc2.wait(), proc3.wait()) self.assertEqual(result2.stdout, data) self.assertEqual(result3.stdout, data) @asynctest async def test_redirect_stdin_when_paused(self): """Test redirecting stdin when write is paused""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) async with self.connect() as conn: process = await conn.create_process() process.stdin.write(data) await process.redirect_stdin('stdin') result = await process.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @asynctest async def test_redirect_process_when_paused(self): """Test redirecting away from a process when write is paused""" data = 4*1024*1024*'*' async with self.connect() as conn: proc1 = await conn.create_process(input=data) proc2 = await conn.create_process('delay', stdin=proc1.stdout) proc3 = await conn.create_process('delay', stdin=proc1.stderr) await proc1.redirect_stderr(asyncssh.DEVNULL) result = await proc2.wait() proc3.close() self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_consecutive_redirect(self): """Test consecutive redirects using drain""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) async with self.connect() as conn: process = await conn.create_process() await process.redirect_stdin('stdin', send_eof=False) await process.stdin.drain() await process.redirect_stdin('stdin') result = await process.wait() self.assertEqual(result.stdout, data+data) self.assertEqual(result.stderr, data+data) @unittest.skipUnless(_aiofiles_available, 'Async file I/O not available') class _TestAsyncFileRedirection(_TestProcess): """Unit tests for AsyncSSH async file redirection""" @asynctest async def test_stdin_aiofile(self): """Test with stdin redirected to an aiofile""" data = str(id(self)) with open('stdin', 'w') as file: file.write(data) file = await aiofiles.open('stdin', 'r') async with self.connect() as conn: result = await conn.run('echo', stdin=file) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_binary_aiofile(self): """Test with stdin redirected to an aiofile in binary mode""" data = str(id(self)).encode() + b'\xff' with open('stdin', 'wb') as file: file.write(data) file = await aiofiles.open('stdin', 'rb') async with self.connect() as conn: result = await conn.run('echo', stdin=file, encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdout_aiofile(self): """Test with stdout redirected to an aiofile""" data = str(id(self)) file = await aiofiles.open('stdout', 'w') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_aiofile_keep_open(self): """Test with stdout redirected to an aiofile which remains open""" data = str(id(self)) async with aiofiles.open('stdout', 'w') as file: async with self.connect() as conn: await conn.run('echo', input=data, stdout=file, recv_eof=False) await conn.run('echo', input=data, stdout=file, recv_eof=False) with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, 2*data) @asynctest async def test_stdout_binary_aiofile(self): """Test with stdout redirected to an aiofile in binary mode""" data = str(id(self)).encode() + b'\xff' file = await aiofiles.open('stdout', 'wb') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file, encoding=None) with open('stdout', 'rb') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @asynctest async def test_pause_async_file_reader(self): """Test pausing and resuming reading from an aiofile""" data = 4*1024*1024*'*' with open('stdin', 'w') as file: file.write(data) file = await aiofiles.open('stdin', 'r') async with self.connect() as conn: result = await conn.run('delay', stdin=file, stderr=asyncssh.DEVNULL) self.assertEqual(result.stdout, data) @asynctest async def test_pause_async_file_writer(self): """Test pausing and resuming writing to an aiofile""" data = 4*1024*1024*'*' async with aiofiles.open('stdout', 'w') as file: async with self.connect() as conn: await conn.run('delay', input=data, stdout=file, stderr=asyncssh.DEVNULL) with open('stdout') as file: self.assertEqual(file.read(), data) @unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows') class _TestProcessPipes(_TestProcess): """Unit tests for AsyncSSH process I/O using pipes""" @asynctest async def test_stdin_pipe(self): """Test with stdin redirected to a pipe""" data = str(id(self)) rpipe, wpipe = os.pipe() os.write(wpipe, data.encode()) os.close(wpipe) async with self.connect() as conn: result = await conn.run('echo', stdin=rpipe) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_text_pipe(self): """Test with stdin redirected to a pipe in text mode""" data = str(id(self)) rpipe, wpipe = os.pipe() rpipe = os.fdopen(rpipe, 'r') wpipe = os.fdopen(wpipe, 'w') wpipe.write(data) wpipe.close() async with self.connect() as conn: result = await conn.run('echo', stdin=rpipe) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdin_binary_pipe(self): """Test with stdin redirected to a pipe in binary mode""" data = str(id(self)).encode() + b'\xff' rpipe, wpipe = os.pipe() os.write(wpipe, data) os.close(wpipe) async with self.connect() as conn: result = await conn.run('echo', stdin=rpipe, encoding=None) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_stdout_pipe(self): """Test with stdout redirected to a pipe""" data = str(id(self)) rpipe, wpipe = os.pipe() async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=wpipe) stdout_data = os.read(rpipe, 1024) os.close(rpipe) self.assertEqual(stdout_data.decode(), data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_pipe_keep_open(self): """Test with stdout redirected to a pipe which remains open""" data = str(id(self)) rpipe, wpipe = os.pipe() os.write(wpipe, data.encode()) async with self.connect() as conn: await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) os.write(wpipe, data.encode()) os.close(wpipe) stdout_data = os.read(rpipe, 1024) os.close(rpipe) self.assertEqual(stdout_data.decode(), 4*data) @asynctest async def test_stdout_text_pipe(self): """Test with stdout redirected to a pipe in text mode""" data = str(id(self)) rpipe, wpipe = os.pipe() rpipe = os.fdopen(rpipe, 'r') wpipe = os.fdopen(wpipe, 'w') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=wpipe) stdout_data = rpipe.read(1024) rpipe.close() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) @asynctest async def test_stdout_text_pipe_keep_open(self): """Test with stdout to a pipe in text mode which remains open""" data = str(id(self)) rpipe, wpipe = os.pipe() rpipe = os.fdopen(rpipe, 'r') wpipe = os.fdopen(wpipe, 'w') wpipe.write(data) async with self.connect() as conn: await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) wpipe.write(data) wpipe.close() stdout_data = rpipe.read(1024) rpipe.close() self.assertEqual(stdout_data, 4*data) @asynctest async def test_stdout_binary_pipe(self): """Test with stdout redirected to a pipe in binary mode""" data = str(id(self)).encode() + b'\xff' rpipe, wpipe = os.pipe() async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=wpipe, encoding=None) stdout_data = os.read(rpipe, 1024) os.close(rpipe) self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, b'') self.assertEqual(result.stderr, data) @unittest.skipIf(sys.platform == 'win32', 'skip socketpair tests on Windows') class _TestProcessSocketPair(_TestProcess): """Unit tests for AsyncSSH process I/O using socketpair""" @asynctest async def test_stdin_socketpair(self): """Test with stdin redirected to a socketpair""" data = str(id(self)) sock1, sock2 = socket.socketpair() sock1.send(data.encode()) sock1.close() async with self.connect() as conn: result = await conn.run('echo', stdin=sock2) self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) @asynctest async def test_change_stdin(self): """Test changing stdin of an open process""" sock1, sock2 = socket.socketpair() sock3, sock4 = socket.socketpair() sock1.send(b'xxx') sock3.send(b'yyy') async with self.connect() as conn: process = await conn.create_process(stdin=sock2) await asyncio.sleep(0.1) await process.redirect_stdin(sock4) sock1.close() sock3.close() result = await process.wait() self.assertEqual(result.stdout, 'xxxyyy') self.assertEqual(result.stderr, 'xxxyyy') @asynctest async def test_stdout_socketpair(self): """Test with stdout redirected to a socketpair""" data = str(id(self)) sock1, sock2 = socket.socketpair() async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=sock1) stdout_data = sock2.recv(1024) sock2.close() self.assertEqual(stdout_data.decode(), data) self.assertEqual(result.stderr, data) @asynctest async def test_pause_socketpair_pipes(self): """Test pausing and resuming reading from and writing to pipes""" data = 4*1024*1024*b'*' sock1, sock2 = socket.socketpair() sock3, sock4 = socket.socketpair() _, writer1 = await asyncio.open_unix_connection(sock=sock1) writer1.write(data) writer1.close() reader2, writer2 = await asyncio.open_unix_connection(sock=sock4) async with self.connect() as conn: process = await conn.create_process('delay', encoding=None, stdin=sock2, stdout=sock3, stderr=asyncssh.DEVNULL) self.assertEqual((await reader2.read()), data) await process.wait() writer2.close() @asynctest async def test_pause_socketpair_streams(self): """Test pausing and resuming reading from and writing to streams""" data = 4*1024*1024*b'*' sock1, sock2 = socket.socketpair() sock3, sock4 = socket.socketpair() _, writer1 = await asyncio.open_unix_connection(sock=sock1) writer1.write(data) writer1.close() reader2, writer2 = await asyncio.open_unix_connection(sock=sock2) _, writer3 = await asyncio.open_unix_connection(sock=sock3) reader4, writer4 = await asyncio.open_unix_connection(sock=sock4) async with self.connect() as conn: process = await conn.create_process('delay', encoding=None, stdin=reader2, stdout=writer3, stderr=asyncssh.DEVNULL) self.assertEqual((await reader4.read()), data) await process.wait() writer2.close() writer3.close() writer4.close() asyncssh-2.20.0/tests/test_public_key.py000066400000000000000000003045051475467777400203720ustar00rootroot00000000000000# Copyright (c) 2014-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for reading and writing public and private keys Note: These tests look for the openssl and ssh-keygen commands in the user's path and will whenever possible use them to perform interoperability tests. Otherwise, these tests will only test AsyncSSH against itself. """ import binascii from datetime import datetime import os from pathlib import Path import shutil import subprocess import sys import unittest from cryptography.exceptions import UnsupportedAlgorithm import asyncssh from asyncssh.asn1 import der_encode, BitString, ObjectIdentifier from asyncssh.asn1 import TaggedDERObject from asyncssh.crypto import chacha_available, ed25519_available, ed448_available from asyncssh.misc import write_file from asyncssh.packet import MPInt, String, UInt32 from asyncssh.pbe import pkcs1_decrypt from asyncssh.public_key import CERT_TYPE_USER, CERT_TYPE_HOST, SSHKey from asyncssh.public_key import SSHX509CertificateChain from asyncssh.public_key import decode_ssh_certificate from asyncssh.public_key import get_public_key_algs, get_certificate_algs from asyncssh.public_key import get_x509_certificate_algs from asyncssh.public_key import import_certificate_subject from asyncssh.public_key import load_identities from .sk_stub import sk_available, stub_sk, unstub_sk from .util import bcrypt_available, get_test_key, x509_available from .util import make_certificate, run, TempDirTestCase _ES1_SHA1_DES = ObjectIdentifier('1.2.840.113549.1.5.10') _P12_RC4_40 = ObjectIdentifier('1.2.840.113549.1.12.1.2') _ES2 = ObjectIdentifier('1.2.840.113549.1.5.13') _ES2_PBKDF2 = ObjectIdentifier('1.2.840.113549.1.5.12') _ES2_AES128 = ObjectIdentifier('2.16.840.1.101.3.4.1.2') _ES2_DES3 = ObjectIdentifier('1.2.840.113549.3.7') try: _openssl_version = run('openssl version') except subprocess.CalledProcessError: # pragma: no cover _openssl_version = b'' _openssl_available = _openssl_version != b'' if _openssl_available: # pragma: no branch _openssl_curves = run('openssl ecparam -list_curves') else: # pragma: no cover _openssl_curves = b'' # The openssl "-v2prf" option is only available in OpenSSL 1.0.2 or later _openssl_supports_v2prf = _openssl_version >= b'OpenSSL 1.0.2' # Ed25519/Ed448 support via "pkey" is only available in OpenSSL 1.1.1 or later _openssl_supports_pkey = _openssl_version >= b'OpenSSL 1.1.1' if _openssl_version >= b'OpenSSL 3': # pragma: no branch _openssl_legacy = '-provider default -provider legacy ' else: # pragma: no cover _openssl_legacy = '' try: if sys.platform != 'win32': _openssh_version = run('ssh -V') else: # pragma: no cover _openssh_version = b'' except subprocess.CalledProcessError: # pragma: no cover _openssh_version = b'' _openssh_available = _openssh_version != b'' # GCM & Chacha tests require OpenSSH 6.9 due to a bug in earlier versions: # https://bugzilla.mindrot.org/show_bug.cgi?id=2366 _openssh_supports_gcm_chacha = _openssh_version >= b'OpenSSH_6.9' _openssh_supports_arcfour_blowfish_cast = (_openssh_available and _openssh_version < b'OpenSSH_7.6') pkcs1_ciphers = (('aes128-cbc', '-aes128', False), ('aes192-cbc', '-aes192', False), ('aes256-cbc', '-aes256', False), ('des-cbc', '-des', True), ('des3-cbc', '-des3', False)) pkcs8_ciphers = ( ('aes128-cbc', 'sha224', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA224', _openssl_supports_v2prf, False), ('aes128-cbc', 'sha256', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA256', _openssl_supports_v2prf, False), ('aes128-cbc', 'sha384', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA384', _openssl_supports_v2prf, False), ('aes128-cbc', 'sha512', 2, '-v2 aes-128-cbc ' '-v2prf hmacWithSHA512', _openssl_supports_v2prf, False), ('des-cbc', 'md5', 1, '-v1 PBE-MD5-DES', _openssl_available, True), ('des-cbc', 'sha1', 1, '-v1 PBE-SHA1-DES', _openssl_available, True), ('des2-cbc', 'sha1', 1, '-v1 PBE-SHA1-2DES', _openssl_available, False), ('des3-cbc', 'sha1', 1, '-v1 PBE-SHA1-3DES', _openssl_available, False), ('rc4-40', 'sha1', 1, '-v1 PBE-SHA1-RC4-40', _openssl_available, True), ('rc4-128', 'sha1', 1, '-v1 PBE-SHA1-RC4-128', _openssl_available, True), ('aes128-cbc', 'sha1', 2, '-v2 aes-128-cbc', _openssl_available, False), ('aes192-cbc', 'sha1', 2, '-v2 aes-192-cbc', _openssl_available, False), ('aes256-cbc', 'sha1', 2, '-v2 aes-256-cbc', _openssl_available, False), ('blowfish-cbc', 'sha1', 2, '-v2 bf-cbc', _openssl_available, True), ('cast128-cbc', 'sha1', 2, '-v2 cast-cbc', _openssl_available, True), ('des-cbc', 'sha1', 2, '-v2 des-cbc', _openssl_available, True), ('des3-cbc', 'sha1', 2, '-v2 des-ede3-cbc', _openssl_available, False)) openssh_ciphers = ( ('aes128-gcm@openssh.com', _openssh_supports_gcm_chacha), ('aes256-gcm@openssh.com', _openssh_supports_gcm_chacha), ('arcfour', _openssh_supports_arcfour_blowfish_cast), ('arcfour128', _openssh_supports_arcfour_blowfish_cast), ('arcfour256', _openssh_supports_arcfour_blowfish_cast), ('blowfish-cbc', _openssh_supports_arcfour_blowfish_cast), ('cast128-cbc', _openssh_supports_arcfour_blowfish_cast), ('aes128-cbc', _openssh_available), ('aes192-cbc', _openssh_available), ('aes256-cbc', _openssh_available), ('aes128-ctr', _openssh_available), ('aes192-ctr', _openssh_available), ('aes256-ctr', _openssh_available), ('3des-cbc', _openssh_available) ) if chacha_available: # pragma: no branch openssh_ciphers += (('chacha20-poly1305@openssh.com', _openssh_supports_gcm_chacha),) def select_passphrase(cipher, pbe_version=0): """Randomize between string and bytes version of passphrase""" if cipher is None: return None elif os.urandom(1)[0] & 1: return 'passphrase' elif pbe_version == 1 and cipher in ('des2-cbc', 'des3-cbc', 'rc4-40', 'rc4-128'): return 'passphrase'.encode('utf-16-be') else: return b'passphrase' class _TestPublicKey(TempDirTestCase): """Unit tests for public key modules""" # pylint: disable=too-many-public-methods keyclass = None base_format = None private_formats = () public_formats = () default_cert_version = '' x509_supported = False generate_args = () single_cipher = True use_openssh = _openssh_available use_openssl = _openssl_available def __init__(self, methodName='runTest'): super().__init__(methodName) self.privkey = None self.pubkey = None self.privca = None self.pubca = None self.usercert = None self.hostcert = None self.rootx509 = None self.userx509 = None self.hostx509 = None self.otherx509 = None def make_certificate(self, *args, **kwargs): """Construct an SSH certificate""" return make_certificate(self.default_cert_version, *args, **kwargs) def validate_openssh(self, cert, cert_type, name): """Check OpenSSH certificate validation""" self.assertIsNone(cert.validate(cert_type, name)) def validate_x509(self, cert, user_principal=None): """Check X.509 certificate validation""" self.assertIsNone(cert.validate_chain([], [self.rootx509], [], 'any', user_principal, None)) with self.assertRaises(ValueError): cert.validate_chain([self.rootx509], [], [], 'any', None, None) chain = SSHX509CertificateChain.construct_from_certs([cert]) self.assertEqual(chain, decode_ssh_certificate(chain.public_data)) self.assertIsNone(chain.validate_chain([self.rootx509], [], [], 'any', user_principal, None)) self.assertIsNone(chain.validate_chain([self.rootx509], [], [self.otherx509], 'any', user_principal, None)) with self.assertRaises(ValueError): chain.validate_chain([], [], [], 'any', user_principal, None) with self.assertRaises(ValueError): chain.validate_chain([self.rootx509], [], [cert], 'any', user_principal, None) def check_private(self, format_name, passphrase=None): """Check for a private key match""" newkey = asyncssh.read_private_key('new', passphrase) algorithm = newkey.get_algorithm() keydata = newkey.export_private_key() pubdata = newkey.public_data self.assertEqual(newkey, self.privkey) self.assertEqual(hash(newkey), hash(self.privkey)) keypair = asyncssh.load_keypairs(newkey, passphrase)[0] self.assertEqual(keypair.get_key_type(), 'local') self.assertEqual(keypair.get_algorithm(), algorithm) self.assertEqual(keypair.public_data, pubdata) self.assertIsNotNone(keypair.get_agent_private_key()) keypair = asyncssh.load_keypairs([keypair])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs(keydata)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs('new', passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([newkey])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([(newkey, None)])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([keydata])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([(keydata, None)])[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs(['new'], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([('new', None)], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs(Path('new'), passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([Path('new')], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keypair = asyncssh.load_keypairs([(Path('new'), None)], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) keylist = asyncssh.load_keypairs([]) self.assertEqual(keylist, []) if passphrase: with self.assertRaises((asyncssh.KeyEncryptionError, asyncssh.KeyImportError)): asyncssh.load_keypairs('new', 'xxx') if format_name == 'openssh': identities = load_identities(['new']) self.assertEqual(identities[0], pubdata) else: with self.assertRaises(asyncssh.KeyImportError): load_identities(['new']) identities = load_identities(['new'], skip_private=True) self.assertEqual(identities, []) else: newkey.write_private_key('list', format_name) newkey.append_private_key('list', format_name) keylist = asyncssh.read_private_key_list('list') self.assertEqual(keylist[0].public_data, pubdata) self.assertEqual(keylist[1].public_data, pubdata) newkey.write_private_key(Path('list'), format_name) newkey.append_private_key(Path('list'), format_name) keylist = asyncssh.load_keypairs(Path('list')) self.assertEqual(keylist[0].public_data, pubdata) self.assertEqual(keylist[1].public_data, pubdata) if self.x509_supported and format_name[-4:] == '-pem': cert = newkey.generate_x509_user_certificate(newkey, 'OU=user') chain = SSHX509CertificateChain.construct_from_certs([cert]) cert.write_certificate('new_cert') keypair = asyncssh.load_keypairs(('new', 'new_cert'), passphrase)[0] self.assertEqual(keypair.public_data, chain.public_data) self.assertIsNotNone(keypair.get_agent_private_key()) keypair = asyncssh.load_keypairs('new', passphrase, 'new_cert')[0] self.assertEqual(keypair.public_data, chain.public_data) self.assertIsNotNone(keypair.get_agent_private_key()) newkey.write_private_key('new_bundle', format_name, passphrase) cert.append_certificate('new_bundle', 'pem') keypair = asyncssh.load_keypairs('new_bundle', passphrase)[0] self.assertEqual(keypair.public_data, chain.public_data) with self.assertRaises(OSError): asyncssh.load_keypairs(('new', 'not_found'), passphrase) def check_public(self, format_name): """Check for a public key match""" newkey = asyncssh.read_public_key('new') pubkey = newkey.export_public_key() pubdata = newkey.public_data self.assertEqual(newkey, self.pubkey) self.assertEqual(hash(newkey), hash(self.pubkey)) pubkey = asyncssh.load_public_keys('new')[0] self.assertEqual(pubkey, newkey) pubkey = asyncssh.load_public_keys([newkey])[0] self.assertEqual(pubkey, newkey) pubkey = asyncssh.load_public_keys([pubkey])[0] self.assertEqual(pubkey, newkey) pubkey = asyncssh.load_public_keys(['new'])[0] self.assertEqual(pubkey, newkey) pubkey = asyncssh.load_public_keys(Path('new'))[0] self.assertEqual(pubkey, newkey) pubkey = asyncssh.load_public_keys([Path('new')])[0] self.assertEqual(pubkey, newkey) identity = load_identities(['new'])[0] self.assertEqual(identity, pubdata) newkey.write_public_key('list', format_name) newkey.append_public_key('list', format_name) keylist = asyncssh.read_public_key_list('list') self.assertEqual(keylist[0], newkey) self.assertEqual(keylist[1], newkey) newkey.write_public_key(Path('list'), format_name) newkey.append_public_key(Path('list'), format_name) write_file('list', b'Extra text at end of key list\n', 'ab') keylist = asyncssh.load_public_keys(Path('list')) self.assertEqual(keylist[0], newkey) self.assertEqual(keylist[1], newkey) for hash_name in ('md5', 'sha1', 'sha256', 'sha384', 'sha512'): fp = newkey.get_fingerprint(hash_name) if self.use_openssh: # pragma: no branch keygen_fp = run(f'ssh-keygen -l -E {hash_name} -f sshpub') self.assertEqual(fp, keygen_fp.decode('ascii').split()[1]) with self.assertRaises(ValueError): newkey.get_fingerprint('xxx') def check_certificate(self, cert_type, format_name): """Check for a certificate match""" cert = asyncssh.read_certificate('cert') certdata = cert.export_certificate() self.assertEqual(cert.key, self.pubkey) if cert.is_x509: self.validate_x509(cert) else: self.validate_openssh(cert, cert_type, 'name') certlist = asyncssh.load_certificates(cert) self.assertEqual(certlist[0], cert) self.assertEqual(hash(certlist[0]), hash(cert)) if cert.is_x509: self.assertEqual(certlist[0].x509_cert, cert.x509_cert) self.assertEqual(hash(certlist[0].x509_cert), hash(cert.x509_cert)) certlist = asyncssh.load_certificates(certdata) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates([cert]) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates([certdata]) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates('cert') self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates(Path('cert')) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates([Path('cert')]) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates(certdata + b'Extra text in the middle\n' + certdata) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) cert.write_certificate('list', format_name) cert.append_certificate('list', format_name) certlist = asyncssh.load_certificates('list') self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) cert.write_certificate(Path('list'), format_name) cert.append_certificate(Path('list'), format_name) write_file('list', b'Extra text at end of certificate list\n', 'ab') certlist = asyncssh.load_certificates(Path('list')) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) certlist = asyncssh.load_certificates(['list', [cert]]) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) self.assertEqual(certlist[2], cert) certlist = asyncssh.load_certificates(['list', certdata]) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) self.assertEqual(certlist[2], cert) if format_name == 'openssh': certlist = asyncssh.load_certificates(certdata[:-1]) self.assertEqual(certlist[0], cert) certlist = asyncssh.load_certificates(certdata + certdata[:-1]) self.assertEqual(certlist[0], cert) self.assertEqual(certlist[1], cert) certlist = asyncssh.load_certificates(certdata[1:-1]) self.assertEqual(len(certlist), 0) certlist = asyncssh.load_certificates(certdata[1:] + certdata[:-1]) self.assertEqual(len(certlist), 1) self.assertEqual(certlist[0], cert) def import_pkcs1_private(self, fmt, cipher=None, args=None): """Check import of a PKCS#1 private key""" format_name = f'pkcs1-{fmt}' if self.use_openssl: # pragma: no branch if cipher: run(f'openssl {self.keyclass} {args} -in priv -inform pem ' f'-out new -outform {fmt} -passout pass:passphrase') else: run(f'openssl {self.keyclass} -in priv -inform pem ' f'-out new -outform {fmt}') else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher), cipher) self.check_private(format_name, select_passphrase(cipher)) def export_pkcs1_private(self, fmt, cipher=None, legacy_args=None): """Check export of a PKCS#1 private key""" format_name = f'pkcs1-{fmt}' self.privkey.write_private_key('privout', format_name, select_passphrase(cipher), cipher) if self.use_openssl: # pragma: no branch if cipher: run(f'openssl {self.keyclass} {legacy_args} -in privout ' f'-inform {fmt} -out new -outform pem ' '-passin pass:passphrase') else: run(f'openssl {self.keyclass} -in privout -inform {fmt} ' '-out new -outform pem') else: # pragma: no cover priv = asyncssh.read_private_key('privout', select_passphrase(cipher)) priv.write_private_key('new', format_name) self.check_private(format_name) def import_pkcs1_public(self, fmt): """Check import of a PKCS#1 public key""" format_name = f'pkcs1-{fmt}' if (not self.use_openssl or self.keyclass == 'dsa' or _openssl_version < b'OpenSSL 1.0.0'): # pragma: no cover # OpenSSL no longer has support for PKCS#1 DSA, and PKCS#1 # RSA is not supported before OpenSSL 1.0.0, so we only test # against ourselves in these cases. self.pubkey.write_public_key('new', format_name) else: run(f'openssl {self.keyclass} -pubin -in pub -inform pem ' f'-RSAPublicKey_out -out new -outform {fmt}') self.check_public(format_name) def export_pkcs1_public(self, fmt): """Check export of a PKCS#1 public key""" format_name = f'pkcs1-{fmt}' self.privkey.write_public_key('pubout', format_name) if not self.use_openssl or self.keyclass == 'dsa': # pragma: no cover # OpenSSL no longer has support for PKCS#1 DSA, so we can # only test against ourselves. pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', format_name) else: run(f'openssl {self.keyclass} -RSAPublicKey_in -in pubout ' f'-inform {fmt} -out new -outform pem') self.check_public(format_name) def import_pkcs8_private(self, fmt, openssl_ok=True, cipher=None, hash_alg=None, pbe_version=None, args=None): """Check import of a PKCS#8 private key""" format_name = f'pkcs8-{fmt}' if self.use_openssl and openssl_ok: # pragma: no branch if cipher: run(f'openssl pkcs8 -topk8 {args} -in priv -inform pem ' f'-out new -outform {fmt} -passout pass:passphrase') else: run('openssl pkcs8 -topk8 -nocrypt -in priv -inform pem ' f'-out new -outform {fmt}') else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher, pbe_version), cipher, hash_alg, pbe_version) self.check_private(format_name, select_passphrase(cipher, pbe_version)) def export_pkcs8_private(self, fmt, openssl_ok=True, cipher=None, hash_alg=None, pbe_version=None, legacy_args=None): """Check export of a PKCS#8 private key""" format_name = f'pkcs8-{fmt}' self.privkey.write_private_key('privout', format_name, select_passphrase(cipher, pbe_version), cipher, hash_alg, pbe_version) if self.use_openssl and openssl_ok: # pragma: no branch if cipher: run(f'openssl pkcs8 {legacy_args} -in privout -inform {fmt} ' '-out new -outform pem -passin pass:passphrase') else: run(f'openssl pkcs8 -nocrypt -in privout -inform {fmt} ' '-out new -outform pem') else: # pragma: no cover priv = asyncssh.read_private_key( 'privout', select_passphrase(cipher, pbe_version)) priv.write_private_key('new', format_name) self.check_private(format_name) def import_pkcs8_public(self, fmt): """Check import of a PKCS#8 public key""" format_name = f'pkcs8-{fmt}' if self.use_openssl: if _openssl_supports_pkey: run('openssl pkey -pubin -in pub -inform pem -out new ' f'-outform {fmt}') else: # pragma: no cover run(f'openssl {self.keyclass} -pubin -in pub -inform pem ' f'-out new -outform {fmt}') else: # pragma: no cover self.pubkey.write_public_key('new', format_name) self.check_public(format_name) def export_pkcs8_public(self, fmt): """Check export of a PKCS#8 public key""" format_name = f'pkcs8-{fmt}' self.privkey.write_public_key('pubout', format_name) if self.use_openssl: if _openssl_supports_pkey: run(f'openssl pkey -pubin -in pubout -inform {fmt} ' '-out new -outform pem') else: # pragma: no cover run(f'openssl {self.keyclass} -pubin -in pubout ' f'-inform {fmt} -out new -outform pem') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', format_name) self.check_public(format_name) def import_openssh_private(self, openssh_ok=True, cipher=None): """Check import of an OpenSSH private key""" if self.use_openssh and openssh_ok: # pragma: no branch shutil.copy('priv', 'new') if cipher: run(f'ssh-keygen -p -a 1 -N passphrase -Z {cipher} -o -f new') else: run('ssh-keygen -p -N "" -o -f new') else: # pragma: no cover self.privkey.write_private_key('new', 'openssh', select_passphrase(cipher), cipher, rounds=1, ignore_few_rounds=True) self.check_private('openssh', select_passphrase(cipher)) def export_openssh_private(self, openssh_ok=True, cipher=None): """Check export of an OpenSSH private key""" self.privkey.write_private_key('new', 'openssh', select_passphrase(cipher), cipher, rounds=1, ignore_few_rounds=True) if self.use_openssh and openssh_ok: # pragma: no branch os.chmod('new', 0o600) if cipher: run('ssh-keygen -p -P passphrase -N "" -o -f new') else: run('ssh-keygen -p -N "" -o -f new') else: # pragma: no cover priv = asyncssh.read_private_key('new', select_passphrase(cipher)) priv.write_private_key('new', 'openssh') self.check_private('openssh') def import_openssh_public(self): """Check import of an OpenSSH public key""" shutil.copy('sshpub', 'new') self.check_public('openssh') def export_openssh_public(self): """Check export of an OpenSSH public key""" self.privkey.write_public_key('pubout', 'openssh') if self.use_openssh: # pragma: no branch run('ssh-keygen -e -f pubout -m rfc4716 > new') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', 'rfc4716') self.check_public('openssh') def import_openssh_certificate(self, cert_type, cert): """Check import of an OpenSSH certificate""" shutil.copy(cert, 'cert') self.check_certificate(cert_type, 'openssh') def export_openssh_certificate(self, cert_type, cert): """Check export of an OpenSSH certificate""" cert.write_certificate('certout', 'openssh') if self.use_openssh: # pragma: no branch run('ssh-keygen -e -f certout -m rfc4716 > cert') else: # pragma: no cover cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'rfc4716') self.check_certificate(cert_type, 'openssh') def import_rfc4716_public(self): """Check import of an RFC4716 public key""" if self.use_openssh: # pragma: no branch run('ssh-keygen -e -f sshpub -m rfc4716 > new') else: # pragma: no cover self.pubkey.write_public_key('new', 'rfc4716') self.check_public('rfc4716') pubdata = self.pubkey.export_public_key('rfc4716') write_file('new', pubdata.replace(b'\n', b'\nXXX:\n', 1)) self.check_public('rfc4716') def export_rfc4716_public(self): """Check export of an RFC4716 public key""" self.pubkey.write_public_key('pubout', 'rfc4716') if self.use_openssh: # pragma: no branch run('ssh-keygen -i -f pubout -m rfc4716 > new') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', 'openssh') self.check_public('rfc4716') def import_rfc4716_certificate(self, cert_type, cert): """Check import of an RFC4716 certificate""" if self.use_openssh: # pragma: no branch run(f'ssh-keygen -e -f {cert} -m rfc4716 > cert') else: # pragma: no cover if cert_type == CERT_TYPE_USER: cert = self.usercert else: cert = self.hostcert cert.write_certificate('cert', 'rfc4716') self.check_certificate(cert_type, 'rfc4716') def export_rfc4716_certificate(self, cert_type, cert): """Check export of an RFC4716 certificate""" cert.write_certificate('certout', 'rfc4716') if self.use_openssh: # pragma: no branch run('ssh-keygen -i -f certout -m rfc4716 > cert') else: # pragma: no cover cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'rfc4716') def import_der_x509_certificate(self, cert_type, cert): """Check import of a DER X.509 certificate""" cert.write_certificate('cert', 'der') self.check_certificate(cert_type, 'der') def export_der_x509_certificate(self, cert_type, cert): """Check export of a DER X.509 certificate""" cert.write_certificate('certout', 'der') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'der') def import_pem_x509_certificate(self, cert_type, cert, trusted=False): """Check import of a PEM X.509 certificate""" cert.write_certificate('cert', 'pem') if trusted: with open('cert') as f: lines = f.readlines() lines[0] = lines[0][:11] + 'TRUSTED ' + lines[0][11:] idx = lines[-2].find('=') lines[-2] = lines[-2][:idx] + 'XXXX' + lines[-2][idx:] lines[-1] = lines[-1][:9] + 'TRUSTED ' + lines[-1][9:] with open('cert', 'w') as f: f.writelines(lines) self.check_certificate(cert_type, 'pem') def export_pem_x509_certificate(self, cert_type, cert): """Check export of a PEM X.509 certificate""" cert.write_certificate('certout', 'pem') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'openssh') self.check_certificate(cert_type, 'pem') def import_openssh_x509_certificate(self, cert_type, cert): """Check import of an OpenSSH X.509 certificate""" cert.write_certificate('cert') self.check_certificate(cert_type, 'openssh') def export_openssh_x509_certificate(self, cert_type, cert): """Check export of an OpenSSH X.509 certificate""" cert.write_certificate('certout') cert = asyncssh.read_certificate('certout') cert.write_certificate('cert', 'pem') self.check_certificate(cert_type, 'openssh') def check_encode_errors(self): """Check error code paths in key encoding""" for fmt in ('pkcs1-der', 'pkcs1-pem', 'pkcs8-der', 'pkcs8-pem', 'openssh', 'rfc4716', 'xxx'): with self.subTest(f'Encode private from public ({fmt})'): with self.assertRaises(asyncssh.KeyExportError): self.pubkey.export_private_key(fmt) with self.subTest('Encode with unknown key format'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_public_key('xxx') with self.subTest('Encode encrypted pkcs1-der'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_private_key('pkcs1-der', 'x') if self.keyclass == 'ec': with self.subTest('Encode EC public key with PKCS#1'): with self.assertRaises(asyncssh.KeyExportError): self.privkey.export_public_key('pkcs1-pem') if 'pkcs1' in self.private_formats: with self.subTest('Encode with unknown PKCS#1 cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs1-pem', 'x', 'xxx') if 'pkcs8' in self.private_formats: # pragma: no branch with self.subTest('Encode with unknown PKCS#8 cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'xxx') with self.subTest('Encode with unknown PKCS#8 hash'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'aes128-cbc', 'xxx') with self.subTest('Encode with unknown PKCS#8 version'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('pkcs8-pem', 'x', 'aes128-cbc', 'sha1', 3) if bcrypt_available: # pragma: no branch with self.subTest('Encode with unknown openssh cipher'): with self.assertRaises(asyncssh.KeyEncryptionError): self.privkey.export_private_key('openssh', 'x', 'xxx') with self.subTest('Encode agent cert private from public'): with self.assertRaises(asyncssh.KeyExportError): self.pubkey.encode_agent_cert_private() def check_decode_errors(self): """Check error code paths in key decoding""" private_errors = [ ('Non-ASCII', '\xff'), ('Incomplete ASN.1', b''), ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#1 params', der_encode((1, b'', TaggedDERObject(0, b'')))), ('Invalid PKCS#1 EC named curve OID', der_encode((1, b'', TaggedDERObject(0, ObjectIdentifier('1.1'))))), ('Invalid PKCS#8', der_encode((0, (self.privkey.pkcs8_oid, ()), der_encode(None)))), ('Unknown PKCS#8 algorithm', der_encode((0, (ObjectIdentifier('1.1'), None), b''))), ('Invalid PKCS#8 ASN.1', der_encode((0, (self.privkey.pkcs8_oid, None), b''))), ('Invalid PKCS#8 params', der_encode((1, (self.privkey.pkcs8_oid, b''), der_encode((1, b''))))), ('Invalid PEM header', b'-----BEGIN XXX-----\n'), ('Missing PEM footer', b'-----BEGIN PRIVATE KEY-----\n'), ('Invalid PEM key type', b'-----BEGIN XXX PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END XXX PRIVATE KEY-----'), ('Invalid PEM Base64', b'-----BEGIN PRIVATE KEY-----\n' b'X\n' b'-----END PRIVATE KEY-----'), ('Missing PKCS#1 passphrase', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'-----END DSA PRIVATE KEY-----'), ('Incomplete PEM ASN.1', b'-----BEGIN PRIVATE KEY-----\n' b'-----END PRIVATE KEY-----'), ('Missing PEM PKCS#8 passphrase', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#1 key', b'-----BEGIN DSA PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM PKCS#8 key', b'-----BEGIN PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END PRIVATE KEY-----'), ('Unknown format OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'XXX') + b'-----END OPENSSH PRIVATE KEY-----'), ('Incomplete OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'openssh-key-v1\0') + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH nkeys', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String(''), String(''), String(''), UInt32(2), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Missing OpenSSH passphrase', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('xxx'), String(''), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Mismatched OpenSSH check bytes', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(2))))))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH algorithm', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(1), String('xxx'))))))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH pad', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('none'), String(''), String(''), UInt32(1), String(''), String(b''.join((UInt32(1), UInt32(1), String('ssh-dss'), 5*MPInt(0), String(''), b'\0')))))) + b'-----END OPENSSH PRIVATE KEY-----') ] decrypt_errors = [ ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#8', der_encode((0, (self.privkey.pkcs8_oid, ()), der_encode(None)))), ('Invalid PEM params', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: XXX\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM cipher', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: XXX,00\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM IV', b'-----BEGIN DSA PRIVATE KEY-----\n' b'Proc-Type: 4,ENCRYPTED\n' b'DEK-Info: AES-256-CBC,XXX\n' b'-----END DSA PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encrypted data', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encrypted header', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode((None, None))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 encryption algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((None, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_ES1_SHA1_DES, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 salt', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, (b'', 0)), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES1 PKCS#12 iteration count', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_P12_RC4_40, (b'x', 0)), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode(((_ES2, None), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 KDF algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((None, None), (None, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption algorithm', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, None), (None, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, None), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 salt', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (None, None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 iteration count', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 PBKDF2 PRF', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 1, None)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Unknown PEM PKCS#8 PBES2 PBKDF2 PRF', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 1, (ObjectIdentifier('1.1'), None))), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid PEM PKCS#8 PBES2 encryption parameters', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 1)), (_ES2_AES128, None))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid length PEM PKCS#8 PBES2 IV', b'-----BEGIN ENCRYPTED PRIVATE KEY-----\n' + binascii.b2a_base64(der_encode( ((_ES2, ((_ES2_PBKDF2, (b'', 1)), (_ES2_AES128, b''))), b''))) + b'-----END ENCRYPTED PRIVATE KEY-----'), ('Invalid OpenSSH cipher', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('xxx'), String(''), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH kdf', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('xxx'), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH kdf data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(''), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH salt', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(b''), UInt32(128)))), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH encrypted data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(16*b'\0'), UInt32(128)))), UInt32(1), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Unexpected OpenSSH trailing data', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String('aes256-cbc'), String('bcrypt'), String(b''.join((String(16*b'\0'), UInt32(128)))), UInt32(1), String(''), String(''), String('xxx')))) + b'-----END OPENSSH PRIVATE KEY-----') ] public_errors = [ ('Non-ASCII', '\xff'), ('Invalid ASN.1', b'\x30'), ('Invalid PKCS#1', der_encode(None)), ('Invalid PKCS#8', der_encode(((self.pubkey.pkcs8_oid, ()), BitString(der_encode(None))))), ('Unknown PKCS#8 algorithm', der_encode(((ObjectIdentifier('1.1'), None), BitString(b'')))), ('Invalid PKCS#8 ASN.1', der_encode(((self.pubkey.pkcs8_oid, None), BitString(b'')))), ('Invalid PEM header', b'-----BEGIN XXX-----\n'), ('Missing PEM footer', b'-----BEGIN PUBLIC KEY-----\n'), ('Invalid PEM key type', b'-----BEGIN XXX PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END XXX PUBLIC KEY-----'), ('Invalid PEM Base64', b'-----BEGIN PUBLIC KEY-----\n' b'X\n' b'-----END PUBLIC KEY-----'), ('Incomplete PEM ASN.1', b'-----BEGIN PUBLIC KEY-----\n' b'-----END PUBLIC KEY-----'), ('Invalid PKCS#1 ASN.1', b'-----BEGIN DSA PUBLIC KEY-----\n' + binascii.b2a_base64(b'\x30') + b'-----END PUBLIC KEY-----'), ('Invalid PKCS#1 key data', b'-----BEGIN DSA PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END DSA PUBLIC KEY-----'), ('Invalid PKCS#8 key data', b'-----BEGIN PUBLIC KEY-----\n' + binascii.b2a_base64(der_encode(None)) + b'-----END PUBLIC KEY-----'), ('Invalid OpenSSH', b'xxx'), ('Invalid OpenSSH Base64', b'ssh-dss X'), ('Unknown OpenSSH algorithm', b'ssh-dss ' + binascii.b2a_base64(String('xxx'))), ('Invalid OpenSSH body', b'ssh-dss ' + binascii.b2a_base64(String('ssh-dss'))), ('Unknown format OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'XXX') + b'-----END OPENSSH PRIVATE KEY-----'), ('Incomplete OpenSSH key', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b'openssh-key-v1\0') + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid OpenSSH nkeys', b'-----BEGIN OPENSSH PRIVATE KEY-----\n' + binascii.b2a_base64(b''.join( (b'openssh-key-v1\0', String(''), String(''), String(''), UInt32(2), String(''), String('')))) + b'-----END OPENSSH PRIVATE KEY-----'), ('Invalid RFC4716 header', b'---- XXX ----\n'), ('Missing RFC4716 footer', b'---- BEGIN SSH2 PUBLIC KEY ----\n'), ('Invalid RFC4716 header', b'---- BEGIN SSH2 PUBLIC KEY ----\n' b'Comment: comment\n' b'XXX:\\\n' b'---- END SSH2 PUBLIC KEY ----\n'), ('Invalid RFC4716 Base64', b'---- BEGIN SSH2 PUBLIC KEY ----\n' b'X\n' b'---- END SSH2 PUBLIC KEY ----\n') ] keypair_errors = [ ('Mismatched certificate', (self.privca, self.usercert)), ('Invalid signature algorithm string', (self.privkey, None, 'xxx')), ('Invalid signature algorithm bytes', (self.privkey, None, b'xxx')) ] for fmt, data in private_errors: with self.subTest(f'Decode private ({fmt})'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_private_key(data) for fmt, data in decrypt_errors: with self.subTest('fDecrypt private ({fmt})'): with self.assertRaises((asyncssh.KeyEncryptionError, asyncssh.KeyImportError)): asyncssh.import_private_key(data, 'x') for fmt, data in public_errors: with self.subTest(f'Decode public ({fmt})'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_public_key(data) for fmt, key in keypair_errors: with self.subTest(f'Load keypair ({fmt})'): with self.assertRaises(ValueError): asyncssh.load_keypairs([key]) def check_sshkey_base_errors(self): """Check SSHKey base class errors""" key = SSHKey(None) with self.subTest('SSHKey base class errors'): with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs1_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs1_public() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs8_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_pkcs8_public() with self.assertRaises(asyncssh.KeyExportError): key.encode_ssh_private() with self.assertRaises(asyncssh.KeyExportError): key.encode_ssh_public() def check_sign_and_verify(self): """Check key signing and verification""" with self.subTest('Sign/verify test'): data = os.urandom(8) for cert in (None, self.usercert, self.userx509): keypair = asyncssh.load_keypairs([(self.privkey, cert)])[0] for sig_alg in keypair.sig_algorithms: with self.subTest('Good signature', sig_alg=sig_alg): try: keypair.set_sig_algorithm(sig_alg) sig = keypair.sign(data) with self.subTest('Good signature'): self.assertTrue(self.pubkey.verify(data, sig)) badsig = bytearray(sig) badsig[-1] ^= 0xff badsig = bytes(badsig) with self.subTest('Bad signature'): self.assertFalse(self.pubkey.verify(data, badsig)) if sig_alg.startswith(b'webauthn-'): idx = sig.rfind(b'ssh:') badpfx = bytearray(sig) badpfx[idx] = ord('x') badpfx = bytes(badpfx) with self.subTest('Bad prefix'): self.assertFalse(self.pubkey.verify(data, badpfx)) except UnsupportedAlgorithm: # pragma: no cover pass with self.subTest('Missing signature'): self.assertFalse(self.pubkey.verify( data, String(self.pubkey.sig_algorithms[0]))) with self.subTest('Empty signature'): self.assertFalse(self.pubkey.verify( data, String(self.pubkey.sig_algorithms[0]) + String(b''))) with self.subTest('Sign with bad algorithm'): with self.assertRaises(ValueError): self.privkey.sign(data, b'xxx') with self.subTest('Verify with bad algorithm'): self.assertFalse(self.pubkey.verify( data, String('xxx') + String(''))) with self.subTest('Sign with public key'): with self.assertRaises(ValueError): self.pubkey.sign(data, self.pubkey.sig_algorithms[0]) def check_set_certificate(self): """Check setting certificate on existing keypair""" keypair = asyncssh.load_keypairs([self.privkey])[0] keypair.set_certificate(self.usercert) self.assertEqual(keypair.public_data, self.usercert.public_data) keypair = asyncssh.load_keypairs(self.privkey)[0] keypair = asyncssh.load_keypairs((keypair, self.usercert))[0] self.assertEqual(keypair.public_data, self.usercert.public_data) key2 = get_test_key('ssh-rsa', 1) with self.assertRaises(ValueError): asyncssh.load_keypairs((key2, self.usercert)) def check_comment(self): """Check getting and setting comments""" with self.subTest('Comment test'): self.assertEqual(self.privkey.get_comment_bytes(), b'comment') self.assertEqual(self.privkey.get_comment(), 'comment') self.assertEqual(self.pubkey.get_comment_bytes(), b'pub_comment') self.assertEqual(self.pubkey.get_comment(), 'pub_comment') key = asyncssh.import_private_key( self.privkey.export_private_key('openssh')) self.assertEqual(key.get_comment_bytes(), b'comment') self.assertEqual(key.get_comment(), 'comment') key.set_comment('new_comment') self.assertEqual(key.get_comment_bytes(), b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') key.set_comment(b'new_comment') self.assertEqual(key.get_comment_bytes(), b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') key.set_comment(b'\xff') self.assertEqual(key.get_comment_bytes(), b'\xff') with self.assertRaises(UnicodeDecodeError): key.get_comment() cert = asyncssh.import_certificate( self.usercert.export_certificate()) cert.set_comment(b'\xff') self.assertEqual(cert.get_comment_bytes(), b'\xff') with self.assertRaises(UnicodeDecodeError): cert.get_comment() if self.x509_supported: cert = asyncssh.import_certificate( self.userx509.export_certificate()) cert.set_comment(b'\xff') self.assertEqual(cert.get_comment_bytes(), b'\xff') with self.assertRaises(UnicodeDecodeError): cert.get_comment() for fmt in ('openssh', 'rfc4716'): key = asyncssh.import_public_key( self.pubkey.export_public_key(fmt)) self.assertEqual(key.get_comment_bytes(), b'pub_comment') self.assertEqual(key.get_comment(), 'pub_comment') key = asyncssh.import_public_key( self.pubca.export_public_key(fmt)) self.assertEqual(key.get_comment_bytes(), None) self.assertEqual(key.get_comment(), None) key.set_comment('new_comment') self.assertEqual(key.get_comment_bytes(), b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') key.set_comment(b'new_comment') self.assertEqual(key.get_comment_bytes(), b'new_comment') self.assertEqual(key.get_comment(), 'new_comment') for fmt in ('openssh', 'rfc4716'): cert = asyncssh.import_certificate( self.usercert.export_certificate(fmt)) self.assertEqual(cert.get_comment_bytes(), b'user_comment') self.assertEqual(cert.get_comment(), 'user_comment') cert = self.privca.generate_user_certificate( self.pubkey, 'name', principals='name1,name2', comment='cert_comment') self.assertEqual(cert.principals, ['name1', 'name2']) self.assertEqual(cert.get_comment_bytes(), b'cert_comment') self.assertEqual(cert.get_comment(), 'cert_comment') cert = asyncssh.import_certificate( self.hostcert.export_certificate(fmt)) self.assertEqual(cert.get_comment_bytes(), b'host_comment') self.assertEqual(cert.get_comment(), 'host_comment') cert = self.privca.generate_host_certificate( self.pubkey, 'name', principals=['name1', 'name2'], comment=b'\xff') self.assertEqual(cert.principals, ['name1', 'name2']) self.assertEqual(cert.get_comment_bytes(), b'\xff') with self.assertRaises(UnicodeDecodeError): cert.get_comment() cert.set_comment('new_comment') self.assertEqual(cert.get_comment_bytes(), b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') cert.set_comment(b'new_comment') self.assertEqual(cert.get_comment_bytes(), b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') if self.x509_supported: for fmt in ('openssh', 'der', 'pem'): cert = asyncssh.import_certificate( self.rootx509.export_certificate(fmt)) self.assertEqual(cert.get_comment_bytes(), None) self.assertEqual(cert.get_comment(), None) cert = self.privca.generate_x509_ca_certificate( self.pubkey, 'OU=root', comment='ca_comment') self.assertEqual(cert.get_comment_bytes(), b'ca_comment') self.assertEqual(cert.get_comment(), 'ca_comment') cert = asyncssh.import_certificate( self.userx509.export_certificate(fmt)) self.assertEqual(cert.get_comment_bytes(), b'user_comment') self.assertEqual(cert.get_comment(), 'user_comment') cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', comment='user_comment') self.assertEqual(cert.get_comment_bytes(), b'user_comment') self.assertEqual(cert.get_comment(), 'user_comment') cert = asyncssh.import_certificate( self.hostx509.export_certificate(fmt)) self.assertEqual(cert.get_comment_bytes(), b'host_comment') self.assertEqual(cert.get_comment(), 'host_comment') cert = self.privca.generate_x509_host_certificate( self.pubkey, 'OU=host', 'OU=root', comment='host_comment') self.assertEqual(cert.get_comment_bytes(), b'host_comment') self.assertEqual(cert.get_comment(), 'host_comment') cert.set_comment('new_comment') self.assertEqual(cert.get_comment_bytes(), b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') cert.set_comment(b'new_comment') self.assertEqual(cert.get_comment_bytes(), b'new_comment') self.assertEqual(cert.get_comment(), 'new_comment') keypair = asyncssh.load_keypairs([self.privkey])[0] self.assertEqual(keypair.get_comment_bytes(), b'comment') self.assertEqual(keypair.get_comment(), 'comment') keypair.set_comment('new_comment') self.assertEqual(keypair.get_comment_bytes(), b'new_comment') self.assertEqual(keypair.get_comment(), 'new_comment') keypair.set_comment(b'new_comment') self.assertEqual(keypair.get_comment_bytes(), b'new_comment') self.assertEqual(keypair.get_comment(), 'new_comment') keypair.set_comment(b'\xff') self.assertEqual(keypair.get_comment_bytes(), b'\xff') with self.assertRaises(UnicodeDecodeError): keypair.get_comment() priv = asyncssh.read_private_key('priv') priv.set_comment(None) keypair = asyncssh.load_keypairs((priv, self.pubkey))[0] self.assertEqual(keypair.get_comment(), 'pub_comment') keypair = asyncssh.load_keypairs((priv, self.usercert))[0] self.assertEqual(keypair.get_comment(), 'user_comment') keypair = asyncssh.load_keypairs(priv, None, self.usercert)[0] self.assertEqual(keypair.get_comment(), 'user_comment') pubdata = self.pubkey.export_public_key() keypair = asyncssh.load_keypairs((priv, pubdata))[0] self.assertEqual(keypair.get_comment(), 'pub_comment') certdata = self.usercert.export_certificate() keypair = asyncssh.load_keypairs((priv, certdata))[0] self.assertEqual(keypair.get_comment(), 'user_comment') keypair = asyncssh.load_keypairs(priv, None, certdata)[0] self.assertEqual(keypair.get_comment(), 'user_comment') priv.write_private_key('key') keypair = asyncssh.load_keypairs('key')[0] self.assertEqual(keypair.get_comment(), 'key') keypair = asyncssh.load_keypairs(('key', 'sshpub'))[0] self.assertEqual(keypair.get_comment(), 'pub_comment') keypair = asyncssh.load_keypairs(('key', 'usercert'))[0] self.assertEqual(keypair.get_comment(), 'user_comment') keypair = asyncssh.load_keypairs('key', None, 'usercert')[0] self.assertEqual(keypair.get_comment(), 'user_comment') self.pubkey.write_public_key('key.pub') keypair = asyncssh.load_keypairs('key')[0] self.assertEqual(keypair.get_comment(), 'pub_comment') self.usercert.write_certificate('key-cert.pub') keypair = asyncssh.load_keypairs('key')[0] self.assertEqual(keypair.get_comment(), 'user_comment') keypair = asyncssh.load_keypairs('key')[1] self.assertEqual(keypair.get_comment(), 'pub_comment') keypair = asyncssh.load_keypairs(('key', None))[0] self.assertEqual(keypair.get_comment(), 'pub_comment') key2 = get_test_key('ssh-rsa', 1) with self.assertRaises(ValueError): asyncssh.load_keypairs((key2, 'pub')) for f in ('key', 'key.pub', 'key-cert.pub'): os.remove(f) def check_pkcs1_private(self): """Check PKCS#1 private key format""" with self.subTest('Import PKCS#1 PEM private'): self.import_pkcs1_private('pem') with self.subTest('Export PKCS#1 PEM private'): self.export_pkcs1_private('pem') with self.subTest('Import PKCS#1 DER private'): self.import_pkcs1_private('der') with self.subTest('Export PKCS#1 DER private'): self.export_pkcs1_private('der') for cipher, args, legacy in pkcs1_ciphers: legacy_args = _openssl_legacy if legacy else '' with self.subTest(f'Import PKCS#1 PEM private ({cipher})'): self.import_pkcs1_private('pem', cipher, legacy_args + args) with self.subTest(f'Export PKCS#1 PEM private ({cipher})'): self.export_pkcs1_private('pem', cipher, legacy_args) def check_pkcs1_public(self): """Check PKCS#1 public key format""" with self.subTest('Import PKCS#1 PEM public'): self.import_pkcs1_public('pem') with self.subTest('Export PKCS#1 PEM public'): self.export_pkcs1_public('pem') with self.subTest('Import PKCS#1 DER public'): self.import_pkcs1_public('der') with self.subTest('Export PKCS#1 DER public'): self.export_pkcs1_public('der') def check_pkcs8_private(self): """Check PKCS#8 private key format""" with self.subTest('Import PKCS#8 PEM private'): self.import_pkcs8_private('pem') with self.subTest('Export PKCS#8 PEM private'): self.export_pkcs8_private('pem') with self.subTest('Import PKCS#8 DER private'): self.import_pkcs8_private('der') with self.subTest('Export PKCS#8 DER private'): self.export_pkcs8_private('der') for cipher, hash_alg, pbe_version, args, \ openssl_ok, legacy in pkcs8_ciphers: legacy_args = _openssl_legacy if legacy else '' with self.subTest(f'Import PKCS#8 PEM private ({cipher}-' f'{hash_alg}-v{pbe_version})'): self.import_pkcs8_private('pem', openssl_ok, cipher, hash_alg, pbe_version, legacy_args + args) with self.subTest(f'Export PKCS#8 PEM private ({cipher}-' f'{hash_alg}-v{pbe_version})'): self.export_pkcs8_private('pem', openssl_ok, cipher, hash_alg, pbe_version, legacy_args) with self.subTest(f'Import PKCS#8 DER private ({cipher}-' f'{hash_alg}-v{pbe_version})'): self.import_pkcs8_private('der', openssl_ok, cipher, hash_alg, pbe_version, legacy_args + args) with self.subTest(f'Export PKCS#8 DER private ({cipher}-' f'{hash_alg}-v{pbe_version})'): self.export_pkcs8_private('der', openssl_ok, cipher, hash_alg, pbe_version, legacy_args) if self.single_cipher: break def check_pkcs8_public(self): """Check PKCS#8 public key format""" with self.subTest('Import PKCS#8 PEM public'): self.import_pkcs8_public('pem') with self.subTest('Export PKCS#8 PEM public'): self.export_pkcs8_public('pem') with self.subTest('Import PKCS#8 DER public'): self.import_pkcs8_public('der') with self.subTest('Export PKCS#8 DER public'): self.export_pkcs8_public('der') def check_openssh_private(self): """Check OpenSSH private key format""" with self.subTest('Import OpenSSH private'): self.import_openssh_private() with self.subTest('Export OpenSSH private'): self.export_openssh_private() if bcrypt_available: # pragma: no branch for cipher, openssh_ok in openssh_ciphers: with self.subTest(f'Import OpenSSH private ({cipher})'): self.import_openssh_private(openssh_ok, cipher) with self.subTest(f'Export OpenSSH private ({cipher})'): self.export_openssh_private(openssh_ok, cipher) if self.single_cipher: break def check_openssh_public(self): """Check OpenSSH public key format""" with self.subTest('Import OpenSSH public'): self.import_openssh_public() with self.subTest('Export OpenSSH public'): self.export_openssh_public() def check_openssh_certificate(self): """Check OpenSSH certificate format""" with self.subTest('Import OpenSSH user certificate'): self.import_openssh_certificate(CERT_TYPE_USER, 'usercert') with self.subTest('Export OpenSSH user certificate'): self.export_openssh_certificate(CERT_TYPE_USER, self.usercert) with self.subTest('Import OpenSSH host certificate'): self.import_openssh_certificate(CERT_TYPE_HOST, 'hostcert') with self.subTest('Export OpenSSH host certificate'): self.export_openssh_certificate(CERT_TYPE_HOST, self.hostcert) def check_rfc4716_public(self): """Check RFC4716 public key format""" with self.subTest('Import RFC4716 public'): self.import_rfc4716_public() with self.subTest('Export RFC4716 public'): self.export_rfc4716_public() def check_rfc4716_certificate(self): """Check RFC4716 certificate format""" with self.subTest('Import RFC4716 user certificate'): self.import_rfc4716_certificate(CERT_TYPE_USER, 'usercert') with self.subTest('Export RFC4716 user certificate'): self.export_rfc4716_certificate(CERT_TYPE_USER, self.usercert) with self.subTest('Import RFC4716 host certificate'): self.import_rfc4716_certificate(CERT_TYPE_HOST, 'hostcert') with self.subTest('Export RFC4716 host certificate'): self.export_rfc4716_certificate(CERT_TYPE_HOST, self.hostcert) def check_der_x509_certificate(self): """Check DER X.509 certificate format""" with self.subTest('Import DER X.509 user certificate'): self.import_der_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Export DER X.509 user certificate'): self.export_der_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import DER X.509 host certificate'): self.import_der_x509_certificate(CERT_TYPE_HOST, self.hostx509) with self.subTest('Export DER X.509 host certificate'): self.export_der_x509_certificate(CERT_TYPE_HOST, self.hostx509) def check_pem_x509_certificate(self): """Check PEM X.509 certificate format""" with self.subTest('Import PEM X.509 user certificate'): self.import_pem_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Export PEM X.509 user certificate'): self.export_pem_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import PEM X.509 host certificate'): self.import_pem_x509_certificate(CERT_TYPE_HOST, self.hostx509) with self.subTest('Export PEM X.509 host certificate'): self.export_pem_x509_certificate(CERT_TYPE_HOST, self.hostx509) with self.subTest('Import PEM X.509 trusted user certificate'): self.import_pem_x509_certificate(CERT_TYPE_USER, self.userx509, trusted=True) with self.subTest('Import PEM X.509 trusted host certificate'): self.import_pem_x509_certificate(CERT_TYPE_HOST, self.hostx509, trusted=True) def check_openssh_x509_certificate(self): """Check OpenSSH X.509 certificate format""" with self.subTest('Import OpenSSH X.509 user certificate'): self.import_openssh_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Export OpenSSH X.509 user certificate'): self.export_openssh_x509_certificate(CERT_TYPE_USER, self.userx509) with self.subTest('Import OpenSSH X.509 host certificate'): self.import_openssh_x509_certificate(CERT_TYPE_HOST, self.hostx509) with self.subTest('Export OpenSSH X.509 host certificate'): self.export_openssh_x509_certificate(CERT_TYPE_HOST, self.hostx509) def check_certificate_options(self): """Check SSH certificate options""" cert = self.privca.generate_user_certificate( self.pubkey, 'name', force_command='command', source_address=['1.2.3.4'], permit_x11_forwarding=False, permit_agent_forwarding=False, permit_port_forwarding=False, permit_pty=False, permit_user_rc=False, touch_required=False) cert.write_certificate('cert') self.check_certificate(CERT_TYPE_USER, 'openssh') for valid_after, valid_before in ((0, 1.), (datetime.now(), '+1m'), ('20160101', '20160102'), ('20160101000000', '20160102235959'), ('now', '1w2d3h4m5s'), ('-52w', '+52w')): cert = self.privca.generate_host_certificate( self.pubkey, 'name', valid_after=valid_after, valid_before=valid_before) cert.write_certificate('cert') cert2 = asyncssh.read_certificate('cert') self.assertEqual(cert2.public_data, cert.public_data) def check_certificate_errors(self, cert_type): """Check general and OpenSSH certificate error cases""" with self.subTest('Non-ASCII certificate'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('\u0080\n') with self.subTest('Invalid SSH format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('xxx\n') with self.subTest('Invalid certificate packetization'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate( b'xxx ' + binascii.b2a_base64(b'\x00')) with self.subTest('Invalid certificate algorithm'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate( b'xxx ' + binascii.b2a_base64(String(b'xxx'))) with self.subTest('Invalid certificate critical option'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={b'xxx': b''}) asyncssh.import_certificate(cert) with self.subTest('Ignored certificate extension'): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), extensions={b'xxx': b''}) self.assertIsNotNone(asyncssh.import_certificate(cert)) with self.subTest('Invalid certificate signature'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), bad_signature=True) asyncssh.import_certificate(cert) with self.subTest('Invalid characters in certificate key ID'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), key_id=b'\xff') asyncssh.import_certificate(cert) with self.subTest('Invalid characters in certificate principal'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, (b'\xff',)) asyncssh.import_certificate(cert) if cert_type == CERT_TYPE_USER: with self.subTest('Invalid characters in force-command'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'force-command': String(b'\xff')}) asyncssh.import_certificate(cert) with self.subTest('Invalid characters in source-address'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'source-address': String(b'\xff')}) asyncssh.import_certificate(cert) with self.subTest('Invalid IP network in source-address'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), options={'source-address': String('1.1.1.256')}) asyncssh.import_certificate(cert) with self.subTest('Invalid certificate type'): with self.assertRaises(asyncssh.KeyImportError): cert = self.make_certificate(0, self.pubkey, self.privca, ('name',)) asyncssh.import_certificate(cert) with self.subTest('Mismatched certificate type'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',)) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type ^ 3, 'name') with self.subTest('Certificate not yet valid'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), valid_after=0xffffffffffffffff) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name') with self.subTest('Certificate expired'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',), valid_before=0) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name') with self.subTest('Certificate principal mismatch'): with self.assertRaises(ValueError): cert = self.make_certificate(cert_type, self.pubkey, self.privca, ('name',)) cert = asyncssh.import_certificate(cert) self.validate_openssh(cert, cert_type, 'name2') for fmt in ('der', 'pem', 'xxx'): with self.subTest('Invalid certificate export format', fmt=fmt): with self.assertRaises(asyncssh.KeyExportError): self.usercert.export_certificate(fmt) def check_x509_certificate_errors(self): """Check X.509 certificate error cases""" with self.subTest('Invalid DER format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate(b'\x30\x00') with self.subTest('Invalid DER format in certificate list'): with self.assertRaises(asyncssh.KeyImportError): write_file('certlist', b'\x30\x00') asyncssh.read_certificate_list('certlist') with self.subTest('Invalid PEM format'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----') with self.subTest('Invalid PEM certificate type'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN XXX CERTIFICATE-----\n' '-----END XXX CERTIFICATE-----\n') with self.subTest('Missing PEM footer'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n') with self.subTest('Invalid PEM Base64'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n' 'X\n' '-----END CERTIFICATE-----\n') with self.subTest('Invalid PEM trusted certificate'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate( '-----BEGIN TRUSTED CERTIFICATE-----\n' 'MA==\n' '-----END TRUSTED CERTIFICATE-----\n') with self.subTest('Invalid PEM certificate data'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_certificate('-----BEGIN CERTIFICATE-----\n' 'XXXX\n' '-----END CERTIFICATE-----\n') with self.subTest('Certificate not yet valid'): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', valid_after=0xfffffffffffffffe) with self.assertRaises(ValueError): self.validate_x509(cert) with self.subTest('Certificate expired'): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', valid_before=1) with self.assertRaises(ValueError): self.validate_x509(cert) with self.subTest('Certificate principal mismatch'): cert = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', principals=['name']) with self.assertRaises(ValueError): self.validate_x509(cert, 'name2') for fmt in ('rfc4716', 'xxx'): with self.subTest('Invalid certificate export format', fmt=fmt): with self.assertRaises(asyncssh.KeyExportError): self.userx509.export_certificate(fmt) with self.subTest('Empty certificate chain'): with self.assertRaises(asyncssh.KeyImportError): decode_ssh_certificate(String('x509v3-ssh-rsa') + UInt32(0) + UInt32(0)) def check_x509_certificate_subject(self): """Check X.509 certificate subject cases""" with self.subTest('Missing certificate subject algorithm'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('xxx') with self.subTest('Unknown certificate subject algorithm'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('xxx subject=OU=name') with self.subTest('Invalid certificate subject'): with self.assertRaises(asyncssh.KeyImportError): import_certificate_subject('x509v3-ssh-rsa xxx') subject = import_certificate_subject('x509v3-ssh-rsa subject=OU=name') self.assertEqual(subject, 'OU=name') def test_keys(self): """Check keys and certificates""" for alg_name, kwargs in self.generate_args: with self.subTest(alg_name=alg_name, **kwargs): self.privkey = get_test_key( alg_name, comment='comment', **kwargs) self.privkey.write_private_key('priv', self.base_format) self.pubkey = self.privkey.convert_to_public() self.pubkey.set_comment('pub_comment') self.pubkey.write_public_key('pub', self.base_format) self.pubkey.write_public_key('sshpub', 'openssh') self.privca = get_test_key(alg_name, 1, **kwargs) self.privca.write_private_key('privca', self.base_format) self.pubca = self.privca.convert_to_public() self.pubca.write_public_key('pubca', self.base_format) self.usercert = self.privca.generate_user_certificate( self.pubkey, 'name', comment='user_comment') self.usercert.write_certificate('usercert') hostcert_sig_alg = self.privca.sig_algorithms[0].decode() self.hostcert = self.privca.generate_host_certificate( self.pubkey, 'name', sig_alg=hostcert_sig_alg, comment='host_comment') self.hostcert.write_certificate('hostcert') for f in ('priv', 'privca'): os.chmod(f, 0o600) self.assertEqual(self.privkey.get_algorithm(), alg_name) self.assertEqual(self.usercert.get_algorithm(), self.default_cert_version) if self.x509_supported: self.rootx509 = self.privca.generate_x509_ca_certificate( self.pubca, 'OU=root') self.rootx509.write_certificate('rootx509') self.userx509 = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=user', 'OU=root', comment='user_comment') self.assertEqual(self.userx509.get_algorithm(), 'x509v3-' + alg_name) self.userx509.write_certificate('userx509') self.hostx509 = self.privca.generate_x509_host_certificate( self.pubkey, 'OU=host', 'OU=root', comment='host_comment') self.hostx509.write_certificate('hostx509') self.otherx509 = self.privca.generate_x509_user_certificate( self.pubkey, 'OU=other', 'OU=root') self.otherx509.write_certificate('otherx509') self.check_encode_errors() self.check_decode_errors() self.check_sshkey_base_errors() self.check_sign_and_verify() self.check_set_certificate() self.check_comment() if 'pkcs1' in self.private_formats: self.check_pkcs1_private() if 'pkcs1' in self.public_formats: self.check_pkcs1_public() if 'pkcs8' in self.private_formats: # pragma: no branch self.check_pkcs8_private() if 'pkcs8' in self.public_formats: # pragma: no branch self.check_pkcs8_public() self.check_openssh_private() self.check_openssh_public() self.check_openssh_certificate() self.check_rfc4716_public() self.check_rfc4716_certificate() self.check_certificate_options() for cert_type in (CERT_TYPE_USER, CERT_TYPE_HOST): self.check_certificate_errors(cert_type) if self.x509_supported: self.check_der_x509_certificate() self.check_pem_x509_certificate() self.check_openssh_x509_certificate() self.check_x509_certificate_errors() self.check_x509_certificate_subject() class TestDSA(_TestPublicKey): """Test DSA keys""" keyclass = 'dsa' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs1', 'pkcs8', 'openssh', 'rfc4716') default_cert_version = 'ssh-dss-cert-v01@openssh.com' x509_supported = x509_available generate_args = (('ssh-dss', {}),) use_openssh = False class TestRSA(_TestPublicKey): """Test RSA keys""" keyclass = 'rsa' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs1', 'pkcs8', 'openssh', 'rfc4716') default_cert_version = 'ssh-rsa-cert-v01@openssh.com' x509_supported = x509_available generate_args = (('ssh-rsa', {'key_size': 1024}), ('ssh-rsa', {'key_size': 2048}), ('ssh-rsa', {'key_size': 3072}), ('ssh-rsa', {'exponent': 3})) class TestECDSA(_TestPublicKey): """Test ECDSA keys""" keyclass = 'ec' base_format = 'pkcs8-pem' private_formats = ('pkcs1', 'pkcs8', 'openssh') public_formats = ('pkcs8', 'openssh', 'rfc4716') x509_supported = x509_available generate_args = (('ecdsa-sha2-nistp256', {}), ('ecdsa-sha2-nistp384', {}), ('ecdsa-sha2-nistp521', {})) @property def default_cert_version(self): """Return default SSH certificate version""" return self.privkey.algorithm.decode('ascii') + '-cert-v01@openssh.com' @unittest.skipUnless(ed25519_available, 'ed25519 not available') class TestEd25519(_TestPublicKey): """Test Ed25519 keys""" keyclass = 'ed25519' base_format = 'pkcs8-pem' private_formats = ('pkcs8', 'openssh') public_formats = ('pkcs8', 'openssh', 'rfc4716') x509_supported = x509_available default_cert_version = 'ssh-ed25519-cert-v01@openssh.com' generate_args = (('ssh-ed25519', {}),) single_cipher = False use_openssh = False use_openssl = _openssl_supports_pkey @unittest.skipUnless(ed448_available, 'ed448 not available') class TestEd448(_TestPublicKey): """Test Ed448 keys""" keyclass = 'ed448' base_format = 'pkcs8-pem' private_formats = ('pkcs8', 'openssh') public_formats = ('pkcs8', 'openssh', 'rfc4716') x509_supported = x509_available default_cert_version = 'ssh-ed448-cert-v01@openssh.com' generate_args = (('ssh-ed448', {}),) use_openssh = False use_openssl = _openssl_supports_pkey @unittest.skipUnless(sk_available, 'security key support not available') class TestSKECDSA(_TestPublicKey): """Test U2F ECDSA keys""" keyclass = 'sk-ecdsa' base_format = 'openssh' private_formats = ('openssh',) public_formats = ('openssh',) generate_args = (('sk-ecdsa-sha2-nistp256@openssh.com', {}),) use_openssh = False def setUp(self): """Set up ECDSA security key test""" super().setUp() self.addCleanup(unstub_sk, *stub_sk([1])) @property def default_cert_version(self): """Return default SSH certificate version""" return self.privkey.algorithm.decode('ascii')[:-12] + \ '-cert-v01@openssh.com' @unittest.skipUnless(sk_available, 'security key support not available') @unittest.skipUnless(ed25519_available, 'ed25519 not available') class TestSKEd25519(_TestPublicKey): """Test U2F Ed25519 keys""" keyclass = 'sk-ed25519' base_format = 'openssh' private_formats = ('openssh',) public_formats = ('openssh',) default_cert_version = 'sk-ssh-ed25519-cert-v01@openssh.com' generate_args = (('sk-ssh-ed25519@openssh.com', {}),) use_openssh = False def setUp(self): """Set up Ed25519 security key test""" super().setUp() self.addCleanup(unstub_sk, *stub_sk([2])) del _TestPublicKey class _TestPublicKeyTopLevel(TempDirTestCase): """Top-level public key module tests""" def test_public_key(self): """Test public key top-level functions""" self.assertIsNotNone(get_public_key_algs()) self.assertIsNotNone(get_certificate_algs()) self.assertEqual(bool(get_x509_certificate_algs()), x509_available) def test_public_key_algorithm_mismatch(self): """Test algorithm mismatch in SSH public key""" privkey = get_test_key('ssh-rsa') keydata = privkey.export_public_key('openssh') keydata = b'ssh-dss ' + keydata.split(None, 1)[1] with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_public_key(keydata) write_file('list', keydata) with self.assertRaises(asyncssh.KeyImportError): asyncssh.read_public_key_list('list') def test_pad_error(self): """Test for missing RFC 1423 padding on PBE decrypt""" with self.assertRaises(asyncssh.KeyEncryptionError): pkcs1_decrypt(b'', b'AES-128-CBC', os.urandom(16), 'x') def test_ec_explicit(self): """Test EC certificate with explicit parameters""" if _openssl_available: # pragma: no branch for curve in ('secp256r1', 'secp384r1', 'secp521r1'): with self.subTest('Import EC key with explicit parameters', curve=curve): run('openssl ecparam -out priv -noout -genkey ' f'-name {curve} -param_enc explicit') asyncssh.read_private_key('priv') @unittest.skipIf(not _openssl_available, "openssl isn't available") @unittest.skipIf(b'secp224r1' not in _openssl_curves, "this openssl doesn't support secp224r1") def test_ec_explicit_unknown(self): """Import EC key with unknown explicit parameters""" run('openssl ecparam -out priv -noout -genkey -name secp224r1 ' '-param_enc explicit') with self.assertRaises(asyncssh.KeyImportError): asyncssh.read_private_key('priv') def test_generate_errors(self): """Test errors in private key and certificate generation""" for alg_name, kwargs in (('xxx', {}), ('ssh-dss', {'xxx': 0}), ('ssh-rsa', {'xxx': 0}), ('ecdsa-sha2-nistp256', {'xxx': 0}), ('ssh-ed25519', {'xxx': 0}), ('ssh-ed448', {'xxx': 0})): with self.subTest(alg_name=alg_name, **kwargs): with self.assertRaises(asyncssh.KeyGenerationError): asyncssh.generate_private_key(alg_name, **kwargs) privkey = get_test_key('ssh-rsa') pubkey = privkey.convert_to_public() privca = get_test_key('ssh-rsa', 1) with self.assertRaises(asyncssh.KeyGenerationError): privca.generate_user_certificate(pubkey, 'name', version=0) with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after=()) with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after='xxx') with self.assertRaises(ValueError): privca.generate_user_certificate(pubkey, 'name', valid_after='now', valid_before='-1m') with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after=()) with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after='xxx') with self.assertRaises(ValueError): privca.generate_x509_user_certificate(pubkey, 'OU=user', valid_after='now', valid_before='-1m') privca.x509_algorithms = None with self.assertRaises(asyncssh.KeyGenerationError): privca.generate_x509_user_certificate(pubkey, 'OU=user') def test_rsa_encrypt_error(self): """Test RSA encryption error""" privkey = get_test_key('ssh-rsa', 2048) pubkey = privkey.convert_to_public() self.assertIsNone(pubkey.encrypt(os.urandom(256), pubkey.algorithm)) def test_rsa_decrypt_error(self): """Test RSA decryption error""" privkey = get_test_key('ssh-rsa', 2048) self.assertIsNone(privkey.decrypt(b'', privkey.algorithm)) @unittest.skipUnless(x509_available, 'x509 not available') def test_x509_certificate_hashes(self): """Test X.509 certificate hash algorithms""" privkey = get_test_key('ssh-rsa') pubkey = privkey.convert_to_public() for hash_alg in ('sha256', 'sha512'): cert = privkey.generate_x509_user_certificate( pubkey, 'OU=user', hash_alg=hash_alg) cert.write_certificate('cert', 'pem') cert2 = asyncssh.read_certificate('cert') self.assertEqual(str(cert2.subject), 'OU=user') asyncssh-2.20.0/tests/test_saslprep.py000066400000000000000000000064341475467777400200750ustar00rootroot00000000000000# Copyright (c) 2015-2018 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for SASL string preparation""" import unittest from asyncssh.saslprep import saslprep, SASLPrepError class _TestSASLPrep(unittest.TestCase): """Unit tests for saslprep module""" def test_nonstring(self): """Test passing a non-string value""" with self.assertRaises(TypeError): saslprep(b'xxx') def test_unassigned(self): """Test passing strings with unassigned code points""" for s in ('\u0221', '\u038b', '\u0510', '\u070e', '\u0900', '\u0a00'): with self.assertRaises(SASLPrepError, msg=f'U+{ord(s):08x}'): saslprep('abc' + s + 'def') def test_map_to_nothing(self): """Test passing strings with characters that map to nothing""" for s in ('\u00ad', '\u034f', '\u1806', '\u200c', '\u2060', '\ufe00'): self.assertEqual(saslprep('abc' + s + 'def'), 'abcdef', msg=f'U+{ord(s):08x}') def test_map_to_whitespace(self): """Test passing strings with characters that map to whitespace""" for s in ('\u00a0', '\u1680', '\u2000', '\u202f', '\u205f', '\u3000'): self.assertEqual(saslprep('abc' + s + 'def'), 'abc def', msg=f'U+{ord(s):08x}') def test_normalization(self): """Test Unicode normalization form KC conversions""" for (s, n) in (('\u00aa', 'a'), ('\u2168', 'IX')): self.assertEqual(saslprep('abc' + s + 'def'), 'abc' + n + 'def', msg=f'U+{ord(s):08x}') def test_prohibited(self): """Test passing strings with prohibited characters""" for s in ('\u0000', '\u007f', '\u0080', '\u06dd', '\u180e', '\u200e', '\u2028', '\u202a', '\u206a', '\u2ff0', '\u2ffb', '\ud800', '\udfff', '\ue000', '\ufdd0', '\ufef9', '\ufffc', '\uffff', '\U0001d173', '\U000E0001', '\U00100000', '\U0010fffd'): with self.assertRaises(SASLPrepError, msg=f'U+{ord(s):08x}'): saslprep('abc' + s + 'def') def test_bidi(self): """Test passing strings with bidirectional characters""" for s in ('\u05be\u05c0\u05c3\u05d0', # RorAL only 'abc\u00c0\u00c1\u00c2', # L only '\u0627\u0031\u0628'): # Mix of RorAL and other self.assertEqual(saslprep(s), s) with self.assertRaises(SASLPrepError): saslprep('abc\u05be\u05c0\u05c3') # Mix of RorAL and L with self.assertRaises(SASLPrepError): saslprep('\u0627\u0031') # RorAL not at both start & end asyncssh-2.20.0/tests/test_sftp.py000066400000000000000000005577441475467777400172370ustar00rootroot00000000000000# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH SFTP client and server""" import asyncio import errno import functools import os from pathlib import Path import posixpath import shutil import stat import sys import time import unittest from unittest.mock import patch import asyncssh from asyncssh import SFTPError, SFTPNoSuchFile, SFTPPermissionDenied from asyncssh import SFTPFailure, SFTPBadMessage, SFTPNoConnection from asyncssh import SFTPConnectionLost, SFTPOpUnsupported, SFTPInvalidHandle from asyncssh import SFTPNoSuchPath, SFTPFileAlreadyExists, SFTPWriteProtect from asyncssh import SFTPNoMedia, SFTPNoSpaceOnFilesystem, SFTPQuotaExceeded from asyncssh import SFTPUnknownPrincipal, SFTPLockConflict, SFTPDirNotEmpty from asyncssh import SFTPNotADirectory, SFTPInvalidFilename, SFTPLinkLoop from asyncssh import SFTPCannotDelete, SFTPInvalidParameter from asyncssh import SFTPFileIsADirectory, SFTPByteRangeLockConflict from asyncssh import SFTPByteRangeLockRefused, SFTPDeletePending from asyncssh import SFTPFileCorrupt, SFTPOwnerInvalid, SFTPGroupInvalid from asyncssh import SFTPNoMatchingByteRangeLock from asyncssh import SFTPAttrs, SFTPVFSAttrs, SFTPName, SFTPServer from asyncssh import SEEK_CUR, SEEK_END from asyncssh import FXP_INIT, FXP_VERSION, FXP_OPEN, FXP_READ from asyncssh import FXP_WRITE, FXP_STATUS, FXP_HANDLE, FXP_DATA from asyncssh import FXF_WRITE, FXF_APPEND, FXF_CREAT, FXF_TRUNC from asyncssh import FXF_CREATE_NEW, FXF_CREATE_TRUNCATE, FXF_OPEN_EXISTING from asyncssh import FXF_OPEN_OR_CREATE, FXF_TRUNCATE_EXISTING from asyncssh import FXF_APPEND_DATA, FXF_BLOCK_READ from asyncssh import ACE4_READ_DATA, ACE4_WRITE_DATA, ACE4_APPEND_DATA from asyncssh import FXR_OVERWRITE from asyncssh import FXRP_STAT_IF_EXISTS, FXRP_STAT_ALWAYS from asyncssh import FILEXFER_ATTR_UIDGID, FILEXFER_ATTR_OWNERGROUP from asyncssh import FILEXFER_TYPE_REGULAR, FILEXFER_TYPE_DIRECTORY from asyncssh import FILEXFER_TYPE_SYMLINK, FILEXFER_TYPE_SPECIAL from asyncssh import FILEXFER_TYPE_UNKNOWN, FILEXFER_TYPE_SOCKET from asyncssh import FILEXFER_TYPE_CHAR_DEVICE, FILEXFER_TYPE_BLOCK_DEVICE from asyncssh import FILEXFER_TYPE_FIFO from asyncssh import FILEXFER_ATTR_BITS_READONLY, FILEXFER_ATTR_KNOWN_TEXT from asyncssh import FX_OK, scp from asyncssh.packet import SSHPacket, String, UInt32 from asyncssh.sftp import SAFE_SFTP_READ_LEN, SAFE_SFTP_WRITE_LEN from asyncssh.sftp import LocalFile, SFTPHandler, SFTPLimits, SFTPServerHandler from .server import ServerTestCase from .util import asynctest def _getpwuid_error(uid): """Simulate not being able to resolve user name""" # pylint: disable=unused-argument raise KeyError def _getgrgid_error(gid): """Simulate not being able to resolve group name""" # pylint: disable=unused-argument raise KeyError def tuple_to_nsec(sec, nsec): """Convert seconds and remainder to nanoseconds since epoch""" return sec * 1_000_000_000 + (nsec or 0) def lookup_user(uid): """Return the user name associated with a uid""" try: # pylint: disable=import-outside-toplevel import pwd return pwd.getpwuid(uid).pw_name except ImportError: # pragma: no cover return '' def lookup_group(gid): """Return the group name associated with a gid""" try: # pylint: disable=import-outside-toplevel import grp return grp.getgrgid(gid).gr_name except ImportError: # pragma: no cover return '' def remove(files): """Remove files and directories""" for f in files.split(): try: if os.path.isdir(f) and not os.path.islink(f): shutil.rmtree(f) else: os.remove(f) except OSError: pass def sftp_test(func): """Decorator for running SFTP tests""" @asynctest @functools.wraps(func) async def sftp_wrapper(self): """Run a test after opening an SFTP client""" async with self.connect() as conn: async with conn.start_sftp_client() as sftp: await func(self, sftp) return sftp_wrapper def sftp_test_v4(func): """Decorator for running SFTPv4 tests""" @asynctest @functools.wraps(func) async def sftp_wrapper(self): """Run a test after opening an SFTP client""" async with self.connect() as conn: async with conn.start_sftp_client(sftp_version=4) as sftp: await func(self, sftp) return sftp_wrapper def sftp_test_v5(func): """Decorator for running SFTPv5 tests""" @asynctest @functools.wraps(func) async def sftp_wrapper(self): """Run a test after opening an SFTP client""" async with self.connect() as conn: async with conn.start_sftp_client(sftp_version=5) as sftp: await func(self, sftp) return sftp_wrapper def sftp_test_v6(func): """Decorator for running SFTPv6 tests""" @asynctest @functools.wraps(func) async def sftp_wrapper(self): """Run a test after opening an SFTP client""" async with self.connect() as conn: async with conn.start_sftp_client(sftp_version=6) as sftp: await func(self, sftp) return sftp_wrapper class _ResetFileHandleServerHandler(SFTPServerHandler): """Reset file handle counter on each request to test handle-in-use check""" async def recv_packet(self): """Reset next handle counter to test handle-in-use check""" self._next_handle = 0 return await super().recv_packet() class _IncompleteMessageServerHandler(SFTPServerHandler): """Close the SFTP session in the middle of sending a message""" async def run(self): """Close the session after sending an incomplete message""" await self.recv_packet() self._writer.write(UInt32(1)) self._writer.close() class _WriteCloseServerHandler(SFTPServerHandler): """Close the SFTP session in the middle of a write request""" async def _process_packet(self, pkttype, pktid, packet): """Close the session when a file close request is received""" if pkttype == FXP_WRITE: await self._cleanup(None) else: await super()._process_packet(pkttype, pktid, packet) class _ReorderReadServerHandler(SFTPServerHandler): """Reorder first two read requests""" _request = 'delay' async def _process_packet(self, pkttype, pktid, packet): """Close the session when a file close request is received""" if pkttype == FXP_READ: if self._request == 'delay': self._request = pkttype, pktid, packet elif self._request: await super()._process_packet(pkttype, pktid, packet) pkttype, pktid, packet = self._request await super()._process_packet(pkttype, pktid, packet) self._request = None else: await super()._process_packet(pkttype, pktid, packet) else: await super()._process_packet(pkttype, pktid, packet) class _CheckPropSFTPServer(SFTPServer): """Return an FTP server which checks channel properties""" def listdir(self, _path): """List the contents of a directory""" if self.channel.get_connection() == self.connection: # pragma: no branch return [SFTPName(k.encode()) for k in self.env.keys()] class _ChrootSFTPServer(SFTPServer): """Return an FTP server with a changed root""" def __init__(self, chan): os.mkdir('chroot') super().__init__(chan, 'chroot') def exit(self): """Clean up the changed root directory""" remove('chroot') def stat(self, path): """Get attributes of a file or directory, following symlinks""" return SFTPAttrs.from_local(super().stat(path)) class _OpenErrorSFTPServer(SFTPServer): """Return an error on file open""" async def open56(self, path, desired_access, flags, attrs): """Return an error when opening a file""" err = getattr(errno, path.decode('ascii')) raise OSError(err, os.strerror(err)) class _IOErrorSFTPServer(SFTPServer): """Return an I/O error during file writing""" async def read(self, file_obj, offset, size): """Return an error for reads past 4 MB in a file""" if offset >= 4*1024*1024: raise SFTPFailure('I/O error') else: return super().read(file_obj, offset, size) async def write(self, file_obj, offset, data): """Return an error for writes past 4 MB in a file""" if offset >= 4*1024*1024: raise SFTPFailure('I/O error') else: super().write(file_obj, offset, data) class _SmallBlockSizeSFTPServer(SFTPServer): """Limit reads to a small block size""" async def read(self, file_obj, offset, size): """Limit reads to return no more than 4 KB at a time""" return super().read(file_obj, offset, min(size, 4096)) class _TruncateSFTPServer(SFTPServer): """Truncate a file when it is accessed, simulating a simultaneous writer""" async def read(self, file_obj, offset, size): """Truncate a file to 32 KB when a read is done""" os.truncate('src', 32768) return super().read(file_obj, offset, size) class _NotImplSFTPServer(SFTPServer): """Return an error that a request is not implemented""" async def symlink(self, oldpath, newpath): """Return that symlinks aren't implemented""" raise NotImplementedError class _FileTypeSFTPServer(SFTPServer): """Return a list of files of each possible file type""" _file_types = ((FILEXFER_TYPE_REGULAR, stat.S_IFREG), (FILEXFER_TYPE_DIRECTORY, stat.S_IFDIR), (FILEXFER_TYPE_SYMLINK, stat.S_IFLNK), (FILEXFER_TYPE_SPECIAL, 0xf000), (FILEXFER_TYPE_UNKNOWN, 0), (FILEXFER_TYPE_SOCKET, stat.S_IFSOCK), (FILEXFER_TYPE_CHAR_DEVICE, stat.S_IFCHR), (FILEXFER_TYPE_BLOCK_DEVICE, stat.S_IFBLK), (FILEXFER_TYPE_FIFO, stat.S_IFIFO)) def listdir(self, _path): """List the contents of a directory""" return [SFTPName(str(filetype).encode('ascii'), attrs=SFTPAttrs(permissions=mode)) for filetype, mode in self._file_types] class _LongnameSFTPServer(SFTPServer): """Return a fixed set of files in response to a listdir request""" def listdir(self, _path): """List the contents of a directory""" # pylint: disable=no-self-use return list((b'.', b'..', SFTPName(b'.file'), SFTPName(b'file1'), SFTPName(b'file2', b'', SFTPAttrs(permissions=0, nlink=1, uid=0, gid=0, size=0, mtime=0)), SFTPName(b'file3', b'', SFTPAttrs(mtime=time.time())), SFTPName(b'file4', 56*b' ' + b'file4'))) def lstat(self, path): """Get attributes of a file, directory, or symlink""" return SFTPAttrs.from_local(super().lstat(path)) class _LargeDirSFTPServer(SFTPServer): """Return a really large listdir result""" async def listdir(self, path): """Return a really large listdir result""" # pylint: disable=unused-argument return 100000 * [SFTPName(b'a', '', SFTPAttrs())] class _StatVFSSFTPServer(SFTPServer): """Return a fixed set of attributes in response to a statvfs request""" expected_statvfs = SFTPVFSAttrs(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) def statvfs(self, path): """Get attributes of the file system containing a file""" # pylint: disable=unused-argument return self.expected_statvfs def fstatvfs(self, file_obj): """Return attributes of the file system containing an open file""" # pylint: disable=unused-argument return self.expected_statvfs class _ChownSFTPServer(SFTPServer): """Simulate file ownership changes""" _ownership = {} def setstat(self, path, attrs): """Get attributes of a file or directory, following symlinks""" self._ownership[self.map_path(path)] = \ (attrs.uid, attrs.gid, attrs.owner, attrs.group) def stat(self, path): """Get attributes of a file or directory, following symlinks""" path = self.map_path(path) attrs = SFTPAttrs.from_local(os.stat(path)) if path in self._ownership: # pragma: no branch attrs.uid, attrs.gid, attrs.owner, attrs.group = \ self._ownership[path] return attrs class _SymlinkSFTPServer(SFTPServer): """Implement symlink with non-standard argument order""" def symlink(self, oldpath, newpath): """Create a symbolic link""" # pylint: disable=arguments-out-of-order return super().symlink(newpath, oldpath) class _SFTPAttrsSFTPServer(SFTPServer): """Implement stat which returns SFTPAttrs and raises SFTPError""" async def stat(self, path): """Get attributes of a file or directory, following symlinks""" try: return SFTPAttrs.from_local(super().stat(path)) except OSError as exc: if exc.errno == errno.EACCES: raise SFTPPermissionDenied(exc.strerror) from None else: raise SFTPError(99, exc.strerror) from None async def lstat(self, path): """Get attributes of a local file, directory, or symlink""" return SFTPAttrs.from_local(super().lstat(path)) async def fstat(self, file_obj): """Get attributes of an open file""" return SFTPAttrs.from_local(super().fstat(file_obj)) async def scandir(self, path): """Return names and attributes of the files in a local directory""" async for name in super().scandir(path): yield name class _AsyncSFTPServer(SFTPServer): """Implement all SFTP callbacks as async methods""" # pylint: disable=useless-super-delegation async def format_longname(self, name): """Format the long name associated with an SFTP name""" return super().format_longname(name) async def open(self, path, pflags, attrs): """Open a file to serve to a remote client""" return super().open(path, pflags, attrs) async def close(self, file_obj): """Close an open file or directory""" super().close(file_obj) async def read(self, file_obj, offset, size): """Read data from an open file""" return super().read(file_obj, offset, size) async def write(self, file_obj, offset, data): """Write data to an open file""" return super().write(file_obj, offset, data) async def lstat(self, path): """Get attributes of a file, directory, or symlink""" return super().lstat(path) async def fstat(self, file_obj): """Get attributes of an open file""" return super().fstat(file_obj) async def setstat(self, path, attrs): """Set attributes of a file or directory, following symlinks""" super().setstat(path, attrs) async def lsetstat(self, path, attrs): """Set attributes of a file, directory, or symlink""" super().lsetstat(path, attrs) async def fsetstat(self, file_obj, attrs): """Set attributes of an open file""" super().fsetstat(file_obj, attrs) def scandir(self, path): """Scan the contents of a directory""" return super().scandir(path) async def remove(self, path): """Remove a file or symbolic link""" super().remove(path) async def mkdir(self, path, attrs): """Create a directory with the specified attributes""" super().mkdir(path, attrs) async def rmdir(self, path): """Remove a directory""" super().rmdir(path) async def realpath(self, path): """Return the canonical version of a path""" return super().realpath(path) async def stat(self, path): """Get attributes of a file or directory, following symlinks""" return super().stat(path) async def rename(self, oldpath, newpath): """Rename a file, directory, or link""" super().rename(oldpath, newpath) async def readlink(self, path): """Return the target of a symbolic link""" return super().readlink(path) async def symlink(self, oldpath, newpath): """Create a symbolic link""" super().symlink(oldpath, newpath) async def posix_rename(self, oldpath, newpath): """Rename a file, directory, or link with POSIX semantics""" super().posix_rename(oldpath, newpath) async def statvfs(self, path): """Get attributes of the file system containing a file""" return super().statvfs(path) async def fstatvfs(self, file_obj): """Return attributes of the file system containing an open file""" return super().fstatvfs(file_obj) async def link(self, oldpath, newpath): """Create a hard link""" super().link(oldpath, newpath) async def lock(self, file_obj, offset, length, flags): """Acquire a byte range lock on an open file""" super().lock(file_obj, offset, length, flags) async def unlock(self, file_obj, offset, length): """Release a byte range lock on an open file""" super().unlock(file_obj, offset, length) async def fsync(self, file_obj): """Force file data to be written to disk""" super().fsync(file_obj) class _CheckSFTP(ServerTestCase): """Utility functions for AsyncSSH SFTP unit tests""" @classmethod def setUpClass(cls): """Check if symlink is available on this platform""" super().setUpClass() try: os.symlink('file', 'link') os.remove('link') cls._symlink_supported = True except OSError: # pragma: no cover cls._symlink_supported = False def _create_file(self, name, data=(), mode=None, utime=None): """Create a test file""" if data == (): data = str(id(self)) binary = 'b' if isinstance(data, bytes) else '' with open(name, 'w' + binary) as f: f.write(data) if mode is not None: os.chmod(name, mode) if utime is not None: os.utime(name, utime) def _check_attr(self, name1, name2, follow_symlinks, check_atime): """Check if attributes on two files are equal""" statfunc = os.stat if follow_symlinks else os.lstat attrs1 = statfunc(name1) attrs2 = statfunc(name2) self.assertEqual(stat.S_IMODE(attrs1.st_mode), stat.S_IMODE(attrs2.st_mode)) self.assertEqual(int(attrs1.st_mtime), int(attrs2.st_mtime)) if check_atime: self.assertEqual(int(attrs1.st_atime), int(attrs2.st_atime)) def _check_file(self, name1, name2, preserve=False, follow_symlinks=False, check_atime=True): """Check if two files are equal""" if preserve: self._check_attr(name1, name2, follow_symlinks, check_atime) with open(name1, 'rb') as file1: with open(name2, 'rb') as file2: self.assertEqual(file1.read(), file2.read()) def _check_stat(self, sftp_stat, local_stat): """Check if file attributes are equal""" self.assertEqual(sftp_stat.size, local_stat.st_size) self.assertEqual(sftp_stat.uid, local_stat.st_uid) self.assertEqual(sftp_stat.gid, local_stat.st_gid) self.assertEqual(sftp_stat.permissions, local_stat.st_mode) self.assertEqual(sftp_stat.atime, int(local_stat.st_atime)) self.assertEqual(sftp_stat.mtime, int(local_stat.st_mtime)) def _check_stat_v4(self, sftp_stat, local_stat): """Check if file attributes are equal""" self.assertEqual(sftp_stat.size, local_stat.st_size) if sys.platform != 'win32': # pragma: no branch self.assertEqual(sftp_stat.owner, lookup_user(local_stat.st_uid)) self.assertEqual(sftp_stat.group, lookup_group(local_stat.st_gid)) self.assertEqual(sftp_stat.permissions, stat.S_IMODE(local_stat.st_mode)) self.assertEqual(tuple_to_nsec(sftp_stat.atime, sftp_stat.atime_ns), local_stat.st_atime_ns) self.assertEqual(tuple_to_nsec(sftp_stat.mtime, sftp_stat.mtime_ns), local_stat.st_mtime_ns) def _check_link(self, link, target): """Check if a symlink points to the right target""" link = os.readlink(link) if link.startswith('\\\\?\\'): # pragma: no cover link = link[4:] self.assertEqual(Path(link).resolve(), Path(target).resolve()) class _TestSFTP(_CheckSFTP): """Unit tests for AsyncSSH SFTP client and server""" @classmethod async def start_server(cls): """Start an SFTP server for the tests to use""" return await cls.create_server(sftp_factory=True, sftp_version=6) @sftp_test async def _dummy_sftp_client(self, sftp): """Test starting a new SFTPv3 client session and immediately exiting""" self.assertEqual(sftp.version, 3) @sftp_test_v5 async def _dummy_sftp_client_v5(self, sftp): """Test starting a new SFTPv5 client session and immediately exiting""" self.assertEqual(sftp.version, 5) @sftp_test_v6 async def _dummy_sftp_client_v6(self, sftp): """Test starting a new SFTPv6 client session and immediately exiting""" self.assertEqual(sftp.version, 6) @sftp_test async def test_copy(self, sftp): """Test copying a file over SFTP""" for method in ('get', 'put', 'copy'): for src in ('src', b'src', Path('src')): with self.subTest(method=method, src=type(src)): try: self._create_file('src') await getattr(sftp, method)(src, 'dst') self._check_file('src', 'dst') finally: remove('src dst') @sftp_test async def test_copy_max_requests(self, sftp): """Test copying a file over SFTP with max requests set""" for method in ('get', 'put', 'copy'): for src in ('src', b'src', Path('src')): with self.subTest(method=method, src=type(src)): try: self._create_file('src', 16*1024*1024*'\0') await getattr(sftp, method)(src, 'dst', max_requests=4) self._check_file('src', 'dst') finally: remove('src dst') def test_copy_non_remote(self): """Test copying without using remote_copy function""" @sftp_test async def _test_copy_non_remote(self, sftp): """Test copying without using remote_copy function""" for method in ('copy', 'mcopy'): with self.subTest(method=method): try: self._create_file('src') await sftp.copy('src', 'dst') self._check_file('src', 'dst') finally: remove('src dst') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _test_copy_non_remote(self) def test_copy_remote_only(self): """Test copying while allowing only remote copy""" @sftp_test async def _test_copy_remote_only(self, sftp): """Test copying with only remote copy allowed""" for method in ('copy', 'mcopy'): with self.subTest(method=method): try: self._create_file('src') with self.assertRaises(SFTPOpUnsupported): await getattr(sftp, method)('src', 'dst', remote_only=True) finally: remove('src') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _test_copy_remote_only(self) @sftp_test async def test_copy_progress(self, sftp): """Test copying a file over SFTP with progress reporting""" def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes): """Monitor progress of copy""" reports.append(bytes_copied) for method in ('get', 'put', 'copy'): for size in (0, 100000): with self.subTest(method=method, size=size): reports = [] try: self._create_file('src', size * 'a') await getattr(sftp, method)( 'src', 'dst', block_size=8192, progress_handler=_report_progress) self._check_file('src', 'dst') if method != 'copy': self.assertEqual(len(reports), (size // 8192) + 1) self.assertEqual(reports[-1], size) finally: remove('src dst') @sftp_test async def test_copy_preserve(self, sftp): """Test copying a file with preserved attributes over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: self._create_file('src', mode=0o666, utime=(1, 2)) await getattr(sftp, method)('src', 'dst', preserve=True) self._check_file('src', 'dst', preserve=True) finally: remove('src dst') @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') @sftp_test async def test_copy_preserve_link(self, sftp): """Test copying a symlink with preserved attributes over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.symlink('file', 'link1') os.utime('link1', times=(1, 2), follow_symlinks=False) await getattr(sftp, method)( 'link1', 'link2', preserve=True, follow_symlinks=False) self.assertEqual(os.lstat('link2').st_mtime, 2) finally: remove('link1 link2') @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') def test_copy_preserve_link_unsupported(self): """Test preserving symlink attributes over SFTP without lsetstat""" @sftp_test async def _lsetstat_unsupported(self, sftp): """Try copying link attributes without lsetstat""" try: os.symlink('file', 'link1') os.utime('link1', times=(1, 2), follow_symlinks=False) await sftp.put('link1', 'link2', preserve=True, follow_symlinks=False) self.assertNotEqual(int(os.lstat('link2').st_mtime), 2) finally: remove('link1 link2') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _lsetstat_unsupported(self) @sftp_test async def test_copy_recurse(self, sftp): """Test recursively copying a directory over SFTP""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.mkdir('src') self._create_file('src/file1') if self._symlink_supported: # pragma: no branch os.symlink('file1', 'src/file2') await getattr(sftp, method)('src', 'dst', recurse=True) self._check_file('src/file1', 'dst/file1') if self._symlink_supported: # pragma: no branch self._check_link('dst/file2', 'file1') finally: remove('src dst') @sftp_test async def test_copy_recurse_existing(self, sftp): """Test recursively copying over SFTP where target dir exists""" for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.mkdir('src') os.mkdir('dst') os.mkdir('dst/src') self._create_file('src/file1') if self._symlink_supported: # pragma: no branch os.symlink('file1', 'src/file2') await getattr(sftp, method)('src', 'dst', recurse=True) self._check_file('src/file1', 'dst/src/file1') if self._symlink_supported: # pragma: no branch self._check_link('dst/src/file2', 'file1') finally: remove('src dst') @sftp_test async def test_copy_follow_symlinks(self, sftp): """Test copying a file over SFTP while following symlinks""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: self._create_file('src') os.symlink('src', 'link') await getattr(sftp, method)('link', 'dst', follow_symlinks=True) self._check_file('src', 'dst') finally: remove('src dst link') @sftp_test async def test_copy_recurse_follow_symlinks(self, sftp): """Test recursively copying over SFTP while following symlinks""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') for method in ('get', 'put', 'copy'): with self.subTest(method=method): try: os.mkdir('src') self._create_file('src/file1') os.symlink('file1', 'src/file2') await getattr(sftp, method)('src', 'dst', recurse=True, follow_symlinks=True) self._check_file('src/file1', 'dst/file2') finally: remove('src dst') @sftp_test async def test_copy_invalid_name(self, sftp): """Test copying a file with an invalid name over SFTP""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): with self.assertRaises((OSError, SFTPNoSuchFile, SFTPFailure, UnicodeDecodeError)): await getattr(sftp, method)(b'\xff') @sftp_test async def test_copy_directory_no_recurse(self, sftp): """Test copying a directory over SFTP without recurse option""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): try: os.mkdir('dir') with self.assertRaises(SFTPFailure): await getattr(sftp, method)('dir') finally: remove('dir') @sftp_test_v6 async def test_copy_directory_no_recurse_v6(self, sftp): """Test copying a directory over SFTPv6 without recurse option""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): try: os.mkdir('dir') with self.assertRaises(SFTPFileIsADirectory): await getattr(sftp, method)('dir') finally: remove('dir') @sftp_test async def test_multiple_copy(self, sftp): """Test copying multiple files over SFTP""" for method in ('get', 'put', 'copy'): for seq in (list, tuple): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(seq(('src1', 'src2')), 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test_v4 async def test_multiple_copy_v4(self, sftp): """Test copying multiple files over SFTPv4""" for method in ('get', 'put', 'copy'): for seq in (list, tuple): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(seq(('src1', 'src2')), 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test_v5 async def test_multiple_copy_v5(self, sftp): """Test copying multiple files over SFTPv5""" for method in ('get', 'put', 'copy'): for seq in (list, tuple): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(seq(('src1', 'src2')), 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test_v6 async def test_multiple_copy_v6(self, sftp): """Test copying multiple files over SFTPv6""" for method in ('get', 'put', 'copy'): for seq in (list, tuple): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(seq(('src1', 'src2')), 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test async def test_multiple_copy_glob(self, sftp): """Test copying multiple files via glob over SFTP""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(['', 'src*'], 'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test async def test_multiple_copy_bytes_path(self, sftp): """Test copying multiple files with byte string paths over SFTP""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(b'src*', b'dst') self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test async def test_multiple_copy_pathlib_path(self, sftp): """Test copying multiple files with pathlib paths over SFTP""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1', 'xxx') self._create_file('src2', 'yyy') os.mkdir('dst') await getattr(sftp, method)(Path('src*'), Path('dst')) self._check_file('src1', 'dst/src1') self._check_file('src2', 'dst/src2') finally: remove('src1 src2 dst') @sftp_test async def test_multiple_copy_target_not_dir(self, sftp): """Test copying multiple files over SFTP with non-directory target""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1') self._create_file('src2') with self.assertRaises(SFTPFailure): await getattr(sftp, method)('src*', 'dst') finally: remove('src') @sftp_test_v6 async def test_multiple_copy_target_not_dir_v6(self, sftp): """Test copying multiple files over SFTP with non-directory target""" for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1') self._create_file('src2') with self.assertRaises(SFTPNotADirectory): await getattr(sftp, method)('src*', 'dst') finally: remove('src') @sftp_test async def test_multiple_copy_error_handler(self, sftp): """Test copying multiple files over SFTP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'src2 is a directory') for method in ('mget', 'mput', 'mcopy'): with self.subTest(method=method): try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') await getattr(sftp, method)('src*', 'dst', error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') def test_remote_copy_unsupported(self): """Test remote copy on a server which doesn't support it""" @sftp_test async def _test_remote_copy_unsupported(self, sftp): """Test remote copy not being supported""" try: self._create_file('src') with self.assertRaises(SFTPOpUnsupported): await sftp.remote_copy('src', 'dst') finally: remove('src') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _test_remote_copy_unsupported(self) @sftp_test async def test_remote_copy_arguments(self, sftp): """Test remote copy arguments""" try: self._create_file('src', os.urandom(2*1024*1024)) async with sftp.open('src', 'rb') as src: async with sftp.open('dst', 'wb') as dst: await sftp.remote_copy(src, dst, 0, 1024*1024, 0) await sftp.remote_copy(src, dst, 1024*1024, 0, 1024*1024) self._check_file('src', 'dst') finally: remove('src dst') @sftp_test async def test_remote_copy_closed_file(self, sftp): """Test remote copy of a closed file""" try: self._create_file('file') async with sftp.open('file', 'rb') as f: await f.close() with self.assertRaises(ValueError): await sftp.remote_copy(f, f) finally: remove('file') @sftp_test async def test_glob(self, sftp): """Test a glob pattern match over SFTP""" glob_tests = ( ('file*', ['file1', 'filedir']), ('./file*', ['./file1', './filedir']), (b'file*', [b'file1', b'filedir']), (['file*'], ['file1', 'filedir']), (['', 'file*'], ['file1', 'filedir']), (['file*/*2'], ['filedir/file2', 'filedir/filedir2']), (['file*/*[3-9]'], ['filedir/file3']), (['**/file[12]'], ['file1', 'filedir/file2']), (['**/file*/'], ['filedir/', 'filedir/filedir2/']), (['filedir/**'], ['filedir', 'filedir/file2', 'filedir/file3', 'filedir/filedir2', 'filedir/filedir2/file4', 'filedir/filedir2/file5']), ('filedir/file2', ['filedir/file2']), ('./filedir/file2', ['./filedir/file2']), ('filedir/file*', ['filedir/file2', 'filedir/file3', 'filedir/filedir2']), ('./filedir/file*', ['./filedir/file2', './filedir/file3', './filedir/filedir2']), ('./filedir/filedir2/file*', ['./filedir/filedir2/file4', './filedir/filedir2/file5']), ('filedir/filedir2/file*', ['filedir/filedir2/file4', 'filedir/filedir2/file5']), ('./filedir/*/file4', ['./filedir/filedir2/file4']), ('filedir/*/file4', ['filedir/filedir2/file4']), ('./*/filedir2/file4', ['./filedir/filedir2/file4']), ('*/filedir2/file4', ['filedir/filedir2/file4']), ('*/filedir2/file*4', ['filedir/filedir2/file4']), ('./filedir/filedir*/file*', ['./filedir/filedir2/file4', './filedir/filedir2/file5']), ('filedir/filedir*/file*', ['filedir/filedir2/file4', 'filedir/filedir2/file5']), ('./**/filedir2/file4', ['./filedir/filedir2/file4']), ('**/filedir2/file4', ['filedir/filedir2/file4']), (['file1', '**/file1'], ['file1'])) try: os.mkdir('filedir') self._create_file('file1') self._create_file('filedir/file2') self._create_file('filedir/file3') os.mkdir('filedir/filedir2') self._create_file('filedir/filedir2/file4') self._create_file('filedir/filedir2/file5') for pattern, matches in glob_tests: with self.subTest(pattern=pattern): self.assertEqual(sorted(await sftp.glob(pattern)), matches) self.assertEqual((await sftp.glob([b'fil*1', 'fil*dir'])), [b'file1', 'filedir']) finally: remove('file1 filedir') @sftp_test async def test_glob_errors(self, sftp): """Test glob pattern match errors over SFTP""" _glob_errors = ( 'file*', 'dir/file1/*', 'dir*/file1/*', 'dir/dir1/*') try: os.mkdir('dir') self._create_file('dir/file1') os.mkdir('dir/dir1') os.chmod('dir/dir1', 0) for pattern in _glob_errors: with self.subTest(pattern=pattern): with self.assertRaises(SFTPNoSuchFile): await sftp.glob(pattern) finally: os.chmod('dir/dir1', 0o700) remove('dir') @sftp_test_v4 async def test_glob_error_v4(self, sftp): """Test a glob pattern match error over SFTP""" with self.assertRaises(SFTPNoSuchPath): await sftp.glob('file*') @sftp_test async def test_glob_error_handler(self, sftp): """Test a glob pattern match with error handler over SFTP""" def err_handler(exc): """Catch error for nonexistent file1""" self.assertEqual(exc.reason, 'No matches found') try: self._create_file('file2') self.assertEqual((await sftp.glob(['file1*', 'file2*'], error_handler=err_handler)), ['file2']) finally: remove('file2') @sftp_test async def test_stat(self, sftp): """Test getting attributes on a file""" try: os.mkdir('dir') self._create_file('file') if self._symlink_supported: # pragma: no branch os.symlink('bad', 'badlink') os.symlink('dir', 'dirlink') os.symlink('file', 'filelink') self._check_stat((await sftp.stat('dir')), os.stat('dir')) self._check_stat((await sftp.stat('file')), os.stat('file')) if self._symlink_supported: # pragma: no branch self._check_stat((await sftp.stat('dirlink')), os.stat('dir')) self._check_stat((await sftp.stat('filelink')), os.stat('file')) with self.assertRaises(SFTPNoSuchFile): await sftp.stat('badlink') self.assertTrue(await sftp.isdir('dir')) self.assertFalse(await sftp.isdir('file')) if self._symlink_supported: # pragma: no branch self.assertFalse(await sftp.isdir('badlink')) self.assertTrue(await sftp.isdir('dirlink')) self.assertFalse(await sftp.isdir('filelink')) self.assertFalse(await sftp.isfile('dir')) self.assertTrue(await sftp.isfile('file')) if self._symlink_supported: # pragma: no branch self.assertFalse(await sftp.isfile('badlink')) self.assertFalse(await sftp.isfile('dirlink')) self.assertTrue(await sftp.isfile('filelink')) self.assertFalse(await sftp.islink('dir')) self.assertFalse(await sftp.islink('file')) if self._symlink_supported: # pragma: no branch self.assertTrue(await sftp.islink('badlink')) self.assertTrue(await sftp.islink('dirlink')) self.assertTrue(await sftp.islink('filelink')) finally: remove('dir file badlink dirlink filelink') @sftp_test async def test_lstat(self, sftp): """Test getting attributes on a link""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link') self._check_stat((await sftp.lstat('link')), os.lstat('link')) finally: remove('link') @sftp_test_v4 async def test_lstat_v4(self, sftp): """Test getting attributes on a link with SFTPv4""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link') self._check_stat_v4((await sftp.lstat('link')), os.lstat('link')) finally: remove('link') @sftp_test_v6 async def test_lstat_v6(self, sftp): """Test getting attributes on a link with SFTPv6""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link') self._check_stat_v4((await sftp.lstat('link')), os.lstat('link')) finally: remove('link') @sftp_test async def test_lstat_via_stat(self, sftp): """Test getting attributes on a link by disabling follow_symlinks""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link') self._check_stat((await sftp.stat('link', follow_symlinks=False)), os.lstat('link')) finally: remove('link') @sftp_test async def test_setstat(self, sftp): """Test setting attributes on a file""" try: self._create_file('file') await sftp.setstat('file', SFTPAttrs(permissions=0o666)) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o666) with self.assertRaises(ValueError): await sftp.setstat('file', SFTPAttrs(owner='root', group='wheel')) finally: remove('file') @sftp_test_v4 async def test_setstat_v4(self, sftp): """Test setting attributes on a file""" try: self._create_file('file') await sftp.setstat('file', SFTPAttrs(atime=1)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 1) await sftp.setstat('file', SFTPAttrs(mtime=2)) stat_result = os.stat('file') self.assertEqual(stat_result.st_mtime, 2) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip uid/gid tests on Windows') @sftp_test_v6 async def test_setstat_invalid_owner_group_v6(self, sftp): """Test setting invalid owner/group on a file""" try: self._create_file('file') with patch('pwd.getpwuid', _getpwuid_error): with self.assertRaises(SFTPOwnerInvalid): await sftp.setstat('file', SFTPAttrs(owner='xxx', group='0')) with patch('grp.getgrgid', _getgrgid_error): with self.assertRaises(SFTPGroupInvalid): await sftp.setstat('file', SFTPAttrs(owner='0', group='yyy')) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') @sftp_test async def test_lsetstat(self, sftp): """Test setting attributes on a link""" try: os.symlink('file', 'link') await sftp.setstat('link', SFTPAttrs(atime=1, mtime=2), follow_symlinks=False) stat_result = os.lstat('link') self.assertEqual(stat_result.st_atime, 1) self.assertEqual(stat_result.st_mtime, 2) finally: remove('link') @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') @sftp_test_v4 async def test_lsetstat_v4(self, sftp): """Test setting attributes on a link""" try: os.symlink('file', 'link') await sftp.setstat('link', SFTPAttrs(atime=1), follow_symlinks=False) self.assertEqual(os.lstat('link').st_atime, 1) await sftp.setstat('link', SFTPAttrs(mtime=2), follow_symlinks=False) self.assertEqual(os.lstat('link').st_mtime, 2) finally: remove('link') @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') @sftp_test_v6 async def test_lsetstat_v6(self, sftp): """Test setting attributes on a link""" try: os.symlink('file', 'link') await sftp.setstat('link', SFTPAttrs(atime=1), follow_symlinks=False) self.assertEqual(os.lstat('link').st_atime, 1) await sftp.setstat('link', SFTPAttrs(mtime=2), follow_symlinks=False) self.assertEqual(os.lstat('link').st_mtime, 2) finally: remove('link') @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @sftp_test async def test_statvfs(self, sftp): """Test getting attributes on a filesystem We can't compare the values returned by a live statvfs call since they can change at any time. See the separate _TestSFTStatPVFS class for a more complete test, but this is left in for code coverage purposes. """ self.assertIsInstance((await sftp.statvfs('.')), SFTPVFSAttrs) @sftp_test async def test_truncate(self, sftp): """Test truncating a file""" try: self._create_file('file', '01234567890123456789') await sftp.truncate('file', 10) self.assertEqual((await sftp.getsize('file')), 10) with open('file') as localf: self.assertEqual(localf.read(), '0123456789') finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') @sftp_test async def test_chown(self, sftp): """Test changing ownership of a file We can't change to a different user/group here if we're not root, so just change to the same user/group. See the separate _TestSFTPChown class for a more complete test, but this is left in for code coverage purposes. """ try: self._create_file('file') stat_result = os.stat('file') await sftp.chown('file', stat_result.st_uid, stat_result.st_gid) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') @sftp_test_v4 async def test_chown_v4(self, sftp): """Test changing ownership of a file We can't change to a different user/group here if we're not root, so just change to the same user/group. See the separate _TestSFTPChown class for a more complete test, but this is left in for code coverage purposes. """ try: self._create_file('file') stat_result = os.stat('file') owner = lookup_user(stat_result.st_uid) group = lookup_group(stat_result.st_gid) await sftp.chown('file', owner, group) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await sftp.chown('file', str(stat_result.st_uid), group) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await sftp.chown('file', owner, str(stat_result.st_gid)) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await sftp.chown('file', str(stat_result.st_uid), group) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await sftp.chown('file', owner, str(stat_result.st_gid)) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) finally: remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chmod tests on Windows') @sftp_test async def test_chmod(self, sftp): """Test changing permissions on a file""" try: self._create_file('file') await sftp.chmod('file', 0o4321) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o4321) finally: remove('file') @sftp_test async def test_utime(self, sftp): """Test changing access and modify times on a file""" try: self._create_file('file') await sftp.utime('file') await sftp.utime('file', (1, 2)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 1) self.assertEqual(stat_result.st_mtime, 2) self.assertEqual((await sftp.getatime('file')), 1) self.assertEqual((await sftp.getmtime('file')), 2) finally: remove('file') @sftp_test_v4 async def test_utime_v4(self, sftp): """Test changing access and modify times on a file with SFTPv4""" try: self._create_file('file') await sftp.utime('file') await sftp.utime('file', (1.0, 2.25)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 1.0) self.assertEqual(stat_result.st_atime_ns, 1000000000) self.assertEqual(stat_result.st_mtime, 2.25) self.assertEqual(stat_result.st_mtime_ns, 2250000000) self.assertEqual((await sftp.getatime('file')), 1.0) self.assertEqual((await sftp.getatime_ns('file')), 1000000000) self.assertIsNotNone(await sftp.getcrtime('file')) self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 2.25) self.assertEqual((await sftp.getmtime_ns('file')), 2250000000) await sftp.utime('file', ns=(3500000000, 4750000000)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 3.5) self.assertEqual(stat_result.st_atime_ns, 3500000000) self.assertEqual(stat_result.st_mtime, 4.75) self.assertEqual(stat_result.st_mtime_ns, 4750000000) self.assertEqual((await sftp.getatime('file')), 3.5) self.assertEqual((await sftp.getatime_ns('file')), 3500000000) self.assertIsNotNone(await sftp.getcrtime('file')) self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 4.75) self.assertEqual((await sftp.getmtime_ns('file')), 4750000000) finally: remove('file') @sftp_test async def test_exists(self, sftp): """Test checking whether a file exists""" try: self._create_file('file1') self.assertTrue(await sftp.exists('file1')) self.assertFalse(await sftp.exists('file2')) finally: remove('file1') @sftp_test async def test_lexists(self, sftp): """Test checking whether a link exists""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('file', 'link1') self.assertTrue(await sftp.lexists('link1')) self.assertFalse(await sftp.lexists('link2')) finally: remove('link1') @sftp_test async def test_remove(self, sftp): """Test removing a file""" try: self._create_file('file') await sftp.remove('file') with self.assertRaises(FileNotFoundError): os.stat('file') with self.assertRaises(SFTPNoSuchFile): await sftp.remove('file') finally: remove('file') @sftp_test async def test_unlink(self, sftp): """Test unlinking a file""" try: self._create_file('file') await sftp.unlink('file') with self.assertRaises(FileNotFoundError): os.stat('file') with self.assertRaises(SFTPNoSuchFile): await sftp.unlink('file') finally: remove('file') @sftp_test async def test_rename(self, sftp): """Test renaming a file""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') with self.assertRaises(SFTPFailure): await sftp.rename('file1', 'file2') await sftp.rename('file1', 'file3') with open('file3') as localf: self.assertEqual(localf.read(), 'xxx') await sftp.rename('file2', 'file3', FXR_OVERWRITE) with open('file3') as localf: self.assertEqual(localf.read(), 'yyy') finally: remove('file1 file2 file3') @sftp_test_v6 async def test_rename_v6(self, sftp): """Test renaming a file with SFTPv6""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') with self.assertRaises(SFTPFileAlreadyExists): await sftp.rename('file1', 'file2') await sftp.rename('file1', 'file3') with open('file3') as localf: self.assertEqual(localf.read(), 'xxx') await sftp.rename('file2', 'file3', FXR_OVERWRITE) with open('file3') as localf: self.assertEqual(localf.read(), 'yyy') finally: remove('file1 file2 file3') @sftp_test async def test_posix_rename(self, sftp): """Test renaming a file that replaces a target file""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') await sftp.posix_rename('file1', 'file2') with open('file2') as localf: self.assertEqual(localf.read(), 'xxx') finally: remove('file1 file2') @sftp_test_v6 async def test_posix_rename_v6(self, sftp): """Test renaming a file that replaces a target file""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') await sftp.posix_rename('file1', 'file2') with open('file2') as localf: self.assertEqual(localf.read(), 'xxx') finally: remove('file1 file2') @sftp_test async def test_listdir(self, sftp): """Test listing files in a directory""" try: os.mkdir('dir') self._create_file('dir/file1') self._create_file('dir/file2') self.assertEqual(sorted(await sftp.listdir('dir')), ['.', '..', 'file1', 'file2']) finally: remove('dir') @sftp_test_v4 async def test_listdir_v4(self, sftp): """Test listing files in a directory with SFTPv4""" try: os.mkdir('dir') self._create_file('dir/file1') self._create_file('dir/file2') self.assertEqual(sorted(await sftp.listdir('dir')), ['.', '..', 'file1', 'file2']) finally: remove('dir') @sftp_test_v4 async def test_listdir_error_v4(self, sftp): """Test error while listing contents of a directory""" orig_readdir = asyncssh.sftp.SFTPClientHandler.readdir async def _readdir_error(self, handle): """Return an error on an SFTP readdir request""" # pylint: disable=unused-argument return await orig_readdir(self, b'\xff\xff\xff\xff') try: os.mkdir('dir') with patch('asyncssh.sftp.SFTPClientHandler.readdir', _readdir_error): with self.assertRaises(SFTPInvalidHandle): await sftp.listdir('dir') finally: remove('dir') @sftp_test async def test_mkdir(self, sftp): """Test creating a directory""" try: await sftp.mkdir('dir') self.assertTrue(os.path.isdir('dir')) finally: remove('dir') @sftp_test async def test_rmdir(self, sftp): """Test removing a directory""" try: os.mkdir('dir') await sftp.rmdir('dir') with self.assertRaises(FileNotFoundError): os.stat('dir') finally: remove('dir') @sftp_test_v6 async def test_rmdir_not_empty_v6(self, sftp): """Test rmdir on a non-empty directory""" try: os.mkdir('dir') self._create_file('dir/file') with self.assertRaises(SFTPDirNotEmpty): await sftp.rmdir('dir') finally: remove('dir') @sftp_test_v6 async def test_open_file_dir_v6(self, sftp): """Test open on a directory""" try: os.mkdir('dir') with self.assertRaises((SFTPPermissionDenied, SFTPFileIsADirectory)): await sftp.open('dir') finally: remove('dir') @sftp_test async def test_rmtree(self, sftp): """Test removing a directory tree""" try: os.mkdir('dir') os.mkdir('dir/dir1') os.mkdir('dir/dir1/dir2') os.mkdir('dir/dir3') self._create_file('dir/file1') self._create_file('dir/file2') self._create_file('dir/dir1/file3') await sftp.rmtree('dir') with self.assertRaises(FileNotFoundError): os.stat('dir') finally: remove('dir') @sftp_test async def test_rmtree_non_existent(self, sftp): """Test passing a non-existent directory to rmtree""" with self.assertRaises(SFTPNoSuchFile): await sftp.rmtree('xxx') @sftp_test async def test_rmtree_ignore_errors(self, sftp): """Test ignoring errors in rmtree""" await sftp.rmtree('xxx', ignore_errors=True) @sftp_test async def test_rmtree_onerror(self, sftp): """Test onerror callback in rmtree""" def _error_handler(*args): errors.append(args) errors = [] await sftp.rmtree('xxx', onerror=_error_handler) self.assertEqual(errors[0][0], sftp.scandir) self.assertEqual(errors[0][1], b'xxx') self.assertEqual(errors[0][2][0], SFTPNoSuchFile) @sftp_test async def test_rmtree_file(self, sftp): """Test passing a file to rmtree""" try: self._create_file('file') with self.assertRaises(SFTPNoSuchFile): await sftp.rmtree('file') finally: remove('file') @sftp_test async def test_rmtree_symlink(self, sftp): """Test passing a symlink to rmtree""" try: os.mkdir('dir') os.symlink('dir', 'link') with self.assertRaises(SFTPNoSuchFile): await sftp.rmtree('link') finally: remove('dir link') @sftp_test async def test_rmtree_symlink_onerror(self, sftp): """Test passing a symlink to rmtree with onerror callback""" def _error_handler(*args): errors.append(args) errors = [] try: os.mkdir('dir') os.symlink('dir', 'link') await sftp.rmtree('link', onerror=_error_handler) self.assertEqual(errors[0][0], sftp.islink) self.assertEqual(errors[0][1], b'link') self.assertEqual(errors[0][2][0], SFTPNoSuchFile) finally: remove('dir link') @sftp_test async def test_rmtree_rmdir_failure(self, sftp): """Test rmdir failing in rmtree""" try: os.mkdir('dir') os.mkdir('dir/subdir') os.chmod('dir', 0o555) with self.assertRaises(SFTPPermissionDenied): await sftp.rmtree('dir') finally: os.chmod('dir', 0o755) remove('dir') @sftp_test async def test_rmtree_unlink_failure(self, sftp): """Test unlink failing in rmtree""" try: os.mkdir('dir') self._create_file('dir/file') os.chmod('dir', 0o555) with self.assertRaises(SFTPPermissionDenied): await sftp.rmtree('dir') finally: os.chmod('dir', 0o755) remove('dir') @sftp_test async def test_readlink(self, sftp): """Test reading a symlink""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('/file', 'link') self.assertEqual((await sftp.readlink('link')), '/file') self.assertEqual((await sftp.readlink(b'link')), b'/file') finally: remove('link') @sftp_test_v6 async def test_readlink_v6(self, sftp): """Test reading a symlink with SFTPv6""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: os.symlink('/file', 'link') self.assertEqual((await sftp.readlink('link')), '/file') self.assertEqual((await sftp.readlink(b'link')), b'/file') finally: remove('link') @sftp_test async def test_readlink_decode_error(self, sftp): """Test unicode decode error while reading a symlink""" async def _readlink_error(self, path): """Return invalid unicode on an SFTP readlink request""" # pylint: disable=unused-argument return [SFTPName(b'\xff')], False with patch('asyncssh.sftp.SFTPClientHandler.readlink', _readlink_error): with self.assertRaises(SFTPBadMessage): await sftp.readlink('link') @sftp_test async def test_symlink(self, sftp): """Test creating a symlink""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: await sftp.symlink('file', 'link') self._check_link('link', 'file') with self.assertRaises(SFTPFailure): await sftp.symlink('file', 'link') finally: remove('file link') @sftp_test_v4 async def test_symlink_v4(self, sftp): """Test creating a symlink with SFTPv4""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: await sftp.symlink('file', 'link') self._check_link('link', 'file') with self.assertRaises(SFTPFileAlreadyExists): await sftp.symlink('file', 'link') finally: remove('file link') @sftp_test_v6 async def test_symlink_v6(self, sftp): """Test creating a symlink with SFTPv6""" try: await sftp.symlink('file', 'link') self._check_link('link', 'file') with self.assertRaises(SFTPFileAlreadyExists): await sftp.symlink('file', 'link') finally: remove('file link') @asynctest async def test_symlink_encode_error(self): """Test creating a unicode symlink with no path encoding set""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') async with self.connect() as conn: async with conn.start_sftp_client(path_encoding=None) as sftp: with self.assertRaises(SFTPBadMessage): await sftp.symlink('file', 'link') @asynctest async def test_nonstandard_symlink_client(self): """Test creating a symlink with opposite argument order""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: async with self.connect(client_version='OpenSSH') as conn: async with conn.start_sftp_client() as sftp: await sftp.symlink('link', 'file') self._check_link('link', 'file') finally: remove('file link') @sftp_test async def test_link(self, sftp): """Test creating a hard link""" try: self._create_file('file1') await sftp.link('file1', 'file2') self._check_file('file1', 'file2') finally: remove('file1 file2') @sftp_test_v6 async def test_link_v6(self, sftp): """Test creating a hard link with SFTPv6""" try: self._create_file('file1') await sftp.link('file1', 'file2') self._check_file('file1', 'file2') finally: remove('file1 file2') @sftp_test async def test_open_read(self, sftp): """Test reading data from a file""" f = None try: self._create_file('file', 'xxx') f = await sftp.open('file') self.assertEqual((await f.read()), 'xxx') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_read_bytes(self, sftp): """Test reading bytes from a file""" f = None try: self._create_file('file', 'xxx') f = await sftp.open('file', 'rb') self.assertEqual((await f.read()), b'xxx') await f.seek(0) self.assertEqual([result async for result in await f.read_parallel()], [(0, b'xxx')]) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_read_offset_size(self, sftp): """Test reading at a specific offset and size""" f = None try: self._create_file('file', 'xxxxyyyy') f = await sftp.open('file') self.assertEqual((await f.read(4, 2)), 'xxyy') self.assertEqual([result async for result in await f.read_parallel(4, 2)], [(2, b'xxyy')]) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_read_no_blocksize(self, sftp): """Test reading with no block size set""" f = None try: self._create_file('file', 'xxxxyyyy') f = await sftp.open('file', block_size=None) self.assertEqual((await f.read(4, 2)), 'xxyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_read_parallel(self, sftp): """Test reading data from a file using parallel I/O""" f = None try: self._create_file('file', 40*1024*'\0') f = await sftp.open('file') self.assertEqual(len(await f.read(64*1024)), 40*1024) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_read_max_requests(self, sftp): """Test reading data from a file with max requests set""" f = None try: self._create_file('file', 16*1024*1024*'\0') f = await sftp.open('file', max_requests=4) self.assertEqual(len(await f.read()), 16*1024*1024) finally: if f: # pragma: no branch await f.close() remove('file') def test_open_read_out_of_order(self): """Test parallel read with out-of-order responses""" @sftp_test async def _test_read_out_of_order(self, sftp): """Test parallel read with out-of-order responses""" f = None try: random_data = os.urandom(12*1024*1024) self._create_file('file', random_data) async with sftp.open('file', 'rb') as f: self.assertEqual(await f.read(), random_data) finally: remove('file') with patch('asyncssh.sftp.SFTPServerHandler', _ReorderReadServerHandler): # pylint: disable=no-value-for-parameter _test_read_out_of_order(self) @sftp_test async def test_open_read_nonexistent(self, sftp): """Test reading data from a nonexistent file""" f = None try: with self.assertRaises(SFTPNoSuchFile): f = await sftp.open('file') finally: if f: # pragma: no cover await f.close() @unittest.skipIf(sys.platform == 'win32', 'skip permission tests on Windows') @sftp_test async def test_open_read_not_permitted(self, sftp): """Test reading data from a file with no read permission""" f = None try: self._create_file('file', mode=0) with self.assertRaises(SFTPPermissionDenied): f = await sftp.open('file') finally: if f: # pragma: no cover await f.close() remove('file') @sftp_test async def test_open_write(self, sftp): """Test writing data to a file""" f = None try: f = await sftp.open('file', 'w') await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_write_bytes(self, sftp): """Test writing bytes to a file""" f = None try: f = await sftp.open('file', 'wb') await f.write(b'xxx') await f.close() with open('file', 'rb') as localf: self.assertEqual(localf.read(), b'xxx') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open_write_v6(self, sftp): """Test writing bytes to a file with SFTPv6 open""" f = None try: f = await sftp.open('file', 'wb') await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_write_v6(self, sftp): """Test writing bytes to a file with SFTPv6 open56""" f = None try: f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_CREATE_TRUNCATE) await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_truncate(self, sftp): """Test truncating a file at open time""" f = None try: self._create_file('file', 'xxxyyy') f = await sftp.open('file', 'w') await f.write('zzz') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzz') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open_truncate_v6(self, sftp): """Test truncating a file at open time with SFTPv6 open""" f = None try: self._create_file('file', 'xxxyyy') f = await sftp.open('file', FXF_WRITE | FXF_TRUNC) await f.write('zzz') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzz') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_truncate_v6(self, sftp): """Test truncating a file at open time with SFTPv6 open56""" f = None try: self._create_file('file', 'xxxyyy') f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_TRUNCATE_EXISTING) await f.write('zzz') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzz') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_append(self, sftp): """Test appending data to an existing file""" f = None try: self._create_file('file', 'xxx') f = await sftp.open('file', 'a+') await f.write('yyy') self.assertEqual((await f.read()), '') self.assertEqual([result async for result in await f.read_parallel()], []) await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open_append_v6(self, sftp): """Test appending data to an existing file with SFTPv6 open""" f = None try: self._create_file('file', 'xxx') f = await sftp.open('file', FXF_WRITE | FXF_APPEND) await f.write('yyy') self.assertEqual((await f.read()), '') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_append_v6(self, sftp): """Test appending data to an existing file with SFTPv6 open56""" f = None try: self._create_file('file', 'xxx') f = await sftp.open56('file', ACE4_READ_DATA | ACE4_WRITE_DATA | ACE4_APPEND_DATA, FXF_OPEN_EXISTING | FXF_APPEND_DATA) await f.write('yyy') self.assertEqual((await f.read()), '') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_exclusive_create(self, sftp): """Test creating a new file""" f = None try: f = await sftp.open('file', 'x') await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') with self.assertRaises(SFTPFailure): f = await sftp.open('file', 'x') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open_exclusive_create_v6(self, sftp): """Test creating a new file with SFTPv6 open""" f = None try: f = await sftp.open('file', 'x') await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') with self.assertRaises(SFTPFileAlreadyExists): f = await sftp.open('file', 'x') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_exclusive_create_v6(self, sftp): """Test creating a new file with SFTPv6 open56""" f = None try: f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_CREATE_NEW) await f.write('xxx') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxx') with self.assertRaises(SFTPFileAlreadyExists): f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_CREATE_NEW) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_exclusive_create_existing(self, sftp): """Test exclusive create of an existing file""" f = None try: self._create_file('file') with self.assertRaises(SFTPFailure): f = await sftp.open('file', 'x') finally: if f: # pragma: no cover await f.close() remove('file') @sftp_test_v4 async def test_open_exclusive_create_existing_v4(self, sftp): """Test exclusive create of an existing file with SFTPv4""" f = None try: self._create_file('file') with self.assertRaises(SFTPFileAlreadyExists): f = await sftp.open('file', 'x') finally: if f: # pragma: no cover await f.close() remove('file') @sftp_test_v6 async def test_open56_exclusive_create_existing_v6(self, sftp): """Test exclusive create of an existing file with SFTPv6 open56""" f = None try: self._create_file('file') with self.assertRaises(SFTPFileAlreadyExists): f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_CREATE_NEW) finally: if f: # pragma: no cover await f.close() remove('file') @sftp_test async def test_open_overwrite(self, sftp): """Test overwriting part of an existing file""" f = None try: self._create_file('file', 'xxxyyy') f = await sftp.open('file', 'r+') await f.write('zzz') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_overwrite_v6(self, sftp): """Test overwriting part of an existing file with SFTPv6 open56""" f = None try: self._create_file('file', 'xxxyyy') f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_OPEN_EXISTING) await f.write('zzz') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'zzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_overwrite_offset_size(self, sftp): """Test writing data at a specific offset""" f = None try: self._create_file('file', 'xxxxyyyy') f = await sftp.open('file', 'r+') await f.write('zz', 3) await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open_overwrite_offset_size_v6(self, sftp): """Test writing data at a specific offset with SFTPv6 open""" f = None try: self._create_file('file', 'xxxxyyyy') f = await sftp.open('file', FXF_WRITE | FXF_CREAT) await f.write('zz', 3) await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_open56_overwrite_offset_size_v6(self, sftp): """Test writing data at a specific offset with SFTPv6 open56""" f = None try: self._create_file('file', 'xxxxyyyy') f = await sftp.open56('file', ACE4_WRITE_DATA, FXF_OPEN_OR_CREATE) await f.write('zz', 3) await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_open_overwrite_nonexistent(self, sftp): """Test overwriting a nonexistent file""" f = None try: with self.assertRaises(SFTPNoSuchFile): f = await sftp.open('file', 'r+') finally: if f: # pragma: no cover await f.close() @sftp_test_v6 async def test_open_link_loop_v6(self, sftp): """Test opening a symlink which is a loop""" f = None try: os.symlink('link1', 'link2') os.symlink('link2', 'link1') with self.assertRaises((SFTPInvalidParameter, SFTPLinkLoop)): f = await sftp.open('link1') finally: if f: # pragma: no cover await f.close() remove('link1 link2') @sftp_test async def test_file_seek(self, sftp): """Test seeking within a file""" f = None try: f = await sftp.open('file', 'w+') await f.write('xxxxyyyy') await f.seek(3) await f.write('zz') await f.seek(-3, SEEK_CUR) self.assertEqual((await f.read(4)), 'xzzy') await f.seek(-4, SEEK_END) self.assertEqual((await f.read()), 'zyyy') self.assertEqual((await f.read()), '') self.assertEqual((await f.read(1)), '') with self.assertRaises(ValueError): await f.seek(0, -1) await f.close() f = await sftp.open('file', 'a+') await f.seek(-4, SEEK_CUR) self.assertEqual((await f.read()), 'zyyy') await f.close() with open('file') as localf: self.assertEqual(localf.read(), 'xxxzzyyy') finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_stat(self, sftp): """Test getting attributes on an open file""" f = None try: f = await sftp.open('file', 'w') self._check_stat((await f.stat()), os.stat('file')) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v4 async def test_file_stat_v4(self, sftp): """Test getting attributes on an open file with SFTPv4""" f = None try: f = await sftp.open('file', 'w') self._check_stat_v4((await f.stat()), os.stat('file')) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_file_stat_v6(self, sftp): """Test getting attributes on an open file with SFTPv6""" f = None try: f = await sftp.open('file', 'w') self._check_stat_v4((await f.stat()), os.stat('file')) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_setstat(self, sftp): """Test setting attributes on an open file""" f = None try: f = await sftp.open('file', 'w') await f.setstat(SFTPAttrs(permissions=0o666)) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o666) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_file_setstat_v6(self, sftp): """Test setting attributes on an open file with SFTPv6""" f = None try: f = await sftp.open('file', 'w') await f.setstat(SFTPAttrs(permissions=0o666)) self.assertEqual(stat.S_IMODE(os.stat('file').st_mode), 0o666) finally: if f: # pragma: no branch await f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') @sftp_test async def test_file_chown(self, sftp): """Test changing ownership of an open file We can't change to a different user/group here if we're not root, so just change to the same user/group. See the separate _TestSFTPChown class for a more complete test, but this is left in for code coverage purposes. """ f = None try: f = await sftp.open('file', 'w') stat_result = os.stat('file') await f.chown(stat_result.st_uid, stat_result.st_gid) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await f.chown(uid=stat_result.st_uid, gid=stat_result.st_gid) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) finally: if f: # pragma: no branch await f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') @sftp_test_v4 async def test_file_chown_v4(self, sftp): """Test changing ownership of an open file We can't change to a different user/group here if we're not root, so just change to the same user/group. See the separate _TestSFTPChown class for a more complete test, but this is left in for code coverage purposes. """ f = None try: f = await sftp.open('file', 'w') stat_result = os.stat('file') owner = lookup_user(stat_result.st_uid) group = lookup_group(stat_result.st_gid) await f.chown(owner, group) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) await f.chown(owner=owner, group=group) new_stat_result = os.stat('file') self.assertEqual(new_stat_result.st_uid, stat_result.st_uid) self.assertEqual(new_stat_result.st_gid, stat_result.st_gid) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_truncate(self, sftp): """Test truncating an open file""" f = None try: self._create_file('file', '01234567890123456789') f = await sftp.open('file', 'a+') await f.truncate(10) self.assertEqual((await f.tell()), 10) self.assertEqual((await f.read(offset=0)), '0123456789') self.assertEqual((await f.tell()), 10) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_utime(self, sftp): """Test changing access and modify times on an open file""" f = None try: f = await sftp.open('file', 'w') await f.utime() await f.utime((1, 2)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 1) self.assertEqual(stat_result.st_mtime, 2) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v4 async def test_file_utime_v4(self, sftp): """Test changing access and modify times on an open file with SFTPv4""" f = None try: f = await sftp.open('file', 'w') await f.utime() await f.utime((1.0, 2.25)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 1.0) self.assertEqual(stat_result.st_atime_ns, 1000000000) self.assertEqual(stat_result.st_mtime, 2.25) self.assertEqual(stat_result.st_mtime_ns, 2250000000) self.assertEqual((await sftp.getatime('file')), 1.0) self.assertEqual((await sftp.getatime_ns('file')), 1000000000) self.assertIsNotNone(await sftp.getcrtime('file')) self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 2.25) self.assertEqual((await sftp.getmtime_ns('file')), 2250000000) await f.utime('file', ns=(3500000000, 4750000000)) stat_result = os.stat('file') self.assertEqual(stat_result.st_atime, 3.5) self.assertEqual(stat_result.st_atime_ns, 3500000000) self.assertEqual(stat_result.st_mtime, 4.75) self.assertEqual(stat_result.st_mtime_ns, 4750000000) self.assertEqual((await sftp.getatime('file')), 3.5) self.assertEqual((await sftp.getatime_ns('file')), 3500000000) self.assertIsNotNone(await sftp.getcrtime('file')) self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 4.75) self.assertEqual((await sftp.getmtime_ns('file')), 4750000000) finally: if f: # pragma: no branch await f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @sftp_test async def test_file_statvfs(self, sftp): """Test getting attributes on the filesystem containing an open file We can't compare the values returned by a live statvfs call since they can change at any time. See the separate _TestSFTStatPVFS class for a more complete test, but this is left in for code coverage purposes. """ f = None try: f = await sftp.open('file', 'w') self.assertIsInstance((await f.statvfs()), SFTPVFSAttrs) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_lock(self, sftp): """Test file lock against earlier version SFTP server""" f = None try: f = await sftp.open('file', 'w') with self.assertRaises(SFTPOpUnsupported): await f.lock(0, 0, FXF_BLOCK_READ) with self.assertRaises(SFTPOpUnsupported): await f.unlock(0, 0) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_file_lock_v6(self, sftp): """Test file lock""" f = None try: f = await sftp.open('file', 'w') with self.assertRaises(SFTPOpUnsupported): await f.lock(0, 0, FXF_BLOCK_READ) with self.assertRaises(SFTPOpUnsupported): await f.unlock(0, 0) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_file_sync(self, sftp): """Test file sync""" f = None try: f = await sftp.open('file', 'w') self.assertIsNone(await f.fsync()) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_exited_session(self, sftp): """Test use of SFTP session after exit""" sftp.exit() await sftp.wait_closed() f = None try: with self.assertRaises(SFTPNoConnection): f = await sftp.open('file') finally: if f: # pragma: no cover await f.close() @sftp_test async def test_cleanup_open_files(self, sftp): """Test cleanup of open file handles on exit""" try: await sftp.open('file', 'w') finally: sftp.exit() await sftp.wait_closed() remove('file') @sftp_test async def test_invalid_open_mode(self, sftp): """Test opening file with invalid mode""" with self.assertRaises(ValueError): await sftp.open('file', 'z') @sftp_test async def test_invalid_open56(self, sftp): """Test calling open56 on an earlier version SFTP server""" with self.assertRaises(SFTPOpUnsupported): await sftp.open56('file', ACE4_WRITE_DATA, FXF_OPEN_OR_CREATE) @sftp_test_v6 async def test_invalid_access_flags_v6(self, sftp): """Test opening file with invalid access flags with SFTPv6""" with self.assertRaises(SFTPInvalidParameter): await sftp.open56('file', 0x80000000, FXF_OPEN_OR_CREATE) @sftp_test_v6 async def test_invalid_open_flags_v6(self, sftp): """Test opening file with invalid open flags with SFTPv6""" with self.assertRaises(SFTPInvalidParameter): await sftp.open56('file', ACE4_WRITE_DATA, 0x80000000) @sftp_test async def test_invalid_handle(self, sftp): """Test sending requests associated with an invalid file handle""" async def _return_invalid_handle(self, path, pflags, attrs): """Return an invalid file handle""" # pylint: disable=unused-argument return UInt32(0xffffffff) with patch('asyncssh.sftp.SFTPClientHandler.open', _return_invalid_handle): f = await sftp.open('file') with self.assertRaises(SFTPFailure): await f.read() with self.assertRaises(SFTPFailure): await f.read(1) with self.assertRaises(SFTPFailure): await f.write('') with self.assertRaises(SFTPFailure): await f.stat() with self.assertRaises(SFTPFailure): await f.setstat(SFTPAttrs()) if sys.platform != 'win32': # pragma: no branch with self.assertRaises(SFTPFailure): await f.statvfs() with self.assertRaises(SFTPFailure): await f.fsync() with self.assertRaises(SFTPFailure): await sftp.remote_copy(f, f) with self.assertRaises(SFTPFailure): await f.close() @sftp_test_v6 async def test_invalid_handle_v6(self, sftp): """Test sending requests associated with an invalid file handle""" async def _return_invalid_handle(self, path, pflags, attrs): """Return an invalid file handle""" # pylint: disable=unused-argument return UInt32(0xffffffff) with patch('asyncssh.sftp.SFTPClientHandler.open', _return_invalid_handle): f = await sftp.open('file') with self.assertRaises(SFTPInvalidHandle): await f.lock(0, 0, FXF_BLOCK_READ) with self.assertRaises(SFTPInvalidHandle): await f.unlock(0, 0) @sftp_test async def test_closed_file(self, sftp): """Test I/O operations on a closed file""" f = None try: self._create_file('file') async with sftp.open('file') as f: # Do an explicit close to test double-close await f.close() with self.assertRaises(ValueError): await f.read() with self.assertRaises(ValueError): await f.read_parallel() with self.assertRaises(ValueError): await f.write('') with self.assertRaises(ValueError): await f.seek(0) with self.assertRaises(ValueError): await f.tell() with self.assertRaises(ValueError): await f.stat() with self.assertRaises(ValueError): await f.setstat(SFTPAttrs()) with self.assertRaises(ValueError): await f.statvfs() with self.assertRaises(ValueError): await f.truncate() with self.assertRaises(ValueError): await f.chown(0, 0) with self.assertRaises(ValueError): await f.chmod(0) with self.assertRaises(ValueError): await f.utime() with self.assertRaises(ValueError): await f.lock(0, 0, FXF_BLOCK_READ) with self.assertRaises(ValueError): await f.unlock(0, 0) with self.assertRaises(ValueError): await f.fsync() finally: if f: # pragma: no branch await f.close() remove('file') def test_unexpected_client_close(self): """Test an unexpected connection close from client""" async def _unexpected_client_close(self): """Close the SSH connection before sending an init request""" self._writer.channel.get_connection().abort() with patch('asyncssh.sftp.SFTPClientHandler.start', _unexpected_client_close): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_unexpected_server_close(self): """Test an unexpected connection close from server""" async def _unexpected_server_close(self): """Close the SSH connection before sending a version response""" packet = await SFTPHandler.recv_packet(self) self._writer.channel.get_connection().abort() return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _unexpected_server_close): with self.assertRaises(SFTPConnectionLost): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_incomplete_message(self): """Test session cleanup in the middle of a write request""" with patch('asyncssh.sftp.SFTPServerHandler', _IncompleteMessageServerHandler): with self.assertRaises(SFTPConnectionLost): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_immediate_client_close(self): """Test closing SFTP channel immediately after opening""" async def _closing_start(self): """Immediately close the SFTP channel""" self.exit() with patch('asyncssh.sftp.SFTPClientHandler.start', _closing_start): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_no_init(self): """Test sending non-init request at start""" async def _no_init_start(self): """Send a non-init request at start""" self.send_packet(FXP_OPEN, 0, UInt32(0)) with patch('asyncssh.sftp.SFTPClientHandler.start', _no_init_start): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_incomplete_init_request(self): """Test sending init with missing version""" async def _missing_version_start(self): """Send an init request with missing version""" self.send_packet(FXP_INIT, None) with patch('asyncssh.sftp.SFTPClientHandler.start', _missing_version_start): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_incomplete_version_response(self): """Test sending an incomplete version response""" async def _incomplete_version_response(self): """Send an incomplete version response""" packet = await SFTPHandler.recv_packet(self) self.send_packet(FXP_VERSION, None) return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _incomplete_version_response): with self.assertRaises(SFTPBadMessage): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_nonstandard_version(self): """Test sending init with non-standard version""" with patch('asyncssh.sftp.MIN_SFTP_VERSION', 2): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_non_version_response(self): """Test sending a non-version message in response to init""" async def _non_version_response(self): """Send a non-version response to init""" packet = await SFTPHandler.recv_packet(self) self.send_packet(FXP_STATUS, None) return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _non_version_response): with self.assertRaises(SFTPBadMessage): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_unsupported_version_response(self): """Test sending an unsupported version in response to init""" async def _unsupported_version_response(self): """Send an unsupported version in response to init""" packet = await SFTPHandler.recv_packet(self) self.send_packet(FXP_VERSION, None, UInt32(99)) return packet with patch('asyncssh.sftp.SFTPServerHandler.recv_packet', _unsupported_version_response): with self.assertRaises(SFTPBadMessage): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_extension_in_init(self): """Test sending an extension in version 3 init request""" async def _init_extension_start(self): """Send an init request with missing version""" self.send_packet(FXP_INIT, None, UInt32(3), String(b'xxx'), String(b'1')) with patch('asyncssh.sftp.SFTPClientHandler.start', _init_extension_start): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_unknown_extension_response(self): """Test sending an unknown extension in version response""" with patch('asyncssh.sftp.SFTPServerHandler._extensions', [(b'xxx', b'1')]): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_empty_extension_response_v5(self): """Test sending an empty extension list in SFTPv5 version response""" with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter self._dummy_sftp_client_v5() def test_attrib_extension_response_v6(self): """Test sending an attrib extension in version response""" with patch('asyncssh.sftp.SFTPServerHandler._attrib_extensions', [b'xxx']): # pylint: disable=no-value-for-parameter self._dummy_sftp_client_v6() def test_close_after_init(self): """Test close immediately after init request at start""" async def _close_after_init_start(self): """Send a close immediately after init request at start""" self.send_packet(FXP_INIT, None, UInt32(3)) await self._cleanup(None) with patch('asyncssh.sftp.SFTPClientHandler.start', _close_after_init_start): # pylint: disable=no-value-for-parameter self._dummy_sftp_client() def test_file_handle_skip(self): """Test skipping over a file handle already in use""" @sftp_test async def _reset_file_handle(self, sftp): """Open multiple files, resetting next handle each time""" file1 = None file2 = None try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') file1 = await sftp.open('file1') file2 = await sftp.open('file2') self.assertEqual((await file1.read()), 'xxx') self.assertEqual((await file2.read()), 'yyy') finally: if file1: # pragma: no branch await file1.close() if file2: # pragma: no branch await file2.close() remove('file1 file2') with patch('asyncssh.sftp.SFTPServerHandler', _ResetFileHandleServerHandler): # pylint: disable=no-value-for-parameter _reset_file_handle(self) @sftp_test async def test_missing_request_pktid(self, sftp): """Test sending request without a packet ID""" async def _missing_pktid(self, filename, pflags, attrs): """Send a request without a packet ID""" # pylint: disable=unused-argument self.send_packet(FXP_OPEN, None) with patch('asyncssh.sftp.SFTPClientHandler.open', _missing_pktid): await sftp.open('file') @sftp_test async def test_malformed_open_request(self, sftp): """Test sending malformed open request""" async def _malformed_open(self, filename, pflags, attrs): """Send a malformed open request""" # pylint: disable=unused-argument return await self._make_request(FXP_OPEN) with patch('asyncssh.sftp.SFTPClientHandler.open', _malformed_open): with self.assertRaises(SFTPBadMessage): await sftp.open('file') @sftp_test async def test_unknown_request(self, sftp): """Test sending unknown request type""" async def _unknown_request(self, filename, pflags, attrs): """Send a request with an unknown type""" # pylint: disable=unused-argument return await self._make_request(0xff) with patch('asyncssh.sftp.SFTPClientHandler.open', _unknown_request): with self.assertRaises(SFTPOpUnsupported): await sftp.open('file') @sftp_test async def test_unrecognized_response_pktid(self, sftp): """Test sending a response with an unrecognized packet ID""" async def _unrecognized_response_pktid(self, pkttype, pktid, packet): """Send a response with an unrecognized packet ID""" # pylint: disable=unused-argument self.send_packet(FXP_HANDLE, 0xffffffff, UInt32(0xffffffff), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _unrecognized_response_pktid): with self.assertRaises(SFTPBadMessage): await sftp.open('file') @sftp_test async def test_bad_response_type(self, sftp): """Test sending a response with an incorrect response type""" async def _bad_response_type(self, pkttype, pktid, packet): """Send a response with an incorrect response type""" # pylint: disable=unused-argument self.send_packet(FXP_DATA, pktid, UInt32(pktid), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _bad_response_type): with self.assertRaises(SFTPBadMessage): await sftp.open('file') @sftp_test async def test_unexpected_ok_response(self, sftp): """Test sending an unexpected FX_OK response""" async def _unexpected_ok_response(self, pkttype, pktid, packet): """Send an unexpected FX_OK response""" # pylint: disable=unused-argument self.send_packet(FXP_STATUS, pktid, UInt32(pktid), UInt32(FX_OK), String(''), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _unexpected_ok_response): with self.assertRaises(SFTPBadMessage): await sftp.open('file') @sftp_test async def test_malformed_ok_response(self, sftp): """Test sending an FX_OK response containing invalid Unicode""" async def _malformed_ok_response(self, pkttype, pktid, packet): """Send an FX_OK response containing invalid Unicode""" # pylint: disable=unused-argument self.send_packet(FXP_STATUS, pktid, UInt32(pktid), UInt32(FX_OK), String(b'\xff'), String('')) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _malformed_ok_response): with self.assertRaises(SFTPBadMessage): await sftp.open('file') @sftp_test async def test_short_ok_response(self, sftp): """Test sending an FX_OK response without a reason and lang""" async def _short_ok_response(self, pkttype, pktid, packet): """Send an FX_OK response missing reason and lang""" # pylint: disable=unused-argument self.send_packet(FXP_STATUS, pktid, UInt32(pktid), UInt32(FX_OK)) with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _short_ok_response): self.assertIsNone(await sftp.mkdir('dir')) @sftp_test async def test_malformed_realpath_response(self, sftp): """Test receiving malformed realpath response""" async def _malformed_realpath(self, path): """Return a malformed realpath response""" # pylint: disable=unused-argument return [SFTPName(''), SFTPName('')], False with patch('asyncssh.sftp.SFTPClientHandler.realpath', _malformed_realpath): with self.assertRaises(SFTPBadMessage): await sftp.realpath('.') @sftp_test async def test_malformed_readlink_response(self, sftp): """Test receiving malformed readlink response""" async def _malformed_readlink(self, path): """Return a malformed readlink response""" # pylint: disable=unused-argument return [SFTPName(''), SFTPName('')], False with patch('asyncssh.sftp.SFTPClientHandler.readlink', _malformed_readlink): with self.assertRaises(SFTPBadMessage): await sftp.readlink('.') def test_unsupported_extensions(self): """Test using extensions on a server that doesn't support them""" @sftp_test async def _unsupported_extensions(self, sftp): """Try using unsupported extensions""" f = None try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') with self.assertRaises(SFTPOpUnsupported): await sftp.statvfs('.') f = await sftp.open('file1') with self.assertRaises(SFTPOpUnsupported): await f.statvfs() with self.assertRaises(SFTPOpUnsupported): await sftp.posix_rename('file1', 'file2') with self.assertRaises(SFTPOpUnsupported): await sftp.rename('file1', 'file2', flags=FXR_OVERWRITE) with self.assertRaises(SFTPOpUnsupported): await sftp.link('file1', 'file2') with self.assertRaises(SFTPOpUnsupported): await f.fsync() with self.assertRaises(SFTPOpUnsupported): await sftp.setstat('file1', SFTPAttrs(), follow_symlinks=False) finally: if f: # pragma: no branch await f.close() remove('file1') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _unsupported_extensions(self) def test_unsupported_extensions_v6(self): """Test using extensions on a server that doesn't support them""" @sftp_test_v6 async def _unsupported_extensions_v6(self, sftp): """Try using unsupported extensions""" try: self._create_file('file1', 'xxx') self._create_file('file2', 'yyy') self._create_file('file3', 'zzz') await sftp.posix_rename('file1', 'file2') with open('file2') as localf: self.assertEqual(localf.read(), 'xxx') await sftp.rename('file2', 'file3', FXR_OVERWRITE) with open('file3') as localf: self.assertEqual(localf.read(), 'xxx') await sftp.link('file3', 'file4') with open('file4') as localf: self.assertEqual(localf.read(), 'xxx') finally: remove('file1 file2 file3 file4') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _unsupported_extensions_v6(self) @asynctest async def test_zero_limits(self): """Test sending a server limits response with zero read/write length""" async def _send_zero_read_write_len(self, packet): """Send a server limits response with zero read/write length""" # pylint: disable=unused-argument return SFTPLimits(0, 0, 0, 0) with patch.dict('asyncssh.sftp.SFTPServerHandler._packet_handlers', {b'limits@openssh.com': _send_zero_read_write_len}): async with self.connect() as conn: async with conn.start_sftp_client() as sftp: self.assertEqual(sftp.limits.max_read_len, SAFE_SFTP_READ_LEN) self.assertEqual(sftp.limits.max_write_len, SAFE_SFTP_WRITE_LEN) def test_write_close(self): """Test session cleanup in the middle of a write request""" @sftp_test async def _write_close(self, sftp): """Initiate write that triggers cleanup""" try: async with sftp.open('file', 'w') as f: with self.assertRaises(SFTPConnectionLost): await f.write('a') finally: sftp.exit() remove('file') with patch('asyncssh.sftp.SFTPServerHandler', _WriteCloseServerHandler): # pylint: disable=no-value-for-parameter _write_close(self) @sftp_test_v4 async def test_write_protect_v4(self, sftp): """Test write protect error in SFTPv4""" def _write_error(self, file_obj, offset, data): """Return read-only FS error when writing to a file""" raise OSError(errno.EROFS, 'Read-only filesystem') try: with patch('asyncssh.sftp.SFTPServer.write', _write_error): with self.assertRaises(SFTPWriteProtect): async with sftp.open('file', 'wb') as f: await f.write(b'\0') finally: remove('file') @sftp_test_v4 async def test_no_media_v4(self, sftp): """Test no media error in SFTPv4""" def _write_error(self, file_obj, offset, data): """Return read-only FS error when writing to a file""" raise SFTPNoMedia('No media in requested drive') try: with patch('asyncssh.sftp.SFTPServer.write', _write_error): with self.assertRaises(SFTPNoMedia): async with sftp.open('file', 'wb') as f: await f.write(b'\0') finally: remove('file') @sftp_test_v5 async def test_no_space_v5(self, sftp): """Test no space on filesystem error in SFTPv5""" def _write_error(self, file_obj, offset, data): """Return no space error when writing to a file""" raise OSError(errno.ENOSPC, 'No space left on device') try: with patch('asyncssh.sftp.SFTPServer.write', _write_error): with self.assertRaises(SFTPNoSpaceOnFilesystem): async with sftp.open('file', 'wb') as f: await f.write(b'\0') finally: remove('file') @sftp_test_v5 async def test_quota_exceeded_v5(self, sftp): """Test quota exceeded error in SFTPv5""" def _write_error(self, file_obj, offset, data): """Return quota exceeded error when writing to a file""" raise OSError(errno.EDQUOT, 'Disk quota exceeded') try: with patch('asyncssh.sftp.SFTPServer.write', _write_error): with self.assertRaises(SFTPQuotaExceeded): async with sftp.open('file', 'wb') as f: await f.write(b'\0') finally: remove('file') @sftp_test_v5 async def test_unknown_principal_v5(self, sftp): """Test unknown principal error in SFTPv5""" def _open56_error(self, path, desired_access, flags, attrs): """Return unknown principal error when opening a file""" raise SFTPUnknownPrincipal('Unknown principal', unknown_names=(attrs.owner, attrs.group or b'\xff')) try: with patch('asyncssh.sftp.SFTPServer.open56', _open56_error): with self.assertRaises(SFTPUnknownPrincipal): await sftp.open('file', 'wb', SFTPAttrs(owner='aaa', group='bbb')) with self.assertRaises(SFTPBadMessage): await sftp.open('file', 'wb', SFTPAttrs(owner=b'aaa', group='')) finally: remove('file') @sftp_test_v5 async def test_lock_conflict_v5(self, sftp): """Test lock conflict error in SFTPv5""" def _open56_error(self, path, desired_access, flags, attrs): """Return lock conflict error when opening a file""" raise SFTPLockConflict('Lock conflict') try: with patch('asyncssh.sftp.SFTPServer.open56', _open56_error): with self.assertRaises(SFTPLockConflict): await sftp.open56('file', ACE4_WRITE_DATA, FXF_WRITE | FXF_CREATE_TRUNCATE) finally: remove('file') @sftp_test_v6 async def test_cannot_delete_v6(self, sftp): """Test cannot delete error in SFTPv6""" def _remove_error(self, path): """Return cannot delete error when removing a file""" raise SFTPCannotDelete('Cannot delete file') with patch('asyncssh.sftp.SFTPServer.remove', _remove_error): with self.assertRaises(SFTPCannotDelete): await sftp.remove('file') @sftp_test_v6 async def test_byte_range_lock_conflict_v6(self, sftp): """Test byte range lock conflict error in SFTPv6""" def _lock_error(self, file_obj, offset, length, flags): """Return byte range lock conflict error""" raise SFTPByteRangeLockConflict('Byte range lock conflict') f = None try: with patch('asyncssh.sftp.SFTPServer.lock', _lock_error): with self.assertRaises(SFTPByteRangeLockConflict): async with sftp.open('file', 'wb') as f: await f.lock(0, 0, FXF_BLOCK_READ) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_byte_range_lock_refused_v6(self, sftp): """Test byte range lock refused error in SFTPv6""" def _lock_error(self, file_obj, offset, length, flags): """Return byte range lock refused error""" raise SFTPByteRangeLockRefused('Byte range lock refused') f = None try: with patch('asyncssh.sftp.SFTPServer.lock', _lock_error): with self.assertRaises(SFTPByteRangeLockRefused): async with sftp.open('file', 'wb') as f: await f.lock(0, 0, FXF_BLOCK_READ) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test_v6 async def test_delete_pending_v6(self, sftp): """Test delete pending error in SFTPv6""" def _remove_error(self, path): """Return delete pending error when removing a file""" raise SFTPDeletePending('Delete of file is pending') with patch('asyncssh.sftp.SFTPServer.remove', _remove_error): with self.assertRaises(SFTPDeletePending): await sftp.remove('file') @sftp_test_v6 async def test_file_corrupt_v6(self, sftp): """Test file corrupt error in SFTPv6""" def _open56_error(self, path, desired_access, flags, attrs): """Return file corrupt error when opening a file""" raise SFTPFileCorrupt('Filesystem is corrupt') with patch('asyncssh.sftp.SFTPServer.open56', _open56_error): with self.assertRaises(SFTPFileCorrupt): await sftp.open('file') @sftp_test_v6 async def test_byte_range_unlock_mismatch_v6(self, sftp): """Test byte range unlock mismatch error in SFTPv6""" def _unlock_error(self, file_obj, offset, length): """Return byte range unlock mismatch error""" raise SFTPNoMatchingByteRangeLock('Byte range unlock mismatch') f = None try: with patch('asyncssh.sftp.SFTPServer.unlock', _unlock_error): with self.assertRaises(SFTPNoMatchingByteRangeLock): async with sftp.open('file', 'wb') as f: await f.unlock(0, 0) finally: if f: # pragma: no branch await f.close() remove('file') @sftp_test async def test_log_formatting(self, sftp): """Exercise log formatting of SFTP objects""" asyncssh.set_sftp_log_level('DEBUG') with self.assertLogs(level='DEBUG'): await sftp.realpath('.') await sftp.stat('.') if sys.platform != 'win32': # pragma: no cover await sftp.statvfs('.') asyncssh.set_sftp_log_level('WARNING') @sftp_test async def test_makedirs_no_parent_perms(self, sftp): """Test creating a directory path without perms for a parent dir""" orig_mkdir = sftp.mkdir def _mkdir(path, *args, **kwargs): if path == b'/': raise SFTPPermissionDenied('') return orig_mkdir(path, *args, **kwargs) try: root = os.path.abspath(os.getcwd()) with patch.object(sftp, 'mkdir', _mkdir): await sftp.makedirs(os.path.join(root, 'dir/dir1')) self.assertTrue(os.path.isdir(os.path.join(root, 'dir/dir1'))) finally: remove('dir') @sftp_test async def test_makedirs_no_perms(self, sftp): """Test creating a directory path without perms for all parents""" root = os.path.abspath(os.getcwd()) with patch.object(sftp, 'mkdir', side_effect=SFTPPermissionDenied('')): with self.assertRaises(SFTPPermissionDenied): await sftp.makedirs(os.path.join(root, 'dir/dir1')) class _TestSFTPCallable(_CheckSFTP): """Unit tests for AsyncSSH SFTP factory being a callable""" @classmethod async def start_server(cls): """Start an SFTP server using a callable""" def sftp_factory(conn): """Return an SFTP server""" return SFTPServer(conn) return await cls.create_server(sftp_factory=sftp_factory) @sftp_test async def test_stat(self, sftp): """Test getting attributes on a file""" # pylint: disable=no-self-use await sftp.stat('.') class _TestSFTPServerProperties(_CheckSFTP): """Unit test for checking SFTP server properties""" @classmethod async def start_server(cls): """Start an SFTP server which checks channel properties""" return await cls.create_server(sftp_factory=_CheckPropSFTPServer) @asynctest async def test_properties(self): """Test SFTP server channel properties""" async with self.connect() as conn: async with conn.start_sftp_client(env={'A': 1, 'B': 2}) as sftp: files = await sftp.listdir() self.assertEqual(sorted(files), ['A', 'B']) class _TestSFTPChroot(_CheckSFTP): """Unit test for SFTP server with changed root""" @classmethod async def start_server(cls): """Start an SFTP server with a changed root""" return await cls.create_server(sftp_factory=_ChrootSFTPServer, sftp_version=6) @sftp_test async def test_chroot_copy(self, sftp): """Test copying a file to an FTP server with a changed root""" try: self._create_file('src') await sftp.put('src', 'dst') self._check_file('src', 'chroot/dst') finally: remove('src chroot/dst') @sftp_test async def test_chroot_glob(self, sftp): """Test a glob pattern match over SFTP with a changed root""" try: self._create_file('chroot/file1') self._create_file('chroot/file2') self.assertEqual(sorted(await sftp.glob('/file*')), ['/file1', '/file2']) finally: remove('chroot/file1 chroot/file2') @sftp_test async def test_chroot_realpath(self, sftp): """Test canonicalizing a path on an SFTP server with a changed root""" self.assertEqual((await sftp.realpath('/dir/../file')), '/file') self._create_file('chroot/file1') name = await sftp.realpath('/dir/..', 'file1', check=FXRP_STAT_IF_EXISTS) self.assertEqual(name.attrs.type, FILEXFER_TYPE_REGULAR) name = await sftp.realpath('/dir/..', 'file2', FXRP_STAT_IF_EXISTS) self.assertEqual(name.attrs.type, FILEXFER_TYPE_UNKNOWN) with self.assertRaises(SFTPNoSuchFile): await sftp.realpath('/dir', '..', 'file2', check=FXRP_STAT_ALWAYS) @sftp_test_v6 async def test_chroot_realpath_v6(self, sftp): """Test canonicalizing a path on an SFTP server with a changed root""" self.assertEqual((await sftp.realpath('/dir/../file')), '/file') self._create_file('chroot/file1') name = await sftp.realpath('/dir/..', 'file1', FXRP_STAT_IF_EXISTS) self.assertEqual(name.attrs.type, FILEXFER_TYPE_REGULAR) name = await sftp.realpath('/dir/..', 'file2', check=FXRP_STAT_IF_EXISTS) self.assertEqual(name.attrs.type, FILEXFER_TYPE_UNKNOWN) with self.assertRaises(SFTPNoSuchFile): await sftp.realpath('/dir', '..', 'file2', check=FXRP_STAT_ALWAYS) with self.assertRaises(SFTPInvalidParameter): await sftp.realpath('.', check=99) @sftp_test async def test_getcwd_and_chdir(self, sftp): """Test changing directory on an SFTP server with a changed root""" try: os.mkdir('chroot/dir') self.assertEqual((await sftp.getcwd()), '/') await sftp.chdir('dir') self.assertEqual((await sftp.getcwd()), '/dir') finally: remove('chroot/dir') @sftp_test async def test_chroot_readlink(self, sftp): """Test reading symlinks on an FTP server with a changed root""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: root = os.path.join(os.getcwd(), 'chroot') os.symlink(root, 'chroot/link1') os.symlink(os.path.join(root, 'file'), 'chroot/link2') os.symlink('/xxx', 'chroot/link3') self.assertEqual((await sftp.readlink('link1')), '/') self.assertEqual((await sftp.readlink('link2')), '/file') with self.assertRaises(SFTPNoSuchFile): await sftp.readlink('link3') finally: remove('chroot/link1 chroot/link2 chroot/link3') @sftp_test async def test_chroot_symlink(self, sftp): """Test setting a symlink on an SFTP server with a changed root""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: await sftp.symlink('/file', 'link1') await sftp.symlink('../../file', 'link2') self._check_link('chroot/link1', os.path.abspath('chroot/file')) self._check_link('chroot/link2', 'file') finally: remove('chroot/link1 chroot/link2') @sftp_test async def test_chroot_makedirs(self, sftp): """Test creating a directory path""" try: await sftp.makedirs('dir/dir1') self.assertTrue(os.path.isdir('chroot/dir')) self.assertTrue(os.path.isdir('chroot/dir/dir1')) await sftp.makedirs('dir/dir2') self.assertTrue(os.path.isdir('chroot/dir/dir2')) await sftp.makedirs('dir/dir2', exist_ok=True) self.assertTrue(os.path.isdir('chroot/dir/dir2')) with self.assertRaises(SFTPFailure): await sftp.makedirs('/dir/dir2') self._create_file('chroot/file') with self.assertRaises(SFTPFailure): await sftp.makedirs('file/dir') finally: remove('chroot/dir') @sftp_test_v6 async def test_chroot_makedirs_v6(self, sftp): """Test creating a directory path with SFTPv6""" try: await sftp.makedirs('dir/dir1') self.assertTrue(os.path.isdir('chroot/dir')) self.assertTrue(os.path.isdir('chroot/dir/dir1')) await sftp.makedirs('dir/dir2') self.assertTrue(os.path.isdir('chroot/dir/dir2')) await sftp.makedirs('dir/dir2', exist_ok=True) self.assertTrue(os.path.isdir('chroot/dir/dir2')) with self.assertRaises(SFTPFileAlreadyExists): await sftp.makedirs('/dir/dir2') self._create_file('chroot/file') with self.assertRaises(SFTPNotADirectory): await sftp.makedirs('file/dir') finally: remove('chroot/dir') class _TestSFTPReadEOFWithAttrs(_CheckSFTP): """Unit test for SFTP server read EOF flags with SFTPAttrs from fstat""" @classmethod async def start_server(cls): """Start an SFTP server which returns SFTPAttrs on fstat""" return await cls.create_server(sftp_factory=_SFTPAttrsSFTPServer, sftp_version=6) @sftp_test_v6 async def test_get(self, sftp): """Test copying a file over SFTP""" try: self._create_file('src') await sftp.get('src', 'dst') self._check_file('src', 'dst') finally: remove('src dst') class _TestSFTPUnknownError(_CheckSFTP): """Unit test for SFTP server returning unknown error""" @classmethod async def start_server(cls): """Start an SFTP server which returns unknown error""" return await cls.create_server(sftp_factory=_SFTPAttrsSFTPServer) @sftp_test async def test_stat_error(self, sftp): """Test error when getting attributes of a file on an SFTP server""" with self.assertRaises(SFTPError) as exc: await sftp.stat('file') self.assertEqual(exc.exception.code, 99) class _TestSFTPOpenError(_CheckSFTP): """Unit test for SFTP server returning error on file open""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" return await cls.create_server(sftp_factory=_OpenErrorSFTPServer, sftp_version=6) @sftp_test_v6 async def test_open_error_v6(self, sftp): """Test error when opening a file on an SFTP server""" with self.assertRaises(SFTPInvalidFilename): await sftp.open('ENAMETOOLONG') with self.assertRaises(SFTPInvalidParameter): await sftp.open('EINVAL') with self.assertRaises(SFTPFailure): await sftp.open('ENXIO') class _TestSFTPIOError(_CheckSFTP): """Unit test for SFTP server returning file I/O error""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" return await cls.create_server(sftp_factory=_IOErrorSFTPServer) def test_copy_error(self): """Test error when copying a file on an SFTP server""" @sftp_test async def _test_copy_error(self, sftp): """Test error when copying a file on an SFTP server""" try: self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await sftp.copy('src', 'dst') finally: remove('src dst') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter _test_copy_error(self) @sftp_test async def test_read_error(self, sftp): """Test error when reading a file on an SFTP server""" try: self._create_file('file', 8*1024*1024*'\0') async with sftp.open('file') as f: with self.assertRaises(SFTPFailure): await f.read(8*1024*1024) with self.assertRaises(SFTPFailure): async for _ in await f.read_parallel(8*1024*1024): pass finally: remove('file') @sftp_test async def test_write_error(self, sftp): """Test error when writing a file on an SFTP server""" try: with self.assertRaises(SFTPFailure): async with sftp.open('file', 'w') as f: await f.write(8*1024*1024*'\0') finally: remove('file') class _TestSFTPSmallBlockSize(_CheckSFTP): """Unit test for SFTP server returning file I/O error""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" return (await cls.create_server( sftp_factory=_SmallBlockSizeSFTPServer)) @sftp_test async def test_read(self, sftp): """Test a large read on a server with a small block size""" try: data = os.urandom(65536) self._create_file('file', data) async with sftp.open('file', 'rb', block_size=16384) as f: result = await f.read(65536, 16384) self.assertEqual(result, data[16384:]) finally: remove('file') @sftp_test async def test_get(self, sftp): """Test getting a file from an SFTP server with a small block size""" try: data = os.urandom(8*1024*1024) self._create_file('src', data) await sftp.get('src', 'dst') self._check_file('src', 'dst') finally: remove('src dst') class _TestSFTPEOFDuringCopy(_CheckSFTP): """Unit test for SFTP server returning EOF during a file copy""" @classmethod async def start_server(cls): """Start an SFTP server which truncates files when accessed""" return await cls.create_server(sftp_factory=_TruncateSFTPServer) @sftp_test async def test_get(self, sftp): """Test getting a file from an SFTP server truncated during the copy""" try: self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await sftp.get('src', 'dst') finally: remove('src dst') class _TestSFTPNotImplemented(_CheckSFTP): """Unit test for SFTP server returning not-implemented error""" @classmethod async def start_server(cls): """Start an SFTP server which returns not-implemented errors""" return await cls.create_server(sftp_factory=_NotImplSFTPServer) @sftp_test async def test_symlink_error(self, sftp): """Test error when creating a symbolic link on an SFTP server""" with self.assertRaises(SFTPOpUnsupported): await sftp.symlink('file', 'link') class _TestSFTPFileType(_CheckSFTP): """Unit test for SFTP server formatting directory listings""" @classmethod async def start_server(cls): """Start an SFTP server which returns a fixed directory listing""" return await cls.create_server(sftp_factory=_FileTypeSFTPServer) @sftp_test async def test_filetype(self, sftp): """Test permission to filetype conversion in SFTP readdir call""" for file in await sftp.readdir('/'): self.assertEqual(file.filename, str(file.attrs.type)) class _TestSFTPLongname(_CheckSFTP): """Unit test for SFTP server formatting directory listings""" @classmethod async def start_server(cls): """Start an SFTP server which returns a fixed directory listing""" return await cls.create_server(sftp_factory=_LongnameSFTPServer) @sftp_test async def test_longname(self, sftp): """Test long name formatting in SFTP readdir call""" for file in await sftp.readdir('/'): self.assertEqual(file.longname[56:], file.filename) @sftp_test async def test_glob_hidden(self, sftp): """Test a glob pattern match on hidden files""" self.assertEqual((await sftp.glob('/.*')), ['/.file']) @unittest.skipIf(sys.platform == 'win32', 'skip uid/gid tests on Windows') @sftp_test async def test_getpwuid_error(self, sftp): """Test long name formatting where user name can't be resolved""" with patch('pwd.getpwuid', _getpwuid_error): result = await sftp.readdir('/') self.assertEqual(result[3].longname[16:24], ' ') self.assertEqual(result[4].longname[16:24], '0 ') @unittest.skipIf(sys.platform == 'win32', 'skip uid/gid tests on Windows') @sftp_test async def test_getgrgid_error(self, sftp): """Test long name formatting where group name can't be resolved""" with patch('grp.getgrgid', _getgrgid_error): result = await sftp.readdir('/') self.assertEqual(result[3].longname[25:33], ' ') self.assertEqual(result[4].longname[25:33], '0 ') @sftp_test async def test_strftime_error(self, sftp): """Test long name formatting with strftime not supporting %e""" orig_strftime = time.strftime def strftime_error(fmt, t): """Simulate Windows srtftime that doesn't support %e""" if '%e' in fmt: raise ValueError else: return orig_strftime(fmt, t) with patch('time.strftime', strftime_error): result = await sftp.readdir('/') self.assertEqual(result[3].longname[51:55], ' ') self.assertIn(result[4].longname[51:55], ('1969', '1970')) class _TestSFTPLargeListDir(_CheckSFTP): """Unit test for SFTP server returning large listdir result""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" return await cls.create_server(sftp_factory=_LargeDirSFTPServer) @sftp_test async def test_large_listdir(self, sftp): """Test large listdir result""" self.assertEqual(len(await sftp.readdir('/')), 100000) @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') class _TestSFTPStatVFS(_CheckSFTP): """Unit test for SFTP server filesystem attributes""" @classmethod async def start_server(cls): """Start an SFTP server which returns fixed filesystem attrs""" return await cls.create_server(sftp_factory=_StatVFSSFTPServer) def _check_statvfs(self, sftp_statvfs): """Check if filesystem attributes are equal""" expected_statvfs = _StatVFSSFTPServer.expected_statvfs self.assertEqual(sftp_statvfs.bsize, expected_statvfs.bsize) self.assertEqual(sftp_statvfs.frsize, expected_statvfs.frsize) self.assertEqual(sftp_statvfs.blocks, expected_statvfs.blocks) self.assertEqual(sftp_statvfs.bfree, expected_statvfs.bfree) self.assertEqual(sftp_statvfs.bavail, expected_statvfs.bavail) self.assertEqual(sftp_statvfs.files, expected_statvfs.files) self.assertEqual(sftp_statvfs.ffree, expected_statvfs.ffree) self.assertEqual(sftp_statvfs.favail, expected_statvfs.favail) self.assertEqual(sftp_statvfs.fsid, expected_statvfs.fsid) self.assertEqual(sftp_statvfs.flags, expected_statvfs.flags) self.assertEqual(sftp_statvfs.namemax, expected_statvfs.namemax) self.assertEqual(repr(sftp_statvfs), repr(expected_statvfs)) @sftp_test async def test_statvfs(self, sftp): """Test getting attributes on a filesystem""" self._check_statvfs(await sftp.statvfs('.')) @sftp_test async def test_file_statvfs(self, sftp): """Test getting attributes on the filesystem containing an open file""" f = None try: self._create_file('file') f = await sftp.open('file') self._check_statvfs(await f.statvfs()) finally: if f: # pragma: no branch await f.close() remove('file') @unittest.skipIf(sys.platform == 'win32', 'skip chown tests on Windows') class _TestSFTPChown(_CheckSFTP): """Unit test for SFTP server file ownership""" @classmethod async def start_server(cls): """Start an SFTP server which simulates file ownership changes""" return await cls.create_server(sftp_factory=_ChownSFTPServer, sftp_version=6) @sftp_test async def test_chown(self, sftp): """Test changing ownership of a file""" try: self._create_file('file') await sftp.chown('file', 1, 2) attrs = await sftp.stat('file') self.assertEqual(attrs.uid, 1) self.assertEqual(attrs.gid, 2) finally: remove('file') @sftp_test_v4 async def test_chown_v4(self, sftp): """Test changing ownership of a file with SFTPv4""" try: self._create_file('file') await sftp.chown('file', owner='root', group='wheel') attrs = await sftp.stat('file') self.assertEqual(attrs.owner, 'root') self.assertEqual(attrs.group, 'wheel') finally: remove('file') class _TestSFTPAttrs(unittest.TestCase): """Unit test for SFTPAttrs object""" def test_attrs(self): """Test encoding and decoding of SFTP attributes""" for kwargs in ({'size': 1234}, {'uid': 1, 'gid': 2}, {'permissions': 0o7777}, {'atime': 1, 'mtime': 2}, {'extended': [(b'a1', b'v1'), (b'a2', b'v2')]}): attrs = SFTPAttrs(**kwargs) packet = SSHPacket(attrs.encode(3)) self.assertEqual(repr(SFTPAttrs.decode(packet, 3)), repr(attrs)) for kwargs in ({'type': FILEXFER_TYPE_REGULAR}, {'size': 1234}, {'owner': 'a', 'group': 'b'}, {'permissions': 0o7777}, {'atime': 1, 'atime_ns': 2}, {'crtime': 3, 'crtime_ns': 4}, {'mtime': 5, 'mtime_ns': 6}, {'atime': 7, 'crtime': 8, 'mtime': 9}, {'acl': b''}): attrs = SFTPAttrs(**kwargs) packet = SSHPacket(attrs.encode(4)) self.assertEqual(repr(SFTPAttrs.decode(packet, 4)), repr(attrs)) packet = SSHPacket(SFTPAttrs(uid=1, gid=2).encode(4)) self.assertEqual(repr(SFTPAttrs.decode(packet, 4)), repr(SFTPAttrs(owner='1', group='2'))) for kwargs in ({'type': FILEXFER_TYPE_REGULAR}, {'size': 1234}, {'owner': 'a', 'group': 'b'}, {'permissions': 0o7777}, {'atime': 1, 'atime_ns': 2}, {'crtime': 3, 'crtime_ns': 4}, {'mtime': 5, 'mtime_ns': 6}, {'atime': 7, 'crtime': 8, 'mtime': 9}, {'acl': b''}, {'attrib_bits': FILEXFER_ATTR_BITS_READONLY, 'attrib_valid': FILEXFER_ATTR_BITS_READONLY}): attrs = SFTPAttrs(**kwargs) packet = SSHPacket(attrs.encode(5)) self.assertEqual(repr(SFTPAttrs.decode(packet, 5)), repr(attrs)) for kwargs in ({'type': FILEXFER_TYPE_REGULAR}, {'size': 1234, 'alloc_size': 5678}, {'owner': 'a', 'group': 'b'}, {'permissions': 0o7777}, {'atime': 1, 'atime_ns': 2}, {'crtime': 3, 'crtime_ns': 4}, {'mtime': 5, 'mtime_ns': 6}, {'ctime': 7, 'ctime_ns': 8}, {'atime': 7, 'crtime': 8, 'mtime': 9, 'ctime': 10}, {'acl': b''}, {'attrib_bits': FILEXFER_ATTR_BITS_READONLY, 'attrib_valid': FILEXFER_ATTR_BITS_READONLY}, {'text_hint': FILEXFER_ATTR_KNOWN_TEXT}, {'mime_type': 'application/octet-stream'}, {'untrans_name': b'\xff'}, {'extended': [(b'a1', b'v1'), (b'a2', b'v2')]}): attrs = SFTPAttrs(**kwargs) packet = SSHPacket(attrs.encode(6)) self.assertEqual(repr(SFTPAttrs.decode(packet, 6)), repr(attrs)) def test_illegal_attrs(self): """Test decoding illegal SFTP attributes value""" with self.assertRaises(SFTPBadMessage): SFTPAttrs.decode(SSHPacket(UInt32(FILEXFER_ATTR_OWNERGROUP)), 3) for version in range(4, 7): with self.assertRaises(SFTPBadMessage): SFTPAttrs.decode(SSHPacket( UInt32(FILEXFER_ATTR_UIDGID)), version) with self.assertRaises(SFTPOwnerInvalid): SFTPAttrs.decode(SSHPacket( SFTPAttrs(owner=b'\xff', group='').encode(6)), 6) with self.assertRaises(SFTPGroupInvalid): SFTPAttrs.decode(SSHPacket( SFTPAttrs(owner='', group=b'\xff').encode(6)), 6) with self.assertRaises(SFTPBadMessage): SFTPAttrs.decode(SSHPacket( SFTPAttrs(mime_type=b'\xff').encode(6)), 6) class _TestSFTPNonstandardSymlink(_CheckSFTP): """Unit tests for SFTP server with non-standard symlink order""" @classmethod async def start_server(cls): """Start an SFTP server for the tests to use""" return await cls.create_server(server_version='OpenSSH', sftp_factory=_SymlinkSFTPServer) @asynctest async def test_nonstandard_symlink_client(self): """Test creating a symlink with opposite argument order""" if not self._symlink_supported: # pragma: no cover raise unittest.SkipTest('symlink not available') try: async with self.connect(client_version='OpenSSH') as conn: async with conn.start_sftp_client() as sftp: await sftp.symlink('link', 'file') self._check_link('link', 'file') finally: remove('file link') class _TestSFTPAsync(_TestSFTP): """Unit test for an async SFTPServer""" @classmethod async def start_server(cls): """Start an SFTP server with async callbacks""" return await cls.create_server(sftp_factory=_AsyncSFTPServer, sftp_version=6) @sftp_test async def test_async_realpath(self, sftp): """Test canonicalizing a path on an async SFTP server""" self.assertEqual((await sftp.realpath('dir/../file')), posixpath.join((await sftp.getcwd()), 'file')) @sftp_test_v6 async def test_async_realpath_v6(self, sftp): """Test canonicalizing a path on an async SFTPv6 server""" self._create_file('file1') self.assertEqual((await sftp.realpath('dir/../file')), posixpath.join((await sftp.getcwd()), 'file')) name = await sftp.realpath('dir/../file1', check=FXRP_STAT_ALWAYS) self.assertEqual(name.attrs.type, FILEXFER_TYPE_REGULAR) class _CheckSCP(_CheckSFTP): """Utility functions for AsyncSSH SCP unit tests""" @classmethod async def asyncSetUpClass(cls): """Set up SCP target host/port tuple""" await super().asyncSetUpClass() cls._scp_server = (cls._server_addr, cls._server_port) @classmethod async def start_server(cls): """Start an SFTP server with SCP enabled for the tests to use""" return await cls.create_server(sftp_factory=True, allow_scp=True) async def _check_scp(self, src, dst, data=(), **kwargs): """Check copying a file over SCP""" try: self._create_file('src', data) await scp(src, dst, **kwargs) self._check_file('src', 'dst') finally: remove('src dst') async def _check_progress(self, src, dst): """Check copying a file over SCP with progress reporting""" def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes): """Monitor progress of copy""" reports.append(bytes_copied) for size in (0, 100000): with self.subTest(size=size): reports = [] await self._check_scp(src, dst, size * 'a', block_size=8192, progress_handler=_report_progress) self.assertEqual(len(reports), (size // 8192) + 1) self.assertEqual(reports[-1], size) async def _check_cancel(self, src, dst): """Check cancelling a file transfer over SCP""" def _cancel(_srcpath, _dstpath, _bytes_copied, _total_bytes): """Cancel transfer""" task.cancel() try: self._create_file('src', 1024*8192 * 'a') coro = scp(src, dst, block_size=8192, progress_handler=_cancel) task = asyncio.create_task(coro) await task finally: remove('src dst') class _TestSCP(_CheckSCP): """Unit tests for AsyncSSH SCP client and server""" @asynctest async def test_get(self): """Test getting a file over SCP""" for src in ('src', b'src', Path('src')): for dst in ('dst', b'dst', Path('dst')): with self.subTest(src=type(src), dst=type(dst)): await self._check_scp((self._scp_server, src), dst) @asynctest async def test_get_progress(self): """Test getting a file over SCP with progress reporting""" await self._check_progress((self._scp_server, 'src'), 'dst') @asynctest async def test_get_cancel(self): """Test cancelling a get of a file over SCP""" await self._check_cancel((self._scp_server, 'src'), 'dst') @asynctest async def test_get_preserve(self): """Test getting a file with preserved attributes over SCP""" try: self._create_file('src', utime=(1, 2)) await scp((self._scp_server, 'src'), 'dst', preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest async def test_get_recurse(self): """Test recursively getting a directory over SCP""" try: os.mkdir('src') self._create_file('src/file1') await scp((self._scp_server, 'src'), 'dst', recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest async def test_get_error_handler(self): """Test getting multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'scp: Not a regular file: src2') try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') await scp((self._scp_server, 'src*'), 'dst', error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest async def test_get_recurse_existing(self): """Test getting a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') os.mkdir('dst/src') self._create_file('src/file1') await scp((self._scp_server, 'src'), 'dst', recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @unittest.skipIf(sys.platform == 'win32', 'skip permission tests on Windows') @asynctest async def test_get_not_permitted(self): """Test getting a file with no read permissions over SCP""" try: self._create_file('src', mode=0) with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), 'dst') finally: remove('src dst') @asynctest async def test_get_directory_as_file(self): """Test getting a file which is actually a directory over SCP""" try: os.mkdir('src') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), 'dst') finally: remove('src dst') @asynctest async def test_get_non_directory_in_path(self): """Test getting a file with a non-directory in path over SCP""" try: self._create_file('src') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src/xxx'), 'dst') finally: remove('src dst') @asynctest async def test_get_recurse_not_directory(self): """Test getting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), 'dst', recurse=True) finally: remove('src dst') @asynctest async def test_put(self): """Test putting a file over SCP""" for src in ('src', b'src', Path('src')): for dst in ('dst', b'dst', Path('dst')): with self.subTest(src=type(src), dst=type(dst)): await self._check_scp(src, (self._scp_server, dst)) @asynctest async def test_put_progress(self): """Test putting a file over SCP with progress reporting""" await self._check_progress('src', (self._scp_server, 'dst')) @asynctest async def test_put_cancel(self): """Test cancelling a put of a file over SCP""" await self._check_cancel('src', (self._scp_server, 'dst')) @asynctest async def test_put_preserve(self): """Test putting a file with preserved attributes over SCP""" try: self._create_file('src', utime=(1, 2)) await scp('src', (self._scp_server, 'dst'), preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest async def test_put_recurse(self): """Test recursively putting a directory over SCP""" try: os.mkdir('src') self._create_file('src/file1') await scp('src', (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest async def test_put_recurse_existing(self): """Test putting a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') self._create_file('src/file1') await scp('src', (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @asynctest async def test_put_must_be_dir(self): """Test putting multiple files to a non-directory over SCP""" try: self._create_file('src1') self._create_file('src2') self._create_file('dst') with self.assertRaises(SFTPFailure): await scp(['src1', 'src2'], (self._scp_server, 'dst')) finally: remove('src1 src2 dst') @asynctest async def test_put_non_directory_in_path(self): """Test putting a file with a non-directory in path over SCP""" try: self._create_file('src') with self.assertRaises(OSError): await scp('src/xxx', (self._scp_server, 'dst')) finally: remove('src') @asynctest async def test_put_recurse_not_directory(self): """Test putting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest async def test_put_read_error(self): """Test read errors when putting a file over SCP""" async def _read_error(self, size, offset): """Return an error for reads past 4 MB in a file""" if offset >= 4*1024*1024: raise OSError(errno.EIO, 'I/O error') else: return await orig_read(self, size, offset) try: self._create_file('src', 8*1024*1024*'\0') orig_read = LocalFile.read with patch('asyncssh.sftp.LocalFile.read', _read_error): with self.assertRaises(OSError): await scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest async def test_put_read_early_eof(self): """Test getting early EOF when putting a file over SCP""" async def _read_early_eof(self, size, offset): """Return an early EOF for reads past 4 MB in a file""" if offset >= 4*1024*1024: return b'' else: return await orig_read(self, size, offset) try: self._create_file('src', 8*1024*1024*'\0') orig_read = LocalFile.read with patch('asyncssh.sftp.LocalFile.read', _read_early_eof): with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest async def test_put_name_too_long(self): """Test putting a file over SCP with too long a name""" try: self._create_file('src') with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 256*'a')) finally: remove('src dst') @asynctest async def test_copy(self): """Test copying a file between remote hosts over SCP""" for src in ('src', b'src', Path('src')): for dst in ('dst', b'dst', Path('dst')): with self.subTest(src=type(src), dst=type(dst)): await self._check_scp((self._scp_server, src), (self._scp_server, dst)) @asynctest async def test_copy_progress(self): """Test copying a file over SCP with progress reporting""" await self._check_progress((self._scp_server, 'src'), (self._scp_server, 'dst')) @asynctest async def test_copy_cancel(self): """Test cancelling a copy of a file over SCP""" await self._check_cancel((self._scp_server, 'src'), (self._scp_server, 'dst')) @asynctest async def test_copy_preserve(self): """Test copying a file with preserved attributes between hosts""" try: self._create_file('src', utime=(1, 2)) await scp((self._scp_server, 'src'), (self._scp_server, 'dst'), preserve=True) self._check_file('src', 'dst', preserve=True, check_atime=False) finally: remove('src dst') @asynctest async def test_copy_recurse(self): """Test recursively copying a directory between hosts over SCP""" try: os.mkdir('src') self._create_file('src/file1') await scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/file1') finally: remove('src dst') @asynctest async def test_copy_error_handler_source(self): """Test copying multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" self.assertEqual(exc.reason, 'scp: Not a regular file: src2') try: self._create_file('src1') os.mkdir('src2') os.mkdir('dst') await scp((self._scp_server, 'src*'), (self._scp_server, 'dst'), error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest async def test_copy_error_handler_sink(self): """Test copying multiple files over SCP with error handler""" def err_handler(exc): """Catch error for non-recursive copy of directory""" if sys.platform == 'win32': # pragma: no cover self.assertEqual(exc.reason, 'scp: Permission denied: dst\\src2') else: self.assertEqual(exc.reason, 'scp: Is a directory: dst/src2') try: self._create_file('src1') self._create_file('src2') os.mkdir('dst') os.mkdir('dst/src2') await scp((self._scp_server, 'src*'), (self._scp_server, 'dst'), error_handler=err_handler) self._check_file('src1', 'dst/src1') finally: remove('src1 src2 dst') @asynctest async def test_copy_recurse_existing(self): """Test copying a directory over SCP where target dir exists""" try: os.mkdir('src') os.mkdir('dst') self._create_file('src/file1') await scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) self._check_file('src/file1', 'dst/src/file1') finally: remove('src dst') @asynctest async def test_local_copy(self): """Test for error return when attempting to copy local files""" with self.assertRaises(ValueError): await scp('src', 'dst') @asynctest async def test_copy_multiple(self): """Test copying multiple files over SCP""" try: os.mkdir('src') self._create_file('src/file1') self._create_file('src/file2') await scp([(self._scp_server, 'src/file1'), (self._scp_server, 'src/file2')], '.') self._check_file('src/file1', 'file1') self._check_file('src/file2', 'file2') finally: remove('src file1 file2') @asynctest async def test_copy_recurse_not_directory(self): """Test copying a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest async def test_source_string(self): """Test passing a string to SCP""" with self.assertRaises(OSError): await scp('\xff:xxx', '.') @asynctest async def test_source_bytes(self): """Test passing a byte string to SCP""" with self.assertRaises(OSError): await scp('\xff:xxx'.encode(), '.') @asynctest async def test_source_open_connection(self): """Test passing an open SSHClientConnection to SCP as source""" try: async with self.connect() as conn: self._create_file('src') await scp((conn, 'src'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest async def test_destination_open_connection(self): """Test passing an open SSHClientConnection to SCP as destination""" try: async with self.connect() as conn: os.mkdir('src') self._create_file('src/file1') await scp('src/file1', conn) self._check_file('src/file1', 'file1') finally: remove('src file1') @asynctest async def test_missing_path(self): """Test running SCP with missing path""" async with self.connect() as conn: result = await conn.run('scp ') self.assertEqual(result.stderr, 'scp: the following arguments ' 'are required: path\n') @asynctest async def test_missing_direction(self): """Test running SCP with missing direction argument""" async with self.connect() as conn: result = await conn.run('scp xxx') self.assertEqual(result.stderr, 'scp: one of the arguments -f -t ' 'is required\n') @asynctest async def test_invalid_argument(self): """Test running SCP with invalid argument""" async with self.connect() as conn: result = await conn.run('scp -f -x src') self.assertEqual(result.stderr, 'scp: unrecognized arguments: -x\n') @asynctest async def test_invalid_c_argument(self): """Test running SCP with invalid argument to C request""" async with self.connect() as conn: result = await conn.run('scp -t dst', input='C\n') self.assertEqual(result.stdout, '\0\x01scp: Invalid copy or dir request\n') @asynctest async def test_invalid_t_argument(self): """Test running SCP with invalid argument to C request""" async with self.connect() as conn: result = await conn.run('scp -t -p dst', input='T\n') self.assertEqual(result.stdout, '\0\x01scp: Invalid time request\n') class _TestSCPAsync(_TestSCP): """Unit test for AsyncSSH SCP using an async SFTPServer""" @classmethod async def start_server(cls): """Start an SFTP server with async callbacks""" return await cls.create_server(sftp_factory=_AsyncSFTPServer, allow_scp=True) class _TestSCPAttrs(_CheckSCP): """Unit test for SCP with SFTP server returning SFTPAttrs""" @classmethod async def start_server(cls): """Start an SFTP server which returns SFTPAttrs from stat""" return await cls.create_server(sftp_factory=_SFTPAttrsSFTPServer, allow_scp=True) @asynctest async def test_get(self): """Test getting a file over SCP with stat returning SFTPAttrs""" try: self._create_file('src') await scp((self._scp_server, 'src*'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @asynctest async def test_put_recurse_not_directory(self): """Test putting a directory over SCP where target is not directory""" try: os.mkdir('src') self._create_file('dst') self._create_file('src/file1') with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst'), recurse=True) finally: remove('src dst') @asynctest async def test_put_not_permitted(self): """Test putting a file over SCP onto an unwritable target""" try: self._create_file('src') os.mkdir('dst') os.chmod('dst', 0) with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst/src')) finally: os.chmod('dst', 0o755) remove('src dst') class _TestSCPIOError(_CheckSCP): """Unit test for SCP with SFTP server returning file I/O error""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" return await cls.create_server(sftp_factory=_IOErrorSFTPServer, allow_scp=True) @asynctest async def test_put_error(self): """Test error when putting a file over SCP""" try: self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst')) finally: remove('src dst') @asynctest async def test_copy_error(self): """Test error when copying a file over SCP""" try: self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), (self._scp_server, 'dst')) finally: remove('src dst') class _TestSCPErrors(_CheckSCP): """Unit test for SCP returning error on startup""" @classmethod async def start_server(cls): """Start an SFTP server which returns file I/O errors""" async def _handle_client(process): """Handle new client""" async with process: command = process.command if command.endswith('get_connection_lost'): pass elif command.endswith('get_dir_no_recurse'): await process.stdin.read(1) process.stdout.write('D0755 0 src\n') elif command.endswith('get_early_eof'): await process.stdin.read(1) process.stdout.write('C0644 10 src\n') await process.stdin.read(1) elif command.endswith('get_extra_e'): await process.stdin.read(1) process.stdout.write('E\n') await process.stdin.read(1) elif command.endswith('get_t_without_preserve'): await process.stdin.read(1) process.stdout.write('T0 0 0 0\n') await process.stdin.read(1) elif command.endswith('get_unknown_action'): await process.stdin.read(1) process.stdout.write('X\n') await process.stdin.read(1) elif command.endswith('put_connection_lost'): process.stdout.write('\0\0') elif command.endswith('put_startup_error'): process.stdout.write('Error starting SCP\n') elif command.endswith('recv_early_eof'): process.stdout.write('\0') await process.stdin.readline() try: process.stdout.write('\0') except BrokenPipeError: pass else: process.exit(255) return await cls.create_server(process_factory=_handle_client) @asynctest async def test_get_directory_without_recurse(self): """Test receiving directory when recurse wasn't requested""" try: with self.assertRaises((SFTPBadMessage, SFTPConnectionLost)): await scp((self._scp_server, 'get_dir_no_recurse'), 'dst') finally: remove('dst') @asynctest async def test_get_early_eof(self): """Test getting early EOF when getting a file over SCP""" try: with self.assertRaises(SFTPConnectionLost): await scp((self._scp_server, 'get_early_eof'), 'dst') finally: remove('dst') @asynctest async def test_get_t_without_preserve(self): """Test getting timestamps with requesting preserve""" try: await scp((self._scp_server, 'get_t_without_preserve'), 'dst') finally: remove('dst') @asynctest async def test_get_unknown_action(self): """Test getting unknown action from SCP server during get""" try: with self.assertRaises(SFTPBadMessage): await scp((self._scp_server, 'get_unknown_action'), 'dst') finally: remove('dst') @asynctest async def test_put_startup_error(self): """Test SCP server returning an error on startup""" try: self._create_file('src') with self.assertRaises(SFTPFailure) as exc: await scp('src', (self._scp_server, 'put_startup_error')) self.assertEqual(exc.exception.reason, 'Error starting SCP') finally: remove('src') @asynctest async def test_put_connection_lost(self): """Test SCP server abruptly closing connection on put""" try: self._create_file('src') with self.assertRaises(SFTPConnectionLost) as exc: await scp('src', (self._scp_server, 'put_connection_lost')) self.assertEqual(exc.exception.reason, 'Connection lost') finally: remove('src') @asynctest async def test_copy_connection_lost_source(self): """Test source abruptly closing connection during SCP copy""" with self.assertRaises(SFTPConnectionLost) as exc: await scp((self._scp_server, 'get_connection_lost'), (self._scp_server, 'recv_early_eof')) self.assertEqual(exc.exception.reason, 'Connection lost') @asynctest async def test_copy_connection_lost_sink(self): """Test sink abruptly closing connection during SCP copy""" with self.assertRaises(SFTPConnectionLost) as exc: await scp((self._scp_server, 'get_early_eof'), (self._scp_server, 'put_connection_lost')) self.assertEqual(exc.exception.reason, 'Connection lost') @asynctest async def test_copy_early_eof(self): """Test getting early EOF when copying a file over SCP""" with self.assertRaises(SFTPConnectionLost): await scp((self._scp_server, 'get_early_eof'), (self._scp_server, 'recv_early_eof')) @asynctest async def test_copy_extra_e(self): """Test getting extra E when copying a file over SCP""" await scp((self._scp_server, 'get_extra_e'), (self._scp_server, 'recv_early_eof')) @asynctest async def test_copy_unknown_action(self): """Test getting unknown action from SCP server during copy""" with self.assertRaises(SFTPBadMessage): await scp((self._scp_server, 'get_unknown_action'), (self._scp_server, 'recv_early_eof')) @asynctest async def test_unknown(self): """Test unknown SCP server request for code coverage""" with self.assertRaises(SFTPConnectionLost): await scp('src', (self._scp_server, 'unknown')) asyncssh-2.20.0/tests/test_sk.py000066400000000000000000000335011475467777400166540ustar00rootroot00000000000000# Copyright (c) 2019-2020 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH security key support""" import unittest import asyncssh from .server import ServerTestCase from .sk_stub import sk_available, stub_sk, unstub_sk, patch_sk, sk_error from .util import asynctest, get_test_key class _CheckSKAuth(ServerTestCase): """Common code for testing security key authentication""" _sk_devs = [2] _sk_alg = 'sk-ssh-ed25519@openssh.com' _sk_resident = False _sk_touch_required = True _sk_auth_touch_required = True _sk_use_webauthn = False _sk_cert = False _sk_host = False @classmethod async def start_server(cls): """Start an SSH server which supports security key authentication""" cls.addClassCleanup(unstub_sk, *stub_sk(cls._sk_devs, cls._sk_use_webauthn)) cls._privkey = get_test_key( cls._sk_alg, cls._sk_use_webauthn, resident=cls._sk_resident, touch_required=cls._sk_touch_required) if cls._sk_host: if cls._sk_cert: cert = cls._privkey.generate_host_certificate( cls._privkey, 'localhost', principals=['127.0.0.1']) key = (cls._privkey, cert) else: key = cls._privkey return await cls.create_server(server_host_keys=[key]) else: options = [] if cls._sk_cert: options.append('cert-authority') if not cls._sk_auth_touch_required: options.append('no-touch-required') auth_keys = asyncssh.import_authorized_keys( ','.join(options) + (' ' if options else '') + cls._privkey.export_public_key().decode('utf-8')) return await cls.create_server(authorized_client_keys=auth_keys) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthKeyNotFound(ServerTestCase): """Unit tests for security key authentication with no key found""" @patch_sk([]) @asynctest async def test_enroll_key_not_found(self): """Test generating key with no security key found""" with self.assertRaises(ValueError): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthCTAP1(_CheckSKAuth): """Unit tests for security key authentication with CTAP version 1""" _sk_devs = [1] _sk_alg = 'sk-ecdsa-sha2-nistp256@openssh.com' @asynctest async def test_auth(self): """Test authenticating with a CTAP 1 security key""" async with self.connect(username='ckey', client_keys=[self._privkey]): pass @asynctest async def test_sk_unsupported_alg(self): """Test unsupported security key algorithm""" with self.assertRaises(ValueError): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') @asynctest async def test_enroll_ctap1_error(self): """Test generating key returning a CTAP 1 error""" with sk_error('err'): with self.assertRaises(ValueError): asyncssh.generate_private_key(self._sk_alg) @asynctest async def test_auth_ctap1_error(self): """Test security key returning a CTAP 1 error""" with sk_error('err'): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[self._privkey]) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthCTAP2(_CheckSKAuth): """Unit tests for security key authentication with CTAP version 2""" _sk_devs = [2] @asynctest async def test_auth(self): """Test authenticating with a CTAP 2 security key""" async with self.connect(username='ckey', client_keys=[self._privkey]): pass @asynctest async def test_enroll_without_pin(self): """Test generating key without a PIN""" key = get_test_key('sk-ssh-ed25519@openssh.com') self.assertIsNotNone(key) @asynctest async def test_enroll_with_pin(self): """Test generating key with a PIN""" key = get_test_key('sk-ssh-ed25519@openssh.com', pin=b'123456') self.assertIsNotNone(key) @asynctest async def test_enroll_ctap2_error(self): """Test generating key returning a CTAP 2 error""" with sk_error('err'): with self.assertRaises(ValueError): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') @asynctest async def test_auth_ctap2_error(self): """Test security key returning a CTAP 2 error""" with sk_error('err'): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[self._privkey]) @asynctest async def test_enroll_pin_invalid(self): """Test generating key while providing invalid PIN""" with sk_error('badpin'): with self.assertRaises(ValueError): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com', pin=b'123456') @asynctest async def test_enroll_pin_required(self): """Test generating key without providing a required PIN""" with sk_error('pinreq'): with self.assertRaises(ValueError): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthWebAuthN(_CheckSKAuth): """Unit tests for security key authentication with WebAuthN""" _sk_alg = 'sk-ecdsa-sha2-nistp256@openssh.com' _sk_use_webauthn = True @asynctest async def test_auth(self): """Test authenticating with the Windows WebAuthN API""" async with self.connect(username='ckey', client_keys=[self._privkey]): pass @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthMultipleKeys(_CheckSKAuth): """Unit tests for security key authentication with multiple keys""" _sk_devs = [2, 1] @asynctest async def test_auth_cred_not_found(self): """Test authenticating with security credential not found""" with sk_error('nocred'): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[self._privkey]) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthResidentKeys(_CheckSKAuth): """Unit tests for loading resident keys""" _sk_resident = True @asynctest async def test_load_resident(self): """Test loading resident keys""" keys = asyncssh.load_resident_keys(b'123456') async with self.connect(username='ckey', client_keys=[keys[0]]): pass @asynctest async def test_load_resident_user_match(self): """Test loading resident keys matching a specific user""" keys = asyncssh.load_resident_keys(b'123456', user='AsyncSSH') async with self.connect(username='ckey', client_keys=[keys[0]]): pass @asynctest async def test_koad_resident_user_match(self): """Test loading resident keys matching a specific user""" self.assertIsNotNone(asyncssh.load_resident_keys(b'123456', user='AsyncSSH')) @asynctest async def test_load_resident_no_match(self): """Test loading resident keys with no user match""" self.assertEqual(asyncssh.load_resident_keys(b'123456', user='xxx'), []) @asynctest async def test_no_resident_keys(self): """Test retrieving empty list of resident keys""" with sk_error('nocred'): self.assertEqual(asyncssh.load_resident_keys(b'123456'), []) @asynctest async def test_load_resident_ctap2_error(self): """Test getting resident keys returning a CTAP 2 error""" with sk_error('err'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.load_resident_keys(b'123456') @asynctest async def test_load_resident_pin_invalid(self): """Test getting resident keys while providing invalid PIN""" with sk_error('badpin'): with self.assertRaises(ValueError): asyncssh.load_resident_keys(b'123456') @asynctest async def test_pin_not_set(self): """Test getting resident keys from a key with no configured PIN""" with sk_error('nopin'): with self.assertRaises(ValueError): asyncssh.load_resident_keys(b'123456') @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthTouchNotRequired(_CheckSKAuth): """Unit tests for security key authentication without touch""" _sk_touch_required = False _sk_auth_touch_required = False @asynctest async def test_auth_without_touch(self): """Test authenticating with a security key without touch""" async with self.connect(username='ckey', client_keys=[self._privkey]): pass @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthTouchRequiredECDSA(_CheckSKAuth): """Unit tests for security key authentication failing without touch""" _sk_alg = 'sk-ecdsa-sha2-nistp256@openssh.com' _sk_touch_required = False _sk_auth_touch_required = True @asynctest async def test_auth_touch_required(self): """Test auth failing with a security key not providing touch""" with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[self._privkey]) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKCertAuthTouchNotRequired(_CheckSKAuth): """Unit tests for security key cert authentication without touch""" _sk_touch_required = False _sk_auth_touch_required = False _sk_cert = True @asynctest async def test_cert_auth_cert_touch_not_required(self): """Test authenticating with a security key cert not requiring touch""" cert = self._privkey.generate_user_certificate(self._privkey, 'name', touch_required=False) async with self.connect(username='ckey', client_keys=[(self._privkey, cert)]): pass @asynctest async def test_cert_auth_cert_touch_required(self): """Test cert auth failing with a security key cert requiring touch""" cert = self._privkey.generate_user_certificate(self._privkey, 'name', touch_required=True) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(self._privkey, cert)]) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKCertAuthTouchRequired(_CheckSKAuth): """Unit tests for security key cert authentication failing without touch""" _sk_touch_required = False _sk_auth_touch_required = True _sk_cert = True @asynctest async def test_cert_auth_touch_required(self): """Test cert auth failing with a security key requiring touch""" cert = self._privkey.generate_user_certificate(self._privkey, 'name', touch_required=False) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(self._privkey, cert)]) @asynctest async def test_cert_auth_cert_touch_required(self): """Test cert auth failing with a security key cert requiring touch""" cert = self._privkey.generate_user_certificate(self._privkey, 'name', touch_required=True) with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', client_keys=[(self._privkey, cert)]) @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKHostAuth(_CheckSKAuth): """Unit tests for security key host authentication""" _sk_host = True @asynctest async def test_sk_host_auth(self): """Test a server using a security key as a host key""" pubkey = self._privkey.convert_to_public() async with self.connect(known_hosts=([pubkey], [], [])): pass @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKHostCertAuth(_CheckSKAuth): """Unit tests for security key host cert authentication""" _sk_cert = True _sk_host = True @asynctest async def test_sk_host_auth(self): """Test a server host using a security key host certificate""" pubkey = self._privkey.convert_to_public() async with self.connect(known_hosts=([pubkey], [pubkey], [])): pass asyncssh-2.20.0/tests/test_stream.py000066400000000000000000000326501475467777400175360ustar00rootroot00000000000000# Copyright (c) 2016-2020 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH stream API""" import asyncio import re import asyncssh from .server import Server, ServerTestCase from .util import asynctest, echo class _StreamServer(Server): """Server for testing the AsyncSSH stream API""" async def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=no-self-use action = stdin.channel.get_command() if not action: action = 'echo' if action == 'echo': await echo(stdin, stdout) elif action == 'echo_stderr': await echo(stdin, stdout, stderr) elif action == 'close': await stdin.read(1) stdout.write('\n') elif action == 'disconnect': stdout.write(await stdin.read(1)) raise asyncssh.ConnectionLost('Connection lost') elif action == 'custom_disconnect': await stdin.read(1) raise asyncssh.DisconnectError(99, 'Disconnect') elif action == 'partial': try: await stdin.readexactly(10) except asyncio.IncompleteReadError as exc: stdout.write(exc.partial) try: await stdin.read() except asyncssh.TerminalSizeChanged: pass stdout.write(await stdin.readexactly(5)) else: stdin.channel.exit(255) stdin.channel.close() await stdin.channel.wait_closed() def _begin_session_non_async(self, stdin, stdout, stderr): """Non-async version of session handler""" self._conn.create_task(self._begin_session(stdin, stdout, stderr)) def begin_auth(self, username): """Handle client authentication request""" return False def session_requested(self): """Handle a request to create a new session""" username = self._conn.get_extra_info('username') if username == 'non_async': return self._begin_session_non_async elif username != 'no_channels': return self._begin_session else: return False class _TestStream(ServerTestCase): """Unit tests for AsyncSSH stream API""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return await cls.create_server(_StreamServer) async def _check_session(self, conn, large_block=False): """Open a session and test if an input line is echoed back""" stdin, stdout, stderr = await conn.open_session('echo_stderr') if large_block: data = 4 * [1025*1024*'\0'] else: data = [str(id(self))] stdin.writelines(data) await stdin.drain() self.assertTrue(stdin.can_write_eof()) self.assertFalse(stdin.is_closing()) stdin.write_eof() self.assertTrue(stdin.is_closing()) stdout_data, stderr_data = await asyncio.gather(stdout.read(), stderr.read()) data = ''.join(data) self.assertEqual(data, stdout_data) self.assertEqual(data, stderr_data) await stdin.channel.wait_closed() await stdin.drain() stdin.close() @asynctest async def test_shell(self): """Test starting a shell""" async with self.connect() as conn: await self._check_session(conn) @asynctest async def test_shell_failure(self): """Test failure to start a shell""" async with self.connect(username='no_channels') as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_session() @asynctest async def test_shell_non_async(self): """Test starting a shell using non-async handler""" async with self.connect(username='non_async') as conn: await self._check_session(conn) @asynctest async def test_large_block(self): """Test sending and receiving a large block of data""" async with self.connect() as conn: await self._check_session(conn, large_block=True) @asynctest async def test_feed(self): """Test feeding data into an SSHReader""" async with self.connect() as conn: _, stdout, stderr = await conn.open_session() stdout.feed_data('stdout') stderr.feed_data('stderr') stdout.feed_eof() self.assertEqual(await stdout.read(), 'stdout') self.assertEqual(await stderr.read(), 'stderr') @asynctest async def test_async_iterator(self): """Test reading lines by using SSHReader as an async iterator""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() data = ['Line 1\n', 'Line 2\n'] stdin.writelines(data) stdin.write_eof() async for line in stdout: self.assertEqual(line, data.pop(0)) self.assertEqual(data, []) @asynctest async def test_write_broken_pipe(self): """Test close while we're writing""" async with self.connect() as conn: stdin, _, _ = await conn.open_session('close') stdin.write(4*1024*1024*'\0') with self.assertRaises((ConnectionError, asyncssh.ConnectionLost)): await stdin.drain() @asynctest async def test_write_disconnect(self): """Test disconnect while we're writing""" async with self.connect() as conn: stdin, _, _ = await conn.open_session('disconnect') stdin.write(6*1024*1024*'\0') with self.assertRaises((ConnectionError, asyncssh.ConnectionLost)): await stdin.drain() @asynctest async def test_read_exception(self): """Test read returning an exception""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session('disconnect') stdin.write('\0') self.assertEqual((await stdout.read()), '\0') with self.assertRaises(asyncssh.ConnectionLost): await stdout.read(1) stdin.close() @asynctest async def test_readline_exception(self): """Test readline returning an exception""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session('disconnect') stdin.write('\0') self.assertEqual((await stdout.readline()), '\0') with self.assertRaises(asyncssh.ConnectionLost): await stdout.readline() @asynctest async def test_readexactly_partial_exception(self): """Test readexactly returning partial data before an exception""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session('partial') stdin.write('abcde') stdout.channel.change_terminal_size(80, 24) stdin.write('fghij') self.assertEqual((await stdout.read()), 'abcdefghij') @asynctest async def test_custom_disconnect(self): """Test receiving a custom disconnect message""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session('custom_disconnect') stdin.write('\0') with self.assertRaises(asyncssh.DisconnectError) as exc: await stdout.read() self.assertEqual(exc.exception.code, 99) self.assertEqual(exc.exception.reason, 'Disconnect (error 99)') @asynctest async def test_readuntil_bigger_than_window(self): """Test readuntil getting data bigger than the receive window""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write(4*1024*1024*'\0') with self.assertRaises(asyncio.IncompleteReadError) as exc: await stdout.readuntil('\n') self.assertEqual(exc.exception.partial, stdin.channel.get_recv_window()*'\0') stdin.close() await conn.wait_closed() @asynctest async def test_readline_timeout(self): """Test receiving a timeout while calling readline""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write('ab') try: await asyncio.wait_for(stdout.readline(), timeout=0.1) except asyncio.TimeoutError: pass stdin.write('c\n') self.assertEqual((await stdout.readline()), 'abc\n') stdin.close() @asynctest async def test_pause_read(self): """Test pause reading""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write(6*1024*1024*'\0') await asyncio.sleep(0.01) await stdout.read(1) await asyncio.sleep(0.01) await stdout.read(1) @asynctest async def test_readuntil(self): """Test readuntil with multi-character separator""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write('abc\r') await asyncio.sleep(0.01) stdin.write('\ndef') await asyncio.sleep(0.01) stdin.write('\r\n') await asyncio.sleep(0.01) stdin.write('ghi') stdin.write_eof() self.assertEqual((await stdout.readuntil('\r\n')), 'abc\r\n') self.assertEqual((await stdout.readuntil('\r\n')), 'def\r\n') with self.assertRaises(asyncio.IncompleteReadError) as exc: await stdout.readuntil('\r\n') self.assertEqual(exc.exception.partial, 'ghi') stdin.close() @asynctest async def test_readuntil_separator_list(self): """Test readuntil with a list of separators""" seps = ('+', '-', '\r\n') async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write('ab') await asyncio.sleep(0.01) stdin.write('c+d') await asyncio.sleep(0.01) stdin.write('ef-gh') await asyncio.sleep(0.01) stdin.write('i\r') await asyncio.sleep(0.01) stdin.write('\n') stdin.write_eof() self.assertEqual((await stdout.readuntil(seps)), 'abc+') self.assertEqual((await stdout.readuntil(seps)), 'def-') self.assertEqual((await stdout.readuntil(seps)), 'ghi\r\n') stdin.close() @asynctest async def test_readuntil_empty_separator(self): """Test readuntil with empty separator""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() with self.assertRaises(ValueError): await stdout.readuntil('') stdin.close() @asynctest async def test_readuntil_regex(self): """Test readuntil with a regex pattern""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() stdin.write("hello world\nhello world") output = await stdout.readuntil( re.compile('hello world'), len('hello world') ) self.assertEqual(output, "hello world") output = await stdout.readuntil( re.compile('hello world'), len('hello world') ) self.assertEqual(output, "\nhello world") stdin.close() await conn.wait_closed() @asynctest async def test_abort(self): """Test abort on a channel""" async with self.connect() as conn: stdin, _, _ = await conn.open_session() stdin.channel.abort() @asynctest async def test_abort_closed(self): """Test abort on an already-closed channel""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session('close') stdin.write('\n') await stdout.read() stdin.channel.abort() @asynctest async def test_get_extra_info(self): """Test get_extra_info on streams""" async with self.connect() as conn: stdin, stdout, _ = await conn.open_session() self.assertEqual(stdin.get_extra_info('connection'), stdout.get_extra_info('connection')) stdin.close() @asynctest async def test_unknown_action(self): """Test unknown action""" async with self.connect() as conn: stdin, _, _ = await conn.open_session('unknown') await stdin.channel.wait_closed() self.assertEqual(stdin.channel.get_exit_status(), 255) asyncssh-2.20.0/tests/test_subprocess.py000066400000000000000000000205631475467777400204330ustar00rootroot00000000000000# Copyright (c) 2019 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH subprocess API""" import asyncio from signal import SIGINT import asyncssh from .server import Server, ServerTestCase from .util import asynctest, echo class _SubprocessProtocol(asyncssh.SSHSubprocessProtocol): """Unit test SSH subprocess protocol""" def __init__(self): self._chan = None self.recv_buf = {1: [], 2: []} self.exc = {1: None, 2: None} def pipe_connection_lost(self, fd, exc): """Handle remote process close""" self.exc[fd] = exc def pipe_data_received(self, fd, data): """Handle data from the remote process""" self.recv_buf[fd].append(data) async def _create_subprocess(conn, command=None, **kwargs): """Create a client subprocess""" return await conn.create_subprocess(_SubprocessProtocol, command, **kwargs) class _SubprocessServer(Server): """Server for testing the AsyncSSH subprocess API""" def begin_auth(self, username): """Handle client authentication request""" return False def session_requested(self): """Handle a request to create a new session""" return self._begin_session async def _begin_session(self, stdin, stdout, stderr): """Begin processing a new session""" # pylint: disable=no-self-use action = stdin.channel.get_command() if not action: action = 'echo' if action == 'exit_status': stdout.channel.exit(1) elif action == 'signal': try: await stdin.readline() except asyncssh.SignalReceived as exc: stdout.channel.exit_with_signal(exc.signal) else: await echo(stdin, stdout, stderr) class _TestSubprocess(ServerTestCase): """Unit tests for AsyncSSH subprocess API""" @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server( _SubprocessServer, authorized_client_keys='authorized_keys')) async def _check_subprocess(self, conn, command=None, *, encoding=None, **kwargs): """Start a subprocess and test if an input line is echoed back""" transport, protocol = await _create_subprocess(conn, command, encoding=encoding, *kwargs) data = str(id(self)) if encoding is None: data = data.encode('ascii') stdin = transport.get_pipe_transport(0) self.assertTrue(stdin.can_write_eof()) stdin.writelines([data]) self.assertFalse(transport.is_closing()) stdin.write_eof() self.assertTrue(transport.is_closing()) await transport.wait_closed() sep = '' if encoding else b'' for buf in protocol.recv_buf.values(): self.assertEqual(sep.join([data]), sep.join(buf)) transport.close() @asynctest async def test_shell(self): """Test starting a shell""" async with self.connect() as conn: await self._check_subprocess(conn) @asynctest async def test_exec(self): """Test execution of a remote command""" async with self.connect() as conn: await self._check_subprocess(conn, 'echo') @asynctest async def test_encoding(self): """Test setting encoding""" async with self.connect() as conn: await self._check_subprocess(conn, 'echo', encoding='ascii') @asynctest async def test_input(self): """Test providing input when creating a subprocess""" data = str(id(self)).encode('ascii') async with self.connect() as conn: transport, protocol = await _create_subprocess(conn, input=data) await transport.wait_closed() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), data) @asynctest async def test_redirect_stderr(self): """Test redirecting stderr to file""" data = str(id(self)).encode('ascii') async with self.connect() as conn: transport, protocol = await _create_subprocess(conn, stderr='stderr') stdin = transport.get_pipe_transport(0) stdin.write(data) stdin.write_eof() await transport.wait_closed() with open('stderr', 'rb') as f: stderr_data = f.read() self.assertEqual(b''.join(protocol.recv_buf[1]), data) self.assertEqual(b''.join(protocol.recv_buf[2]), b'') self.assertEqual(stderr_data, data) @asynctest async def test_close(self): """Test closing transport""" async with self.connect() as conn: transport, protocol = await _create_subprocess(conn) transport.close() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'') @asynctest async def test_exit_status(self): """Test reading exit status""" async with self.connect() as conn: transport, protocol = await _create_subprocess(conn, 'exit_status') await transport.wait_closed() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'') self.assertEqual(transport.get_returncode(), 1) @asynctest async def test_stdin_abort(self): """Test abort on stdin""" async with self.connect() as conn: transport, protocol = await _create_subprocess(conn) stdin = transport.get_pipe_transport(0) stdin.abort() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'') @asynctest async def test_stdin_close(self): """Test closing stdin""" async with self.connect() as conn: transport, protocol = await _create_subprocess(conn) stdin = transport.get_pipe_transport(0) stdin.close() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'') @asynctest async def test_read_pause(self): """Test read pause""" async with self.connect() as conn: transport, protocol = await _create_subprocess(conn) stdin = transport.get_pipe_transport(0) stdout = transport.get_pipe_transport(1) stdout.pause_reading() stdin.write(b'\n') await asyncio.sleep(0.1) for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'') stdout.resume_reading() for buf in protocol.recv_buf.values(): self.assertEqual(b''.join(buf), b'\n') stdin.close() @asynctest async def test_signal(self): """Test sending a signal""" async with self.connect() as conn: transport, _ = await _create_subprocess(conn, 'signal') transport.send_signal(SIGINT) await transport.wait_closed() self.assertEqual(transport.get_returncode(), -SIGINT) @asynctest async def test_misc(self): """Test other transport and pipe methods""" async with self.connect() as conn: transport, _ = await _create_subprocess(conn) self.assertEqual(transport.get_pid(), None) stdin = transport.get_pipe_transport(0) self.assertEqual(transport.get_extra_info('socket'), stdin.get_extra_info('socket')) stdin.set_write_buffer_limits() self.assertEqual(stdin.get_write_buffer_size(), 0) stdin.close() asyncssh-2.20.0/tests/test_tuntap.py000066400000000000000000000476121475467777400175620ustar00rootroot00000000000000# Copyright (c) 2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH TUN/TAP support""" import asyncio import builtins import errno import socket import struct import sys from unittest import skipIf, skipUnless from unittest.mock import patch import asyncssh from asyncssh.tuntap import IFF_FMT, LINUX_IFF_TUN from .server import Server, ServerTestCase from .util import asynctest if sys.platform != 'win32': # pragma: no branch import fcntl _orig_funcs = {} class _TunTapSocketMock: """TunTap socket mock""" def ioctl(self, request, arg): """Ignore ioctl requests to bring interface up""" # pylint: disable=no-self-use,unused-argument return arg def close(self): """Close this mock""" # pylint: disable=no-self-use class _TunTapMock: """Common TUN/TAP mock""" _from_intf = {} def __init__(self, interface=None): if interface in self._from_intf: raise OSError(errno.EBUSY, 'Device busy') self._loop = asyncio.get_event_loop() self._sock1, self._sock2 = socket.socketpair(type=socket.SOCK_DGRAM) self._sock2.setblocking(False) self._interface = interface if interface: self._from_intf[interface] = self @classmethod def lookup_intf(cls, interface): """Look up mock by interface""" return cls._from_intf[interface] def fileno(self): """Return the fileno of sock1""" return self._sock1.fileno() def setblocking(self, blocking): """Set blocking mode on the socket""" self._sock1.setblocking(blocking) async def get_packets(self, count): """Get packets written to the TUN/TAP""" return [await self._loop.sock_recv(self._sock2, 65536) for _ in range(count)] def put_packets(self, packets): """Put packets for the TUN/TAP to read""" for packet in packets: self._sock2.send(packet) def read(self, size=-1): """Read a packet""" return self._sock1.recv(size) def write(self, packet): """Write a packet""" return self._sock1.send(packet) def close(self): """Close this mock""" self._from_intf.pop(self._interface, None) self._sock2.send(b'') self._sock1.close() self._sock2.close() class _TunTapOSXMock(_TunTapMock): """TunTapOSX mock""" disable = False def __init__(self, name): if self.disable: raise OSError(errno.ENOENT, 'No such device') interface = name[5:] if int(interface[3:]) >= 16: raise OSError(errno.ENOENT, 'No such device') super().__init__(interface) class _DarwinUTunMock(_TunTapMock): """Darwin UTun mock""" _AF_INET_PREFIX = socket.AF_INET.to_bytes(4, 'big') def __init__(self): super().__init__() self._unit = None def ioctl(self, request, arg): """Respond to DARWIN_CTLIOCGINFO request""" # pylint: disable=no-self-use,unused-argument return arg def connect(self, addr): """Connect to requested unit""" _, unit = addr if unit == 0: for unit in range(16): interface = f'utun{unit}' if interface not in self._from_intf: break else: raise OSError(errno.EBUSY, 'No utun devices available') elif unit <= 16: unit -= 1 interface = f'utun{unit}' if interface in self._from_intf: raise OSError(errno.EBUSY, 'Device busy') else: raise OSError(errno.ENOENT, 'No such device') self._unit = unit self._interface = interface self._from_intf[interface] = self return 0 def getpeername(self): """Return utun unit""" return (0, self._unit + 1) def send(self, packet): """Send a packet""" return super().write(packet[4:]) def recv(self, size): """Receive a packet""" return self._AF_INET_PREFIX + self.read(size) class _LinuxMock(_TunTapMock): """Linux TUN/TAP mock""" def __init__(self): super().__init__() self._sock1.setblocking(False) def ioctl(self, request, arg): """Respond to LINUX_TUNSETIFF request""" # pylint: disable=unused-argument name, flags = struct.unpack(IFF_FMT, arg) if name[0] == 0: prefix = 'tun' if flags & LINUX_IFF_TUN else 'tap' for unit in range(16): interface = f'{prefix}{unit}' if interface not in self._from_intf: break else: self.close() raise OSError(errno.EBUSY, 'No tun devices available') arg = struct.pack(IFF_FMT, interface.encode(), flags) else: interface = name.strip(b'\0').decode() unit = int(interface[3:]) if unit >= 16: raise OSError(errno.ENOENT, 'No such device') self._interface = interface self._from_intf[interface] = self return arg def read(self, size=-1): """Read a packet""" try: return super().read(size) except BlockingIOError: return None def _open(name, mode, *args, **kwargs): """Mock file open""" name = str(name) if name.startswith('/dev/tun') or name.startswith('/dev/tap'): return _TunTapOSXMock(name) elif name == '/dev/net/tun': return _LinuxMock() else: return _orig_funcs['open'](name, mode, *args, **kwargs) # pylint: disable=redefined-builtin def _socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, fileno=None): """Mock socket creation""" if hasattr(socket, 'PF_SYSTEM') and family == socket.PF_SYSTEM and \ type == socket.SOCK_DGRAM and proto == socket.SYSPROTO_CONTROL: return _DarwinUTunMock() elif family == socket.AF_INET and type == socket.SOCK_DGRAM: return _TunTapSocketMock() else: return _orig_funcs['socket'](family, type, proto, fileno) def _ioctl(file, request, arg): """Mock ioctl""" if isinstance(file, (_DarwinUTunMock, _LinuxMock, _TunTapSocketMock)): return file.ioctl(request, arg) else: # pragma: no cover return _orig_funcs['ioctl'](file, request, arg) async def get_packets(interface, count): """Return TUN/TAP packets written""" return await _TunTapMock.lookup_intf(interface).get_packets(count) def put_packets(interface, packets): """Feed packets to a TUN/TAP mock""" _TunTapMock.lookup_intf(interface).put_packets(packets) def patch_tuntap(cls): """Decorator to stub out TUN/TAP functions""" _orig_funcs['open'] = builtins.open _orig_funcs['socket'] = socket.socket cls = patch('builtins.open', _open)(cls) cls = patch('socket.socket', _socket)(cls) if sys.platform != 'win32': # pragma: no branch _orig_funcs['ioctl'] = fcntl.ioctl cls = patch('fcntl.ioctl', _ioctl)(cls) return cls class _EchoSession(asyncssh.SSHTunTapSession): """Echo packets on a TUN session""" def __init__(self): self._chan = None def connection_made(self, chan): """Handle session open""" self._chan = chan def data_received(self, data, datatype): """Handle data from the channel""" self._chan.write(data) def eof_received(self): """Handle EOF from the channel""" self._chan.write_eof() class _TunTapServer(Server): """Server for testing TUN/TAP functions""" async def _echo_handler(self, reader, writer): """Echo packets on a TUN session""" try: async for packet in reader: writer.write(packet) finally: writer.close() def tun_requested(self, unit): """Handle TUN requests""" if unit is None or unit <= 32: return True elif unit == 33: return _EchoSession() elif unit == 34: return (self._conn.create_tuntap_channel(), _EchoSession()) elif unit == 35: return self._echo_handler else: return False def tap_requested(self, unit): """Handle TAP requests""" return True @skipIf(sys.platform == 'win32', 'skip TUN/TAP tests on Windows') @patch_tuntap class _TestTunTap(ServerTestCase): """Unit tests for TUN/TAP functions""" @classmethod async def start_server(cls): """Start an SSH server to connect to""" return await cls.create_server( _TunTapServer, authorized_client_keys='authorized_keys') async def _check_tuntap(self, coro, interface): """Check sending data on a TUN or TAP channel""" reader, writer = await coro try: packets = [b'123', b'456', b'789'] count = len(packets) for packet in packets: writer.write(packet) self.assertEqual((await get_packets(interface, count)), packets) put_packets(interface, packets) for packet in packets: self.assertEqual((await reader.read()), packet) finally: writer.close() async def _check_tuntap_forward(self, coro, remote_interface): """Check sending data on a TUN or TAP channel""" async with coro as forw: local_interface = forw.get_extra_info('interface') packets = [b'123', b'456', b'789'] count = len(packets) put_packets(local_interface, packets) self.assertEqual((await get_packets(remote_interface, count)), packets) put_packets(remote_interface, packets) self.assertEqual((await get_packets(local_interface, count)), packets) async def _check_tuntap_echo(self, coro): """Check echoing of packets on a TUN channel""" reader, writer = await coro try: writer.write(b'123') self.assertEqual((await reader.read()), b'123') writer.write_eof() self.assertEqual((await reader.read()), b'') finally: writer.close() await writer.wait_closed() @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tun(self): """Test sending packets on a layer 3 tunnel on macOS""" async with self.connect() as conn: await self._check_tuntap(conn.open_tun(), 'tun0') @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tun_specific_unit(self): """Test sending on a layer 3 tunnel with specific unit on macOS""" async with self.connect() as conn: await self._check_tuntap(conn.open_tun(0), 'tun0') @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tun_error(self): """Test returning an open error on a layer 3 tunnel on macOS""" with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.open_tun(32) @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') @asynctest async def test_darwin_open_utun(self): """Test sending packets on a layer 3 tunnel using UTun on macOS""" async with self.connect() as conn: await self._check_tuntap(conn.open_tun(16), 'utun0') @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') @asynctest async def test_darwin_failover_to_utun(self): """Test failing over from TunTapOSX to UTun on macOS""" try: _TunTapOSXMock.disable = True async with self.connect() as conn: await self._check_tuntap(conn.open_tun(), 'utun0') finally: _TunTapOSXMock.disable = False @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') @asynctest async def test_darwin_utun_in_use(self): """Test UTun device already in use on macOS""" async with self.connect() as conn: _, writer = await conn.open_tun(16) try: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_tun(16) finally: writer.close() await writer.wait_closed() @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') @asynctest async def test_darwin_utun_all_in_use(self): """Test all UTun devices already in use on macOS""" async with self.connect() as conn: writers = [] try: for unit in range(32): _, writer = await conn.open_tun(unit) writers.append(writer) with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_tun() finally: for writer in writers: writer.close() await writer.wait_closed() @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tap(self): """Test sending packets on a layer 2 tunnel on macOS""" async with self.connect() as conn: await self._check_tuntap(conn.open_tap(), 'tap0') @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tap_unavailable(self): """Test TunTapOSX not being available on macOS""" try: _TunTapOSXMock.disable = True with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.open_tap() finally: _TunTapOSXMock.disable = False @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_open_tap_error(self): """Test sending packets on a layer 2 tunnel on macOS""" with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.open_tap(16) @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_forward_tun(self): """Test forwarding packets on a layer 3 tunnel on macOS""" async with self.connect() as conn: await self._check_tuntap_forward(conn.forward_tun(), 'tun0') @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') @asynctest async def test_darwin_forward_utun(self): """Test forwarding packets on a layer 3 tunnel on macOS""" async with self.connect() as conn: await self._check_tuntap_forward(conn.forward_tun(16, 17), 'utun1') @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') @asynctest async def test_darwin_forward_tap(self): """Test forwarding packets on a layer 2 tunnel on macOS""" async with self.connect() as conn: await self._check_tuntap_forward(conn.forward_tap(), 'tap0') @patch('sys.platform', 'linux') @asynctest async def test_linux_open_tun(self): """Test sending packets on a layer 3 tunnel on Linux""" async with self.connect() as conn: await self._check_tuntap(conn.open_tun(), 'tun0') @patch('sys.platform', 'linux') @asynctest async def test_linux_open_tun_specific_unit(self): """Test sending on a layer 3 tunnel with specific unit on Linux""" async with self.connect() as conn: await self._check_tuntap(conn.open_tun(), 'tun0') @patch('sys.platform', 'linux') @asynctest async def test_linux_open_tun_error(self): """Test returning an open error on a layer 3 tunnel on Linux""" with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.open_tun(32) @patch('sys.platform', 'linux') @asynctest async def test_linux_open_tap(self): """Test sending packets on a layer 2 tunnel on Linux""" async with self.connect() as conn: await self._check_tuntap(conn.open_tap(), 'tap0') @patch('sys.platform', 'linux') @asynctest async def test_linux_forward_tun(self): """Test forwarding packets on a layer 3 tunnel on Linux""" async with self.connect() as conn: await self._check_tuntap_forward(conn.forward_tun(), 'tun0') @patch('sys.platform', 'linux') @asynctest async def test_linux_forward_tap(self): """Test forwarding packets on a layer 2 tunnel on Linux""" async with self.connect() as conn: await self._check_tuntap_forward(conn.forward_tap(), 'tap0') @patch('sys.platform', 'linux') @asynctest async def test_linux_all_in_use(self): """Test all TUN devices already in use on Linux""" async with self.connect() as conn: writers = [] try: for unit in range(16): _, writer = await conn.open_tun(unit) writers.append(writer) with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_tun() finally: for writer in writers: writer.close() await writer.wait_closed() @patch('sys.platform', 'xxx') @asynctest async def test_unknown_platform(self): """Test unknown platform""" async with self.connect() as conn: with self.assertRaises(asyncssh.ChannelOpenError): await conn.open_tun() @asynctest async def test_open_tun_echo_session(self): """Test an echo session on a layer 3 tunnel""" async with self.connect() as conn: await self._check_tuntap_echo(conn.open_tun(33)) @asynctest async def test_open_tun_echo_session_channel(self): """Test an echo session & channel on a layer 3 tunnel""" async with self.connect() as conn: await self._check_tuntap_echo(conn.open_tun(34)) @asynctest async def test_open_tun_echo_handler(self): """Test an echo stream handler on a layer 3 tunnel""" async with self.connect() as conn: await self._check_tuntap_echo(conn.open_tun(35)) @asynctest async def test_open_tun_denied(self): """Test returning an open error on a layer 3 tunnel""" with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.open_tun(36) @asynctest async def test_tun_forward_error(self): """Test returning a forward error on a layer 3 tunnel""" with self.assertRaises(asyncssh.ChannelOpenError): async with self.connect() as conn: await conn.forward_tun(36) @asynctest async def test_invalid_tun_mode(self): """Test sending an invalid mode in a TUN/TAP request""" async with self.connect() as conn: chan = conn.create_tuntap_channel() with self.assertRaises(asyncssh.ChannelOpenError): await chan.open(asyncssh.SSHTunTapSession, 32, 0) asyncssh-2.20.0/tests/test_x11.py000066400000000000000000000516151475467777400166560ustar00rootroot00000000000000# Copyright (c) 2016-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for AsyncSSH X11 forwarding""" import asyncio import os import socket from unittest.mock import patch import asyncssh from asyncssh.misc import maybe_wait_closed from asyncssh.packet import Boolean, String, UInt32 from asyncssh.x11 import XAUTH_FAMILY_IPV4, XAUTH_FAMILY_DECNET from asyncssh.x11 import XAUTH_FAMILY_IPV6, XAUTH_FAMILY_HOSTNAME from asyncssh.x11 import XAUTH_FAMILY_WILD, XAUTH_PROTO_COOKIE from asyncssh.x11 import XAUTH_COOKIE_LEN, X11_BASE_PORT, X11_LISTEN_HOST from asyncssh.x11 import SSHXAuthorityEntry, SSHX11ClientListener from asyncssh.x11 import walk_xauth, lookup_xauth, update_xauth from .server import Server, ServerTestCase from .util import asynctest def _failing_bind(self, address): """Raise OSError to simulate a socket bind failure""" # pylint: disable=unused-argument raise OSError async def _create_x11_process(conn, command=None, x11_forwarding=True, x11_display='test:0', **kwargs): """Create a client process with X11 forwarding enabled""" return await conn.create_process(command, x11_forwarding=x11_forwarding, x11_display=x11_display, **kwargs) class _X11Peer: """Peer representing X server to forward connections to""" expected_auth = b'' @classmethod async def connect(cls, session_factory, host, port): """Simulate connecting to an X server""" # pylint: disable=unused-argument if port == X11_BASE_PORT: return None, cls() else: raise OSError('Connection refused') def __init__(self): self._peer = None self._check_auth = True def set_peer(self, peer): """Set the peer forwarder to exchange data with""" self._peer = peer def write(self, data): """Consume data from the peer""" if self._check_auth: match = data[32:48] == self.expected_auth self._peer.write(b'\x01' if match else b'\xff') self._check_auth = False else: self._peer.write(data) def write_eof(self): """Consume EOF from the peer""" def was_eof_received(self): """Report that an incoming EOF has not been reeceived""" # pylint: disable=no-self-use return False # pragma: no cover def pause_reading(self): """Ignore flow control requests""" def resume_reading(self): """Ignore flow control requests""" def close(self): """Consume close request""" class _X11ClientListener(SSHX11ClientListener): """Unit test X server to forward connections to""" async def forward_connection(self): """Forward a connection to this server""" self._connect_coro = _X11Peer.connect return await super().forward_connection() class _X11ClientChannel(asyncssh.SSHClientChannel): """Patched X11 client channel for unit testing""" async def make_x11_forwarding_request(self, proto, data, screen): """Make a request to enable X11 forwarding""" return await self._make_request(b'x11-req', Boolean(False), String(proto), String(data), UInt32(screen)) class _X11ServerConnection(asyncssh.SSHServerConnection): """Unit test X11 forwarding server connection""" async def attach_x11_listener(self, chan, auth_proto, auth_data, screen): """Attach a channel to a remote X11 display""" if screen == 9: return False _X11Server.auth_proto = auth_proto _X11Server.auth_data = auth_data return await super().attach_x11_listener(chan, auth_proto, auth_data, screen) class _X11Server(Server): """Server for testing AsyncSSH X11 forwarding""" auth_proto = b'' auth_data = b'' @staticmethod def _uint16(value, endian): """Encode a 16-bit value using the specified endianness""" if endian == 'B': return bytes((value >> 8, value & 255)) else: return bytes((value & 255, value >> 8)) @staticmethod def _pad(data): """Pad a string to a multiple of 4 bytes""" length = len(data) % 4 return data + ((4 - length) * b'\00' if length else b'') async def _open_x11(self, chan, endian, bad): """Open an X11 connection back to the client""" display = chan.get_x11_display() if display: dpynum = int(display.rsplit(':')[-1].split('.')[0]) else: return 2 reader, writer = await asyncio.open_connection( X11_LISTEN_HOST, X11_BASE_PORT + dpynum) auth_data = bytearray(self.auth_data) if bad: auth_data[-1] ^= 0xff request = b''.join((endian.encode('ascii'), b'\x00', self._uint16(11, endian), self._uint16(0, endian), self._uint16(len(self.auth_proto), endian), self._uint16(len(self.auth_data), endian), b'\x00\x00', self._pad(self.auth_proto), self._pad(auth_data))) writer.write(request[:24]) await asyncio.sleep(0.1) writer.write(request[24:]) result = await reader.read(1) if result == b'': result = b'\x02' if result == b'\x01': writer.write(b'\x00') writer.close() await maybe_wait_closed(writer) return result[0] async def _begin_session(self, stdin, _stdout, _stderr): """Begin processing a new session""" action = stdin.channel.get_command() if action: if action.startswith('connect '): endian = action[8:9] bad = bool(action[9:] == 'X') result = await self._open_x11(stdin.channel, endian, bad) stdin.channel.exit(result) elif action == 'attach': with patch('socket.socket.bind', _failing_bind): result = await self._conn.attach_x11_listener( None, b'', b'', 0) stdin.channel.exit(bool(result)) elif action == 'open': try: result = await self._conn.create_x11_connection(None) except asyncssh.ChannelOpenError: result = None stdin.channel.exit(bool(result)) elif action == 'invalid': try: result = await self._conn.create_x11_connection( None, b'\xff') except asyncssh.ChannelOpenError: pass elif action == 'sleep': await asyncio.sleep(0.1) else: stdin.channel.exit(255) stdin.channel.close() await stdin.channel.wait_closed() def session_requested(self): return self._begin_session @patch('asyncssh.connection.SSHServerConnection', _X11ServerConnection) @patch('asyncssh.x11.SSHX11ClientListener', _X11ClientListener) class _TestX11(ServerTestCase): """Unit tests for AsyncSSH X11 forwarding""" @classmethod def setUpClass(cls): """Create Xauthority file needed for test""" super().setUpClass() auth_data = os.urandom(XAUTH_COOKIE_LEN) with open('.Xauthority', 'wb') as auth_file: auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_HOSTNAME, b'test', b'1', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_HOSTNAME, b'test', b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV4, socket.inet_pton(socket.AF_INET, '127.0.0.2'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV4, socket.inet_pton(socket.AF_INET, '127.0.0.1'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV6, socket.inet_pton(socket.AF_INET6, '::2'), b'0', XAUTH_PROTO_COOKIE, auth_data))) auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_IPV6, socket.inet_pton(socket.AF_INET6, '::1'), b'0', XAUTH_PROTO_COOKIE, auth_data))) # Added to cover case where we don't match on address family auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_DECNET, b'test', b'0', XAUTH_PROTO_COOKIE, auth_data))) # Wildcard address family match auth_file.write(bytes(SSHXAuthorityEntry( XAUTH_FAMILY_WILD, b'', b'0', XAUTH_PROTO_COOKIE, auth_data))) with open('.Xauthority-empty', 'wb'): pass with open('.Xauthority-corrupted', 'wb') as auth_file: auth_file.write(b'\x00\x00\x00') _X11Peer.expected_auth = auth_data @classmethod async def start_server(cls): """Start an SSH server for the tests to use""" return (await cls.create_server( _X11Server, x11_forwarding=True, authorized_client_keys='authorized_keys')) async def _check_x11(self, command=None, *, exc=None, exit_status=None, **kwargs): """Check requesting X11 forwarding""" async with self.connect() as conn: if exc: with self.assertRaises(exc): await _create_x11_process(conn, command, **kwargs) else: proc = await _create_x11_process(conn, command, **kwargs) await proc.wait() self.assertEqual(proc.exit_status, exit_status) @asynctest async def test_xauth_lookup(self): """Test writing an xauth entry and looking it back up""" await update_xauth('xauth', 'test', '0', b'', b'\x00') _, auth_data = await lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x00') @asynctest async def test_xauth_dead_lock(self): """Test removal of dead Xauthority lock""" with open('xauth-c', 'w'): pass await asyncio.sleep(6) await update_xauth('xauth', 'test', '0', b'', b'\x00') _, auth_data = await lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x00') @asynctest async def test_xauth_update(self): """Test overwriting an xauth entry""" await update_xauth('xauth', 'test', '0', b'', b'\x00') await update_xauth('xauth', 'test', '0', b'', b'\x01') self.assertEqual(len(list(walk_xauth('xauth'))), 1) _, auth_data = await lookup_xauth(asyncio.get_event_loop(), 'xauth', 'test', '0') os.unlink('xauth') self.assertEqual(auth_data, b'\x01') @asynctest async def test_forward_big(self): """Test requesting X11 forwarding with big-endian connect""" await self._check_x11('connect B', exit_status=1, x11_display='test:0.0', x11_single_connection=True) @asynctest async def test_forward_little(self): """Test requesting X11 forwarding with little-endian connect""" await self._check_x11('connect l', exit_status=1) @asynctest async def test_connection_refused_big(self): """Test the X server refusing connection with big-endian connect""" await self._check_x11('connect B', exit_status=2, x11_display='test:1') @asynctest async def test_connection_refused_little(self): """Test the X server refusing connection with little-endian connect""" await self._check_x11('connect l', exit_status=2, x11_display='test:1') @asynctest async def test_bad_auth_big(self): """Test sending bad auth data with big-endian connect""" await self._check_x11('connect BX', exit_status=0) @asynctest async def test_bad_auth_little(self): """Test sending bad auth data with little-endian connect""" await self._check_x11('connect lX', exit_status=0) @asynctest async def test_ipv4_address(self): """Test matching against an IPv4 address""" await self._check_x11(x11_display='127.0.0.1:0') @asynctest async def test_ipv6_address(self): """Test matching against an IPv6 address""" await self._check_x11(x11_display='[::1]:0') @asynctest async def test_wildcard_address(self): """Test matching against a wildcard host entry""" await self._check_x11(x11_display='wild:0') @asynctest async def test_local_server(self): """Test matching against a local X server""" await self._check_x11(x11_display=':0') @asynctest async def test_domain_socket(self): """Test matching against an explicit domain socket""" await self._check_x11(x11_display='/test:0') @asynctest async def test_display_environment(self): """Test getting X11 display from the environment""" os.environ['DISPLAY'] = 'test:0' await self._check_x11(x11_display=None) del os.environ['DISPLAY'] @asynctest async def test_display_not_set(self): """Test requesting X11 forwarding with no display set""" await self._check_x11(exc=asyncssh.ChannelOpenError, x11_display=None) @asynctest async def test_forwarding_denied(self): """Test SSH server denying X11 forwarding""" await self._check_x11(exc=asyncssh.ChannelOpenError, x11_display='test:0.9') @asynctest async def test_xauth_environment(self): """Test getting Xauthority path from the environment""" os.environ['XAUTHORITY'] = '.Xauthority' await self._check_x11() del os.environ['XAUTHORITY'] @asynctest async def test_no_xauth_match(self): """Test no xauth match""" await self._check_x11(x11_display='no_match:1') @asynctest async def test_invalid_display(self): """Test invalid X11 display value""" await self._check_x11(exc=asyncssh.ChannelOpenError, x11_display='test') @asynctest async def test_xauth_missing(self): """Test missing .Xauthority file""" await self._check_x11(x11_auth_path='.Xauthority-missing') @asynctest async def test_xauth_empty(self): """Test empty .Xauthority file""" await self._check_x11(x11_auth_path='.Xauthority-empty') @asynctest async def test_xauth_corrupted(self): """Test .Xauthority file with corrupted entry""" await self._check_x11(exc=asyncssh.ChannelOpenError, x11_auth_path='.Xauthority-corrupted') @asynctest async def test_selective_forwarding(self): """Test requesting X11 forwarding from one session and not another""" async with self.connect() as conn: await conn.create_process('sleep') await _create_x11_process(conn, 'sleep', x11_display='test:0') @asynctest async def test_from_connect(self): """Test requesting X11 forwarding on connection""" async with self.connect(x11_forwarding=True, x11_display='text:0') as conn: await conn.create_process('sleep') @asynctest async def test_multiple_sessions(self): """Test requesting X11 forwarding from two different sessions""" async with self.connect() as conn: await _create_x11_process(conn) await _create_x11_process(conn) @asynctest async def test_simultaneous_sessions(self): """Test X11 forwarding from multiple sessions simultaneously""" async with self.connect() as conn: await _create_x11_process(conn, 'sleep') await _create_x11_process(conn, 'sleep', x11_display='test:0.1') @asynctest async def test_consecutive_different_servers(self): """Test X11 forwarding to different X servers consecutively""" async with self.connect() as conn: proc = await _create_x11_process(conn) await proc.wait() await _create_x11_process(conn, x11_display='test1:0') @asynctest async def test_simultaneous_different_servers(self): """Test X11 forwarding to different X servers simultaneously""" async with self.connect() as conn: await _create_x11_process(conn, 'sleep') with self.assertRaises(asyncssh.ChannelOpenError): await _create_x11_process(conn, x11_display='test1:0') @asynctest async def test_forwarding_disabled(self): """Test X11 request when forwarding was never enabled""" async with self.connect() as conn: result = await conn.run('connect l') self.assertEqual(result.exit_status, 2) @asynctest async def test_attach_failure(self): """Test X11 listener attach when forwarding was never enabled""" async with self.connect() as conn: result = await conn.run('attach') self.assertEqual(result.exit_status, 0) @asynctest async def test_attach_lock_failure(self): """Test X11 listener attach when Xauthority can't be locked""" with open('.Xauthority-c', 'w'): pass await self._check_x11('connect l', exc=asyncssh.ChannelOpenError) os.unlink('.Xauthority-c') @asynctest async def test_open_failure(self): """Test opening X11 connection when forwarding was never enabled""" async with self.connect() as conn: result = await conn.run('open') self.assertEqual(result.exit_status, 0) @asynctest async def test_open_invalid_unicode(self): """Test opening X11 connection with invalid unicode in original host""" async with self.connect() as conn: result = await conn.run('invalid') self.assertEqual(result.exit_status, None) @asynctest async def test_forwarding_not_allowed(self): """Test an X11 request from a non-authorized user""" ckey = asyncssh.read_private_key('ckey') cert = ckey.generate_user_certificate(ckey, 'name', principals=['ckey'], permit_x11_forwarding=False) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: with self.assertRaises(asyncssh.ChannelOpenError): await _create_x11_process(conn, 'connect l') @asynctest async def test_forwarding_ignore_failure(self): """Test ignoring failure on an X11 forwarding request""" ckey = asyncssh.read_private_key('ckey') cert = ckey.generate_user_certificate(ckey, 'name', principals=['ckey'], permit_x11_forwarding=False) async with self.connect(username='ckey', client_keys=[(ckey, cert)], agent_path=None) as conn: proc = await _create_x11_process( conn, x11_forwarding='ignore_failure', x11_display='test') await proc.wait() proc = await _create_x11_process( conn, x11_forwarding='ignore_failure') await proc.wait() @asynctest async def test_invalid_x11_forwarding_request(self): """Test an invalid X11 forwarding request""" with patch('asyncssh.connection.SSHClientChannel', _X11ClientChannel): async with self.connect() as conn: stdin, _, _ = await conn.open_session('sleep') result = await stdin.channel.make_x11_forwarding_request( '', 'xx', 0) self.assertFalse(result) @asynctest async def test_unknown_action(self): """Test unknown action""" async with self.connect() as conn: result = await conn.run('unknown') self.assertEqual(result.exit_status, 255) asyncssh-2.20.0/tests/test_x509.py000066400000000000000000000237531475467777400167540ustar00rootroot00000000000000# Copyright (c) 2017-2020 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Unit tests for X.509 certificate handling""" import time import unittest from cryptography import x509 from .util import get_test_key, x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509Name, X509NamePattern from asyncssh.crypto import generate_x509_certificate from asyncssh.crypto import import_x509_certificate _purpose_secureShellClient = x509.ObjectIdentifier('1.3.6.1.5.5.7.3.21') @unittest.skipUnless(x509_available, 'X.509 not available') class _TestX509(unittest.TestCase): """Unit tests for X.509 module""" @classmethod def setUpClass(cls): cls._privkey = get_test_key('ssh-rsa') cls._pubkey = cls._privkey.convert_to_public() cls._pubdata = cls._pubkey.export_public_key('pkcs8-der') def generate_certificate(self, subject='OU=name', issuer=None, serial=None, valid_after=0, valid_before=0xffffffffffffffff, ca=False, ca_path_len=None, purposes=None, user_principals=(), host_principals=(), hash_alg='sha256', comment=None): """Generate and check an X.509 certificate""" cert = generate_x509_certificate(self._privkey.pyca_key, self._pubkey.pyca_key, subject, issuer, serial, valid_after, valid_before, ca, ca_path_len, purposes, user_principals, host_principals, hash_alg, comment) self.assertEqual(cert.data, import_x509_certificate(cert.data).data) self.assertEqual(cert.subject, X509Name(subject)) self.assertEqual(cert.issuer, X509Name(issuer if issuer else subject)) self.assertEqual(cert.key_data, self._pubdata) if isinstance(comment, str): comment = comment.encode('utf-8') self.assertEqual(cert.comment, comment) return cert def test_generate(self): """Test X.509 certificate generation""" cert = self.generate_certificate(purposes='secureShellClient') self.assertEqual(cert.purposes, {_purpose_secureShellClient}) def test_generate_ca(self): """Test X.509 CA certificate generation""" self.generate_certificate(ca=True, ca_path_len=0) def test_serial(self): """Test X.509 certificate generation with serial number""" self.generate_certificate(serial=1) def test_user_principals(self): """Test X.509 certificate generation with user principals""" cert = self.generate_certificate(user_principals='user1,user2') self.assertEqual(cert.user_principals, ['user1', 'user2']) def test_host_principals(self): """Test X.509 certificate generation with host principals""" cert = self.generate_certificate(host_principals='host1,host2') self.assertEqual(cert.host_principals, ['host1', 'host2']) def test_principal_in_common_name(self): """Test X.509 certificate generation with user principals""" cert = self.generate_certificate(subject='CN=name') self.assertEqual(cert.user_principals, ['name']) self.assertEqual(cert.host_principals, ['name']) def test_comment(self): """Test X.509 certificate generation with comment""" self.generate_certificate(comment=b'comment') self.generate_certificate(comment='comment') def test_unknown_hash(self): """Test X.509 certificate generation with unknown hash""" with self.assertRaises(ValueError): self.generate_certificate(hash_alg='xxx') def test_valid_self(self): """Test validation of X.509 self-signed certificate""" cert = self.generate_certificate() self.assertIsNone(cert.validate([cert], None, None, None)) def test_untrusted_self(self): """Test failed validation of untrusted X.509 self-signed certificate""" cert1 = self.generate_certificate() cert2 = self.generate_certificate() with self.assertRaises(ValueError): cert1.validate([cert2], None, None, None) def test_valid_chain(self): """Test validation of X.509 certificate chain""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0) cert = self.generate_certificate('OU=user', 'OU=int') self.assertIsNone(cert.validate([int_ca, root_ca], None, None, None)) def test_incomplete_chain(self): """Test failed validation of incomplete X.509 certificate chain""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0) cert = self.generate_certificate('OU=user', 'OU=int2') with self.assertRaises(ValueError): cert.validate([int_ca, root_ca], None, None, None) def test_not_yet_valid_self(self): """Test failed validation of not-yet-valid X.509 certificate""" cert = self.generate_certificate(valid_after=time.time() + 60) with self.assertRaises(ValueError): cert.validate([cert], None, None, None) def test_expired_self(self): """Test failed validation of expired X.509 certificate""" cert = self.generate_certificate(valid_before=time.time() - 60) with self.assertRaises(ValueError): cert.validate([cert], None, None, None) def test_expired_intermediate(self): """Test failed validation of expired X.509 intermediate CA""" root_ca = self.generate_certificate('OU=root', ca=True, ca_path_len=1) int_ca = self.generate_certificate('OU=int', 'OU=root', ca=True, ca_path_len=0, valid_before=time.time() - 60) cert = self.generate_certificate('OU=user', 'OU=int') with self.assertRaises(ValueError): cert.validate([int_ca, root_ca], None, None, None) def test_purpose_mismatch(self): """Test failed validation due to purpose mismatch""" cert = self.generate_certificate(purposes='secureShellClient') with self.assertRaises(ValueError): cert.validate([cert], 'secureShellServer', None, None) def test_user_principal_match(self): """Test validation of user principal""" cert = self.generate_certificate(user_principals='user') self.assertIsNone(cert.validate([cert], None, 'user', None)) def test_user_principal_mismatch(self): """Test failed validation due to user principal mismatch""" cert = self.generate_certificate(user_principals='user1,user2') with self.assertRaises(ValueError): cert.validate([cert], None, 'user3', None) def test_host_principal_match(self): """Test validation of host principal""" cert = self.generate_certificate(host_principals='host') self.assertIsNone(cert.validate([cert], None, None, 'host')) def test_host_principal_mismatch(self): """Test failed validation due to host principal mismatch""" cert = self.generate_certificate(host_principals='host1,host2') with self.assertRaises(ValueError): cert.validate([cert], None, None, 'host3') def test_name(self): """Test X.509 distinguished name generation""" name = X509Name('O=Org,OU=Unit') self.assertEqual(name, X509Name('O=Org, OU=Unit')) self.assertEqual(name, X509Name(name)) self.assertEqual(name, X509Name(name.rdns)) self.assertEqual(len(name), 2) self.assertEqual(len(name.rdns), 2) self.assertEqual(str(name), 'O=Org,OU=Unit') self.assertNotEqual(name, X509Name('OU=Unit,O=Org')) def test_multiple_attrs_in_rdn(self): """Test multiple attributes in a relative distinguished name""" name1 = X509Name('O=Org,OU=Unit1+OU=Unit2') name2 = X509Name('O=Org,OU=Unit2+OU=Unit1') self.assertEqual(name1, name2) self.assertEqual(len(name1), 3) self.assertEqual(len(name1.rdns), 2) def test_invalid_attribute(self): """Test X.509 distinguished name with invalid attributes""" with self.assertRaises(ValueError): X509Name('xxx') with self.assertRaises(ValueError): X509Name('X=xxx') def test_exact_name_pattern(self): """Test X.509 distinguished name exact match""" pattern1 = X509NamePattern('O=Org,OU=Unit') pattern2 = X509NamePattern('O=Org, OU=Unit') self.assertEqual(pattern1, pattern2) self.assertEqual(hash(pattern1), hash(pattern2)) self.assertTrue(pattern1.matches(X509Name('O=Org,OU=Unit'))) self.assertFalse(pattern1.matches(X509Name('O=Org,OU=Unit2'))) def test_prefix_pattern(self): """Test X.509 distinguished name prefix match""" pattern = X509NamePattern('O=Org,*') self.assertTrue(pattern.matches(X509Name('O=Org,OU=Unit'))) self.assertFalse(pattern.matches(X509Name('O=Org2,OU=Unit'))) asyncssh-2.20.0/tests/util.py000066400000000000000000000317261475467777400161640ustar00rootroot00000000000000# Copyright (c) 2015-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """Utility functions for unit tests""" import asyncio import binascii import functools import os import shutil import socket import subprocess import sys import tempfile import unittest from unittest.mock import patch from asyncssh import set_default_skip_rsa_key_validation from asyncssh.gss import gss_available from asyncssh.logging import logger from asyncssh.misc import ConnectionLost, SignalReceived from asyncssh.packet import Byte, String, UInt32, UInt64 from asyncssh.public_key import generate_private_key # pylint: disable=ungrouped-imports, unused-import try: import bcrypt bcrypt_available = hasattr(bcrypt, 'kdf') except ImportError: # pragma: no cover bcrypt_available = False nc_available = bool(shutil.which('nc')) try: import uvloop uvloop_available = True except ImportError: # pragma: no cover uvloop_available = False try: from asyncssh.crypto import X509Name x509_available = True except ImportError: # pragma: no cover x509_available = False # pylint: enable=ungrouped-imports, unused-import # pylint: disable=no-member if hasattr(asyncio, 'all_tasks'): all_tasks = asyncio.all_tasks current_task = asyncio.current_task else: # pragma: no cover all_tasks = asyncio.Task.all_tasks current_task = asyncio.Task.current_task # pylint: enable=no-member _test_keys = {} set_default_skip_rsa_key_validation(True) def asynctest(coro): """Decorator for async tests, for use with AsyncTestCase""" @functools.wraps(coro) def async_wrapper(self, *args, **kwargs): """Run a coroutine and wait for it to finish""" return self.loop.run_until_complete(coro(self, *args, **kwargs)) return async_wrapper def patch_getaddrinfo(cls): """Decorator for patching socket.getaddrinfo""" # pylint: disable=redefined-builtin cls.orig_getaddrinfo = socket.getaddrinfo hosts = {'testhost.test': '', 'testcname.test': 'cname.test', 'cname.test': ''} def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): """Mock DNS lookup of server hostname""" # pylint: disable=unused-argument try: return [(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, hosts[host], ('127.0.0.1', port))] except KeyError: return cls.orig_getaddrinfo(host, port, family, type, proto, flags) return patch('socket.getaddrinfo', getaddrinfo)(cls) def patch_getnameinfo(cls): """Decorator for patching socket.getnameinfo""" def getnameinfo(sockaddr, flags): """Mock reverse DNS lookup of client address""" # pylint: disable=unused-argument return ('localhost', sockaddr[1]) return patch('socket.getnameinfo', getnameinfo)(cls) def patch_getnameinfo_error(cls): """Decorator for patching socket.getnameinfo to raise an error""" def getnameinfo_error(sockaddr, flags): """Mock failure of reverse DNS lookup of client address""" # pylint: disable=unused-argument raise socket.gaierror() return patch('socket.getnameinfo', getnameinfo_error)(cls) def patch_extra_kex(cls): """Decorator for skipping extra kex algs""" def skip_extra_kex_algs(self): """Don't send extra key exchange algorithms""" # pylint: disable=unused-argument return [] return patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_extra_kex_algs)(cls) def patch_gss(cls): """Decorator for patching GSSAPI classes""" if not gss_available: # pragma: no cover return cls # pylint: disable=import-outside-toplevel if sys.platform == 'win32': # pragma: no cover from .sspi_stub import SSPIAuth cls = patch('asyncssh.gss_win32.ClientAuth', SSPIAuth)(cls) cls = patch('asyncssh.gss_win32.ServerAuth', SSPIAuth)(cls) else: from .gssapi_stub import Name, Credentials, RequirementFlag from .gssapi_stub import SecurityContext cls = patch('asyncssh.gss_unix.Name', Name)(cls) cls = patch('asyncssh.gss_unix.Credentials', Credentials)(cls) cls = patch('asyncssh.gss_unix.RequirementFlag', RequirementFlag)(cls) cls = patch('asyncssh.gss_unix.SecurityContext', SecurityContext)(cls) return cls async def echo(stdin, stdout, stderr=None): """Echo data from stdin back to stdout and stderr (if open)""" try: while not stdin.at_eof(): data = await stdin.read(65536) if data: stdout.write(data) if stderr: stderr.write(data) await stdout.drain() if stderr: await stderr.drain() stdout.write_eof() except SignalReceived as exc: if exc.signal == 'ABRT': raise ConnectionLost('Abort') from None else: stdin.channel.exit_with_signal(exc.signal) except OSError: pass stdout.close() def _encode_options(options): """Encode SSH certificate critical options and extensions""" return b''.join((String(k) + String(v) for k, v in options.items())) def get_test_key(alg_name, key_id=0, **kwargs): """Generate or return a key with the requested parameters""" params = tuple((alg_name, key_id)) + tuple(kwargs.items()) try: key = _test_keys[params] except KeyError: key = generate_private_key(alg_name, **kwargs) _test_keys[params] = key return key def make_certificate(cert_version, cert_type, key, signing_key, principals, key_id='name', valid_after=0, valid_before=0xffffffffffffffff, options=None, extensions=None, bad_signature=False): """Construct an SSH certificate""" keydata = key.encode_ssh_public() principals = b''.join(String(p) for p in principals) options = _encode_options(options) if options else b'' extensions = _encode_options(extensions) if extensions else b'' signing_keydata = b''.join((String(signing_key.algorithm), signing_key.encode_ssh_public())) data = b''.join((String(cert_version), String(os.urandom(32)), keydata, UInt64(0), UInt32(cert_type), String(key_id), String(principals), UInt64(valid_after), UInt64(valid_before), String(options), String(extensions), String(''), String(signing_keydata))) if bad_signature: data += String('') else: data += String(signing_key.sign(data, signing_key.sig_algorithms[0])) return b''.join((cert_version.encode('ascii'), b' ', binascii.b2a_base64(data))) def run(cmd): """Run a shell commands and return the output""" try: return subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: # pragma: no cover logger.error('Error running command: %s' % cmd) logger.error(exc.output.decode()) raise def try_remove(filename): """Try to remove a file, ignoring errors""" try: os.remove(filename) except OSError: # pragma: no cover pass class ConnectionStub: """Stub class used to replace an SSHConnection object""" def __init__(self, peer, server): self._peer = peer self._server = server if peer: self._packet_queue = asyncio.queues.Queue() self._queue_task = self.create_task(self._process_packets()) else: self._packet_queue = None self._queue_task = None self._logger = logger.get_child(context='conn=99') @property def logger(self): """A logger associated with this connection""" return self._logger async def _run_task(self, coro): """Run an asynchronous task""" # pylint: disable=broad-except try: await coro except Exception as exc: if self._peer: # pragma: no branch self.queue_packet(exc) self.connection_lost(exc) def create_task(self, coro): """Create an asynchronous task""" return asyncio.ensure_future(self._run_task(coro)) def is_client(self): """Return if this is a client connection""" return not self._server def is_server(self): """Return if this is a server connection""" return self._server def get_peer(self): """Return the peer of this connection""" return self._peer async def _process_packets(self): """Process the queue of incoming packets""" while True: data = await self._packet_queue.get() if data is None or isinstance(data, Exception): self._queue_task = None self.connection_lost(data) break await self.process_packet(data) def connection_lost(self, exc): """Handle the closing of a connection""" raise NotImplementedError def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError def queue_packet(self, data): """Add an incoming packet to the queue""" self._packet_queue.put_nowait(data) def send_packet(self, pkttype, *args, **kwargs): """Send a packet to this connection's peer""" # pylint: disable=unused-argument if self._peer: self._peer.queue_packet(Byte(pkttype) + b''.join(args)) def close(self): """Close the connection, stopping processing of incoming packets""" if self._peer: self._peer.queue_packet(None) self._peer = None if self._queue_task: self.queue_packet(None) self._queue_task = None if hasattr(unittest.TestCase, 'addClassCleanup'): ClassCleanupTestCase = unittest.TestCase else: # pragma: no cover class ClassCleanupTestCase(unittest.TestCase): """Stripped down version of class cleanup for Python 3.7 & earlier""" _class_cleanups = [] # pylint: disable=arguments-differ @classmethod def addClassCleanup(cls, function, *args, **kwargs): """Add a cleanup to run after tearDownClass""" cls._class_cleanups.append((function, args, kwargs)) @classmethod def tearDownClass(cls): """Run cleanups after tearDown""" super().tearDownClass() while cls._class_cleanups: function, args, kwargs = cls._class_cleanups.pop() function(*args, **kwargs) class TempDirTestCase(ClassCleanupTestCase): """Unit test class which operates in a temporary directory""" _tempdir = None @classmethod def setUpClass(cls): """Create temporary directory and set it as current directory""" cls._tempdir = tempfile.TemporaryDirectory() os.chdir(cls._tempdir.name) @classmethod def tearDownClass(cls): """Clean up temporary directory""" os.chdir('..') cls._tempdir.cleanup() class AsyncTestCase(TempDirTestCase): """Unit test class which supports tests using asyncio""" loop = None @classmethod def setUpClass(cls): """Set up event loop to run async tests and run async class setup""" super().setUpClass() if uvloop_available and os.environ.get('USE_UVLOOP'): # pragma: no cover cls.loop = uvloop.new_event_loop() else: cls.loop = asyncio.new_event_loop() asyncio.set_event_loop(cls.loop) try: cls.loop.run_until_complete(cls.asyncSetUpClass()) except AttributeError: pass @classmethod def tearDownClass(cls): """Run async class teardown and close event loop""" try: cls.loop.run_until_complete(cls.asyncTearDownClass()) except AttributeError: pass cls.loop.close() super().tearDownClass() def setUp(self): """Run async setup if any""" try: self.loop.run_until_complete(self.asyncSetUp()) except AttributeError: pass def tearDown(self): """Run async teardown if any""" try: self.loop.run_until_complete(self.asyncTearDown()) except AttributeError: pass asyncssh-2.20.0/tox.ini000066400000000000000000000021441475467777400147760ustar00rootroot00000000000000[tox] minversion = 3.8 skip_missing_interpreters = True envlist = clean report py3{8,9,10,11,12,13}-{linux,darwin,windows} [testenv] deps = aiofiles>=0.6.0 bcrypt>=3.1.3 fido2>=0.9.2 libnacl>=1.4.2 pyOpenSSL>=17.0.0 pytest>=7.0.1 pytest-cov>=3.0.0 setuptools>=18.5 linux,darwin: gssapi>=1.2.0 linux,darwin: python-pkcs11>=0.7.0 linux,darwin: uvloop>=0.9.1 windows: pywin32>=227 platform = linux: linux darwin: darwin windows: win32 usedevelop = True setenv = PIP_USE_PEP517 = 1 COVERAGE_FILE = .coverage.{envname} commands = {envpython} -m pytest --cov --cov-report=term-missing:skip-covered {posargs} depends = clean [testenv:clean] deps = coverage skip_install = true setenv = COVERAGE_FILE = commands = coverage erase depends = [testenv:report] deps = coverage skip_install = true parallel_show_output = true setenv = COVERAGE_FILE = commands = coverage combine coverage report --show-missing coverage html coverage xml depends = py3{8,9,10,11,12,13}-{linux,darwin,windows} [pytest] testpaths = tests